zmqhandlers.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # coding: utf-8
  2. """Tornado handlers for WebSocket <-> ZMQ sockets."""
  3. # Copyright (c) Jupyter Development Team.
  4. # Distributed under the terms of the Modified BSD License.
  5. import os
  6. import json
  7. import struct
  8. import warnings
  9. import sys
  10. try:
  11. from urllib.parse import urlparse # Py 3
  12. except ImportError:
  13. from urlparse import urlparse # Py 2
  14. import tornado
  15. from tornado import gen, ioloop, web
  16. from tornado.iostream import StreamClosedError
  17. from tornado.websocket import WebSocketHandler, WebSocketClosedError
  18. from jupyter_client.session import Session
  19. from jupyter_client.jsonutil import date_default, extract_dates
  20. from ipython_genutils.py3compat import cast_unicode
  21. from .handlers import IPythonHandler
  22. def serialize_binary_message(msg):
  23. """serialize a message as a binary blob
  24. Header:
  25. 4 bytes: number of msg parts (nbufs) as 32b int
  26. 4 * nbufs bytes: offset for each buffer as integer as 32b int
  27. Offsets are from the start of the buffer, including the header.
  28. Returns
  29. -------
  30. The message serialized to bytes.
  31. """
  32. # don't modify msg or buffer list in-place
  33. msg = msg.copy()
  34. buffers = list(msg.pop('buffers'))
  35. if sys.version_info < (3, 4):
  36. buffers = [x.tobytes() for x in buffers]
  37. bmsg = json.dumps(msg, default=date_default).encode('utf8')
  38. buffers.insert(0, bmsg)
  39. nbufs = len(buffers)
  40. offsets = [4 * (nbufs + 1)]
  41. for buf in buffers[:-1]:
  42. offsets.append(offsets[-1] + len(buf))
  43. offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
  44. buffers.insert(0, offsets_buf)
  45. return b''.join(buffers)
  46. def deserialize_binary_message(bmsg):
  47. """deserialize a message from a binary blog
  48. Header:
  49. 4 bytes: number of msg parts (nbufs) as 32b int
  50. 4 * nbufs bytes: offset for each buffer as integer as 32b int
  51. Offsets are from the start of the buffer, including the header.
  52. Returns
  53. -------
  54. message dictionary
  55. """
  56. nbufs = struct.unpack('!i', bmsg[:4])[0]
  57. offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
  58. offsets.append(None)
  59. bufs = []
  60. for start, stop in zip(offsets[:-1], offsets[1:]):
  61. bufs.append(bmsg[start:stop])
  62. msg = json.loads(bufs[0].decode('utf8'))
  63. msg['header'] = extract_dates(msg['header'])
  64. msg['parent_header'] = extract_dates(msg['parent_header'])
  65. msg['buffers'] = bufs[1:]
  66. return msg
  67. # ping interval for keeping websockets alive (30 seconds)
  68. WS_PING_INTERVAL = 30000
  69. class WebSocketMixin(object):
  70. """Mixin for common websocket options"""
  71. ping_callback = None
  72. last_ping = 0
  73. last_pong = 0
  74. stream = None
  75. @property
  76. def ping_interval(self):
  77. """The interval for websocket keep-alive pings.
  78. Set ws_ping_interval = 0 to disable pings.
  79. """
  80. return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
  81. @property
  82. def ping_timeout(self):
  83. """If no ping is received in this many milliseconds,
  84. close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
  85. Default is max of 3 pings or 30 seconds.
  86. """
  87. return self.settings.get('ws_ping_timeout',
  88. max(3 * self.ping_interval, WS_PING_INTERVAL)
  89. )
  90. def check_origin(self, origin=None):
  91. """Check Origin == Host or Access-Control-Allow-Origin.
  92. Tornado >= 4 calls this method automatically, raising 403 if it returns False.
  93. """
  94. if self.allow_origin == '*' or (
  95. hasattr(self, 'skip_check_origin') and self.skip_check_origin()):
  96. return True
  97. host = self.request.headers.get("Host")
  98. if origin is None:
  99. origin = self.get_origin()
  100. # If no origin or host header is provided, assume from script
  101. if origin is None or host is None:
  102. return True
  103. origin = origin.lower()
  104. origin_host = urlparse(origin).netloc
  105. # OK if origin matches host
  106. if origin_host == host:
  107. return True
  108. # Check CORS headers
  109. if self.allow_origin:
  110. allow = self.allow_origin == origin
  111. elif self.allow_origin_pat:
  112. allow = bool(self.allow_origin_pat.match(origin))
  113. else:
  114. # No CORS headers deny the request
  115. allow = False
  116. if not allow:
  117. self.log.warning("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
  118. origin, host,
  119. )
  120. return allow
  121. def clear_cookie(self, *args, **kwargs):
  122. """meaningless for websockets"""
  123. pass
  124. def open(self, *args, **kwargs):
  125. self.log.debug("Opening websocket %s", self.request.path)
  126. # start the pinging
  127. if self.ping_interval > 0:
  128. loop = ioloop.IOLoop.current()
  129. self.last_ping = loop.time() # Remember time of last ping
  130. self.last_pong = self.last_ping
  131. self.ping_callback = ioloop.PeriodicCallback(
  132. self.send_ping, self.ping_interval,
  133. )
  134. self.ping_callback.start()
  135. return super(WebSocketMixin, self).open(*args, **kwargs)
  136. def send_ping(self):
  137. """send a ping to keep the websocket alive"""
  138. if self.ws_connection is None and self.ping_callback is not None:
  139. self.ping_callback.stop()
  140. return
  141. # check for timeout on pong. Make sure that we really have sent a recent ping in
  142. # case the machine with both server and client has been suspended since the last ping.
  143. now = ioloop.IOLoop.current().time()
  144. since_last_pong = 1e3 * (now - self.last_pong)
  145. since_last_ping = 1e3 * (now - self.last_ping)
  146. if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
  147. self.log.warning("WebSocket ping timeout after %i ms.", since_last_pong)
  148. self.close()
  149. return
  150. try:
  151. self.ping(b'')
  152. except (StreamClosedError, WebSocketClosedError):
  153. # websocket has been closed, stop pinging
  154. self.ping_callback.stop()
  155. return
  156. self.last_ping = now
  157. def on_pong(self, data):
  158. self.last_pong = ioloop.IOLoop.current().time()
  159. class ZMQStreamHandler(WebSocketMixin, WebSocketHandler):
  160. if tornado.version_info < (4,1):
  161. """Backport send_error from tornado 4.1 to 4.0"""
  162. def send_error(self, *args, **kwargs):
  163. if self.stream is None:
  164. super(WebSocketHandler, self).send_error(*args, **kwargs)
  165. else:
  166. # If we get an uncaught exception during the handshake,
  167. # we have no choice but to abruptly close the connection.
  168. # TODO: for uncaught exceptions after the handshake,
  169. # we can close the connection more gracefully.
  170. self.stream.close()
  171. def _reserialize_reply(self, msg_or_list, channel=None):
  172. """Reserialize a reply message using JSON.
  173. msg_or_list can be an already-deserialized msg dict or the zmq buffer list.
  174. If it is the zmq list, it will be deserialized with self.session.
  175. This takes the msg list from the ZMQ socket and serializes the result for the websocket.
  176. This method should be used by self._on_zmq_reply to build messages that can
  177. be sent back to the browser.
  178. """
  179. if isinstance(msg_or_list, dict):
  180. # already unpacked
  181. msg = msg_or_list
  182. else:
  183. idents, msg_list = self.session.feed_identities(msg_or_list)
  184. msg = self.session.deserialize(msg_list)
  185. if channel:
  186. msg['channel'] = channel
  187. if msg['buffers']:
  188. buf = serialize_binary_message(msg)
  189. return buf
  190. else:
  191. smsg = json.dumps(msg, default=date_default)
  192. return cast_unicode(smsg)
  193. def _on_zmq_reply(self, stream, msg_list):
  194. # Sometimes this gets triggered when the on_close method is scheduled in the
  195. # eventloop but hasn't been called.
  196. if self.ws_connection is None or stream.closed():
  197. self.log.warning("zmq message arrived on closed channel")
  198. self.close()
  199. return
  200. channel = getattr(stream, 'channel', None)
  201. try:
  202. msg = self._reserialize_reply(msg_list, channel=channel)
  203. except Exception:
  204. self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
  205. return
  206. try:
  207. self.write_message(msg, binary=isinstance(msg, bytes))
  208. except (StreamClosedError, WebSocketClosedError):
  209. self.log.warning("zmq message arrived on closed channel")
  210. self.close()
  211. return
  212. class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
  213. def set_default_headers(self):
  214. """Undo the set_default_headers in IPythonHandler
  215. which doesn't make sense for websockets
  216. """
  217. pass
  218. def pre_get(self):
  219. """Run before finishing the GET request
  220. Extend this method to add logic that should fire before
  221. the websocket finishes completing.
  222. """
  223. # authenticate the request before opening the websocket
  224. if self.get_current_user() is None:
  225. self.log.warning("Couldn't authenticate WebSocket connection")
  226. raise web.HTTPError(403)
  227. if self.get_argument('session_id', False):
  228. self.session.session = cast_unicode(self.get_argument('session_id'))
  229. else:
  230. self.log.warning("No session ID specified")
  231. @gen.coroutine
  232. def get(self, *args, **kwargs):
  233. # pre_get can be a coroutine in subclasses
  234. # assign and yield in two step to avoid tornado 3 issues
  235. res = self.pre_get()
  236. yield gen.maybe_future(res)
  237. res = super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
  238. yield gen.maybe_future(res)
  239. def initialize(self):
  240. self.log.debug("Initializing websocket connection %s", self.request.path)
  241. self.session = Session(config=self.config)
  242. def get_compression_options(self):
  243. return self.settings.get('websocket_compression_options', None)