test_zmq_shell.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # -*- coding: utf-8 -*-
  2. """ Tests for zmq shell / display publisher. """
  3. # Copyright (c) IPython Development Team.
  4. # Distributed under the terms of the Modified BSD License.
  5. import os
  6. try:
  7. from queue import Queue
  8. except ImportError:
  9. # py2
  10. from Queue import Queue
  11. from threading import Thread
  12. import unittest
  13. from traitlets import Int
  14. import zmq
  15. from ipykernel.zmqshell import ZMQDisplayPublisher
  16. from jupyter_client.session import Session
  17. class NoReturnDisplayHook(object):
  18. """
  19. A dummy DisplayHook which allows us to monitor
  20. the number of times an object is called, but which
  21. does *not* return a message when it is called.
  22. """
  23. call_count = 0
  24. def __call__(self, obj):
  25. self.call_count += 1
  26. class ReturnDisplayHook(NoReturnDisplayHook):
  27. """
  28. A dummy DisplayHook with the same counting ability
  29. as its base class, but which also returns the same
  30. message when it is called.
  31. """
  32. def __call__(self, obj):
  33. super(ReturnDisplayHook, self).__call__(obj)
  34. return obj
  35. class CounterSession(Session):
  36. """
  37. This is a simple subclass to allow us to count
  38. the calls made to the session object by the display
  39. publisher.
  40. """
  41. send_count = Int(0)
  42. def send(self, *args, **kwargs):
  43. """
  44. A trivial override to just augment the existing call
  45. with an increment to the send counter.
  46. """
  47. self.send_count += 1
  48. super(CounterSession, self).send(*args, **kwargs)
  49. class ZMQDisplayPublisherTests(unittest.TestCase):
  50. """
  51. Tests the ZMQDisplayPublisher in zmqshell.py
  52. """
  53. def setUp(self):
  54. self.context = zmq.Context()
  55. self.socket = self.context.socket(zmq.PUB)
  56. self.session = CounterSession()
  57. self.disp_pub = ZMQDisplayPublisher(
  58. session = self.session,
  59. pub_socket = self.socket
  60. )
  61. def tearDown(self):
  62. """
  63. We need to close the socket in order to proceed with the
  64. tests.
  65. TODO - There is still an open file handler to '/dev/null',
  66. presumably created by zmq.
  67. """
  68. self.disp_pub.clear_output()
  69. self.socket.close()
  70. self.context.term()
  71. def test_display_publisher_creation(self):
  72. """
  73. Since there's no explicit constructor, here we confirm
  74. that keyword args get assigned correctly, and override
  75. the defaults.
  76. """
  77. assert self.disp_pub.session == self.session
  78. assert self.disp_pub.pub_socket == self.socket
  79. def test_thread_local_hooks(self):
  80. """
  81. Confirms that the thread_local attribute is correctly
  82. initialised with an empty list for the display hooks
  83. """
  84. assert self.disp_pub._hooks == []
  85. def hook(msg):
  86. return msg
  87. self.disp_pub.register_hook(hook)
  88. assert self.disp_pub._hooks == [hook]
  89. q = Queue()
  90. def set_thread_hooks():
  91. q.put(self.disp_pub._hooks)
  92. t = Thread(target=set_thread_hooks)
  93. t.start()
  94. thread_hooks = q.get(timeout=10)
  95. assert thread_hooks == []
  96. def test_publish(self):
  97. """
  98. Publish should prepare the message and eventually call
  99. `send` by default.
  100. """
  101. data = dict(a = 1)
  102. assert self.session.send_count == 0
  103. self.disp_pub.publish(data)
  104. assert self.session.send_count == 1
  105. def test_display_hook_halts_send(self):
  106. """
  107. If a hook is installed, and on calling the object
  108. it does *not* return a message, then we assume that
  109. the message has been consumed, and should not be
  110. processed (`sent`) in the normal manner.
  111. """
  112. data = dict(a = 1)
  113. hook = NoReturnDisplayHook()
  114. self.disp_pub.register_hook(hook)
  115. assert hook.call_count == 0
  116. assert self.session.send_count == 0
  117. self.disp_pub.publish(data)
  118. assert hook.call_count == 1
  119. assert self.session.send_count == 0
  120. def test_display_hook_return_calls_send(self):
  121. """
  122. If a hook is installed and on calling the object
  123. it returns a new message, then we assume that this
  124. is just a message transformation, and the message
  125. should be sent in the usual manner.
  126. """
  127. data = dict(a=1)
  128. hook = ReturnDisplayHook()
  129. self.disp_pub.register_hook(hook)
  130. assert hook.call_count == 0
  131. assert self.session.send_count == 0
  132. self.disp_pub.publish(data)
  133. assert hook.call_count == 1
  134. assert self.session.send_count == 1
  135. def test_unregister_hook(self):
  136. """
  137. Once a hook is unregistered, it should not be called
  138. during `publish`.
  139. """
  140. data = dict(a = 1)
  141. hook = NoReturnDisplayHook()
  142. self.disp_pub.register_hook(hook)
  143. assert hook.call_count == 0
  144. assert self.session.send_count == 0
  145. self.disp_pub.publish(data)
  146. assert hook.call_count == 1
  147. assert self.session.send_count == 0
  148. #
  149. # After unregistering the `NoReturn` hook, any calls
  150. # to publish should *not* got through the DisplayHook,
  151. # but should instead hit the usual `session.send` call
  152. # at the end.
  153. #
  154. # As a result, the hook call count should *not* increase,
  155. # but the session send count *should* increase.
  156. #
  157. first = self.disp_pub.unregister_hook(hook)
  158. self.disp_pub.publish(data)
  159. self.assertTrue(first)
  160. assert hook.call_count == 1
  161. assert self.session.send_count == 1
  162. #
  163. # If a hook is not installed, `unregister_hook`
  164. # should return false.
  165. #
  166. second = self.disp_pub.unregister_hook(hook)
  167. self.assertFalse(second)
  168. if __name__ == '__main__':
  169. unittest.main()