threaded.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. """ Defines a KernelClient that provides thread-safe sockets with async callbacks on message replies.
  2. """
  3. from __future__ import absolute_import
  4. import atexit
  5. import errno
  6. import sys
  7. from threading import Thread, Event
  8. import time
  9. # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
  10. # during garbage collection of threads at exit:
  11. from zmq import ZMQError
  12. from zmq.eventloop import ioloop, zmqstream
  13. # Local imports
  14. from traitlets import Type, Instance
  15. from jupyter_client.channels import HBChannel
  16. from jupyter_client import KernelClient
  17. class ThreadedZMQSocketChannel(object):
  18. """A ZMQ socket invoking a callback in the ioloop"""
  19. session = None
  20. socket = None
  21. ioloop = None
  22. stream = None
  23. _inspect = None
  24. def __init__(self, socket, session, loop):
  25. """Create a channel.
  26. Parameters
  27. ----------
  28. socket : :class:`zmq.Socket`
  29. The ZMQ socket to use.
  30. session : :class:`session.Session`
  31. The session to use.
  32. loop
  33. A pyzmq ioloop to connect the socket to using a ZMQStream
  34. """
  35. super(ThreadedZMQSocketChannel, self).__init__()
  36. self.socket = socket
  37. self.session = session
  38. self.ioloop = loop
  39. evt = Event()
  40. def setup_stream():
  41. self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
  42. self.stream.on_recv(self._handle_recv)
  43. evt.set()
  44. self.ioloop.add_callback(setup_stream)
  45. evt.wait()
  46. _is_alive = False
  47. def is_alive(self):
  48. return self._is_alive
  49. def start(self):
  50. self._is_alive = True
  51. def stop(self):
  52. self._is_alive = False
  53. def close(self):
  54. if self.socket is not None:
  55. try:
  56. self.socket.close(linger=0)
  57. except Exception:
  58. pass
  59. self.socket = None
  60. def send(self, msg):
  61. """Queue a message to be sent from the IOLoop's thread.
  62. Parameters
  63. ----------
  64. msg : message to send
  65. This is threadsafe, as it uses IOLoop.add_callback to give the loop's
  66. thread control of the action.
  67. """
  68. def thread_send():
  69. self.session.send(self.stream, msg)
  70. self.ioloop.add_callback(thread_send)
  71. def _handle_recv(self, msg):
  72. """Callback for stream.on_recv.
  73. Unpacks message, and calls handlers with it.
  74. """
  75. ident,smsg = self.session.feed_identities(msg)
  76. msg = self.session.deserialize(smsg)
  77. # let client inspect messages
  78. if self._inspect:
  79. self._inspect(msg)
  80. self.call_handlers(msg)
  81. def call_handlers(self, msg):
  82. """This method is called in the ioloop thread when a message arrives.
  83. Subclasses should override this method to handle incoming messages.
  84. It is important to remember that this method is called in the thread
  85. so that some logic must be done to ensure that the application level
  86. handlers are called in the application thread.
  87. """
  88. pass
  89. def process_events(self):
  90. """Subclasses should override this with a method
  91. processing any pending GUI events.
  92. """
  93. pass
  94. def flush(self, timeout=1.0):
  95. """Immediately processes all pending messages on this channel.
  96. This is only used for the IOPub channel.
  97. Callers should use this method to ensure that :meth:`call_handlers`
  98. has been called for all messages that have been received on the
  99. 0MQ SUB socket of this channel.
  100. This method is thread safe.
  101. Parameters
  102. ----------
  103. timeout : float, optional
  104. The maximum amount of time to spend flushing, in seconds. The
  105. default is one second.
  106. """
  107. # We do the IOLoop callback process twice to ensure that the IOLoop
  108. # gets to perform at least one full poll.
  109. stop_time = time.time() + timeout
  110. for i in range(2):
  111. self._flushed = False
  112. self.ioloop.add_callback(self._flush)
  113. while not self._flushed and time.time() < stop_time:
  114. time.sleep(0.01)
  115. def _flush(self):
  116. """Callback for :method:`self.flush`."""
  117. self.stream.flush()
  118. self._flushed = True
  119. class IOLoopThread(Thread):
  120. """Run a pyzmq ioloop in a thread to send and receive messages
  121. """
  122. _exiting = False
  123. ioloop = None
  124. def __init__(self):
  125. super(IOLoopThread, self).__init__()
  126. self.daemon = True
  127. @staticmethod
  128. @atexit.register
  129. def _notice_exit():
  130. # Class definitions can be torn down during interpreter shutdown.
  131. # We only need to set _exiting flag if this hasn't happened.
  132. if IOLoopThread is not None:
  133. IOLoopThread._exiting = True
  134. def start(self):
  135. """Start the IOLoop thread
  136. Don't return until self.ioloop is defined,
  137. which is created in the thread
  138. """
  139. self._start_event = Event()
  140. Thread.start(self)
  141. self._start_event.wait()
  142. def run(self):
  143. """Run my loop, ignoring EINTR events in the poller"""
  144. if 'asyncio' in sys.modules:
  145. # tornado may be using asyncio,
  146. # ensure an eventloop exists for this thread
  147. import asyncio
  148. asyncio.set_event_loop(asyncio.new_event_loop())
  149. self.ioloop = ioloop.IOLoop()
  150. # signal that self.ioloop is defined
  151. self._start_event.set()
  152. while True:
  153. try:
  154. self.ioloop.start()
  155. except ZMQError as e:
  156. if e.errno == errno.EINTR:
  157. continue
  158. else:
  159. raise
  160. except Exception:
  161. if self._exiting:
  162. break
  163. else:
  164. raise
  165. else:
  166. break
  167. def stop(self):
  168. """Stop the channel's event loop and join its thread.
  169. This calls :meth:`~threading.Thread.join` and returns when the thread
  170. terminates. :class:`RuntimeError` will be raised if
  171. :meth:`~threading.Thread.start` is called again.
  172. """
  173. if self.ioloop is not None:
  174. self.ioloop.add_callback(self.ioloop.stop)
  175. self.join()
  176. self.close()
  177. self.ioloop = None
  178. def close(self):
  179. if self.ioloop is not None:
  180. try:
  181. self.ioloop.close(all_fds=True)
  182. except Exception:
  183. pass
  184. class ThreadedKernelClient(KernelClient):
  185. """ A KernelClient that provides thread-safe sockets with async callbacks on message replies.
  186. """
  187. @property
  188. def ioloop(self):
  189. return self.ioloop_thread.ioloop
  190. ioloop_thread = Instance(IOLoopThread, allow_none=True)
  191. def start_channels(self, shell=True, iopub=True, stdin=True, hb=True):
  192. self.ioloop_thread = IOLoopThread()
  193. self.ioloop_thread.start()
  194. if shell:
  195. self.shell_channel._inspect = self._check_kernel_info_reply
  196. super(ThreadedKernelClient, self).start_channels(shell, iopub, stdin, hb)
  197. def _check_kernel_info_reply(self, msg):
  198. """This is run in the ioloop thread when the kernel info reply is received
  199. """
  200. if msg['msg_type'] == 'kernel_info_reply':
  201. self._handle_kernel_info_reply(msg)
  202. self.shell_channel._inspect = None
  203. def stop_channels(self):
  204. super(ThreadedKernelClient, self).stop_channels()
  205. if self.ioloop_thread.is_alive():
  206. self.ioloop_thread.stop()
  207. iopub_channel_class = Type(ThreadedZMQSocketChannel)
  208. shell_channel_class = Type(ThreadedZMQSocketChannel)
  209. stdin_channel_class = Type(ThreadedZMQSocketChannel)
  210. hb_channel_class = Type(HBChannel)