test_session.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. """test building messages with Session"""
  2. # Copyright (c) Jupyter Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. import hmac
  5. import os
  6. import sys
  7. import uuid
  8. from datetime import datetime
  9. try:
  10. from unittest import mock
  11. except ImportError:
  12. import mock
  13. import pytest
  14. import zmq
  15. from zmq.tests import BaseZMQTestCase
  16. from zmq.eventloop.zmqstream import ZMQStream
  17. from jupyter_client import session as ss
  18. from jupyter_client import jsonutil
  19. from ipython_genutils.py3compat import string_types
  20. def _bad_packer(obj):
  21. raise TypeError("I don't work")
  22. def _bad_unpacker(bytes):
  23. raise TypeError("I don't work either")
  24. class SessionTestCase(BaseZMQTestCase):
  25. def setUp(self):
  26. BaseZMQTestCase.setUp(self)
  27. self.session = ss.Session()
  28. @pytest.fixture
  29. def no_copy_threshold():
  30. """Disable zero-copy optimizations in pyzmq >= 17"""
  31. with mock.patch.object(zmq, 'COPY_THRESHOLD', 1, create=True):
  32. yield
  33. @pytest.mark.usefixtures('no_copy_threshold')
  34. class TestSession(SessionTestCase):
  35. def test_msg(self):
  36. """message format"""
  37. msg = self.session.msg('execute')
  38. thekeys = set('header parent_header metadata content msg_type msg_id'.split())
  39. s = set(msg.keys())
  40. self.assertEqual(s, thekeys)
  41. self.assertTrue(isinstance(msg['content'],dict))
  42. self.assertTrue(isinstance(msg['metadata'],dict))
  43. self.assertTrue(isinstance(msg['header'],dict))
  44. self.assertTrue(isinstance(msg['parent_header'],dict))
  45. self.assertTrue(isinstance(msg['msg_id'], string_types))
  46. self.assertTrue(isinstance(msg['msg_type'], string_types))
  47. self.assertEqual(msg['header']['msg_type'], 'execute')
  48. self.assertEqual(msg['msg_type'], 'execute')
  49. def test_serialize(self):
  50. msg = self.session.msg('execute', content=dict(a=10, b=1.1))
  51. msg_list = self.session.serialize(msg, ident=b'foo')
  52. ident, msg_list = self.session.feed_identities(msg_list)
  53. new_msg = self.session.deserialize(msg_list)
  54. self.assertEqual(ident[0], b'foo')
  55. self.assertEqual(new_msg['msg_id'],msg['msg_id'])
  56. self.assertEqual(new_msg['msg_type'],msg['msg_type'])
  57. self.assertEqual(new_msg['header'],msg['header'])
  58. self.assertEqual(new_msg['content'],msg['content'])
  59. self.assertEqual(new_msg['parent_header'],msg['parent_header'])
  60. self.assertEqual(new_msg['metadata'],msg['metadata'])
  61. # ensure floats don't come out as Decimal:
  62. self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
  63. def test_default_secure(self):
  64. self.assertIsInstance(self.session.key, bytes)
  65. self.assertIsInstance(self.session.auth, hmac.HMAC)
  66. def test_send(self):
  67. ctx = zmq.Context()
  68. A = ctx.socket(zmq.PAIR)
  69. B = ctx.socket(zmq.PAIR)
  70. A.bind("inproc://test")
  71. B.connect("inproc://test")
  72. msg = self.session.msg('execute', content=dict(a=10))
  73. self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
  74. ident, msg_list = self.session.feed_identities(B.recv_multipart())
  75. new_msg = self.session.deserialize(msg_list)
  76. self.assertEqual(ident[0], b'foo')
  77. self.assertEqual(new_msg['msg_id'],msg['msg_id'])
  78. self.assertEqual(new_msg['msg_type'],msg['msg_type'])
  79. self.assertEqual(new_msg['header'],msg['header'])
  80. self.assertEqual(new_msg['content'],msg['content'])
  81. self.assertEqual(new_msg['parent_header'],msg['parent_header'])
  82. self.assertEqual(new_msg['metadata'],msg['metadata'])
  83. self.assertEqual(new_msg['buffers'],[b'bar'])
  84. content = msg['content']
  85. header = msg['header']
  86. header['msg_id'] = self.session.msg_id
  87. parent = msg['parent_header']
  88. metadata = msg['metadata']
  89. msg_type = header['msg_type']
  90. self.session.send(A, None, content=content, parent=parent,
  91. header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
  92. ident, msg_list = self.session.feed_identities(B.recv_multipart())
  93. new_msg = self.session.deserialize(msg_list)
  94. self.assertEqual(ident[0], b'foo')
  95. self.assertEqual(new_msg['msg_id'],header['msg_id'])
  96. self.assertEqual(new_msg['msg_type'],msg['msg_type'])
  97. self.assertEqual(new_msg['header'],msg['header'])
  98. self.assertEqual(new_msg['content'],msg['content'])
  99. self.assertEqual(new_msg['metadata'],msg['metadata'])
  100. self.assertEqual(new_msg['parent_header'],msg['parent_header'])
  101. self.assertEqual(new_msg['buffers'],[b'bar'])
  102. header['msg_id'] = self.session.msg_id
  103. self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
  104. ident, new_msg = self.session.recv(B)
  105. self.assertEqual(ident[0], b'foo')
  106. self.assertEqual(new_msg['msg_id'],header['msg_id'])
  107. self.assertEqual(new_msg['msg_type'],msg['msg_type'])
  108. self.assertEqual(new_msg['header'],msg['header'])
  109. self.assertEqual(new_msg['content'],msg['content'])
  110. self.assertEqual(new_msg['metadata'],msg['metadata'])
  111. self.assertEqual(new_msg['parent_header'],msg['parent_header'])
  112. self.assertEqual(new_msg['buffers'],[b'bar'])
  113. # buffers must support the buffer protocol
  114. with self.assertRaises(TypeError):
  115. self.session.send(A, msg, ident=b'foo', buffers=[1])
  116. # buffers must be contiguous
  117. buf = memoryview(os.urandom(16))
  118. if sys.version_info >= (3,3):
  119. with self.assertRaises(ValueError):
  120. self.session.send(A, msg, ident=b'foo', buffers=[buf[::2]])
  121. A.close()
  122. B.close()
  123. ctx.term()
  124. def test_args(self):
  125. """initialization arguments for Session"""
  126. s = self.session
  127. self.assertTrue(s.pack is ss.default_packer)
  128. self.assertTrue(s.unpack is ss.default_unpacker)
  129. self.assertEqual(s.username, os.environ.get('USER', u'username'))
  130. s = ss.Session()
  131. self.assertEqual(s.username, os.environ.get('USER', u'username'))
  132. self.assertRaises(TypeError, ss.Session, pack='hi')
  133. self.assertRaises(TypeError, ss.Session, unpack='hi')
  134. u = str(uuid.uuid4())
  135. s = ss.Session(username=u'carrot', session=u)
  136. self.assertEqual(s.session, u)
  137. self.assertEqual(s.username, u'carrot')
  138. def test_tracking(self):
  139. """test tracking messages"""
  140. a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
  141. s = self.session
  142. s.copy_threshold = 1
  143. stream = ZMQStream(a)
  144. msg = s.send(a, 'hello', track=False)
  145. self.assertTrue(msg['tracker'] is ss.DONE)
  146. msg = s.send(a, 'hello', track=True)
  147. self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
  148. M = zmq.Message(b'hi there', track=True)
  149. msg = s.send(a, 'hello', buffers=[M], track=True)
  150. t = msg['tracker']
  151. self.assertTrue(isinstance(t, zmq.MessageTracker))
  152. self.assertRaises(zmq.NotDone, t.wait, .1)
  153. del M
  154. t.wait(1) # this will raise
  155. def test_unique_msg_ids(self):
  156. """test that messages receive unique ids"""
  157. ids = set()
  158. for i in range(2**12):
  159. h = self.session.msg_header('test')
  160. msg_id = h['msg_id']
  161. self.assertTrue(msg_id not in ids)
  162. ids.add(msg_id)
  163. def test_feed_identities(self):
  164. """scrub the front for zmq IDENTITIES"""
  165. theids = "engine client other".split()
  166. content = dict(code='whoda',stuff=object())
  167. themsg = self.session.msg('execute',content=content)
  168. pmsg = theids
  169. def test_session_id(self):
  170. session = ss.Session()
  171. # get bs before us
  172. bs = session.bsession
  173. us = session.session
  174. self.assertEqual(us.encode('ascii'), bs)
  175. session = ss.Session()
  176. # get us before bs
  177. us = session.session
  178. bs = session.bsession
  179. self.assertEqual(us.encode('ascii'), bs)
  180. # change propagates:
  181. session.session = 'something else'
  182. bs = session.bsession
  183. us = session.session
  184. self.assertEqual(us.encode('ascii'), bs)
  185. session = ss.Session(session='stuff')
  186. # get us before bs
  187. self.assertEqual(session.bsession, session.session.encode('ascii'))
  188. self.assertEqual(b'stuff', session.bsession)
  189. def test_zero_digest_history(self):
  190. session = ss.Session(digest_history_size=0)
  191. for i in range(11):
  192. session._add_digest(uuid.uuid4().bytes)
  193. self.assertEqual(len(session.digest_history), 0)
  194. def test_cull_digest_history(self):
  195. session = ss.Session(digest_history_size=100)
  196. for i in range(100):
  197. session._add_digest(uuid.uuid4().bytes)
  198. self.assertTrue(len(session.digest_history) == 100)
  199. session._add_digest(uuid.uuid4().bytes)
  200. self.assertTrue(len(session.digest_history) == 91)
  201. for i in range(9):
  202. session._add_digest(uuid.uuid4().bytes)
  203. self.assertTrue(len(session.digest_history) == 100)
  204. session._add_digest(uuid.uuid4().bytes)
  205. self.assertTrue(len(session.digest_history) == 91)
  206. def test_bad_pack(self):
  207. try:
  208. session = ss.Session(pack=_bad_packer)
  209. except ValueError as e:
  210. self.assertIn("could not serialize", str(e))
  211. self.assertIn("don't work", str(e))
  212. else:
  213. self.fail("Should have raised ValueError")
  214. def test_bad_unpack(self):
  215. try:
  216. session = ss.Session(unpack=_bad_unpacker)
  217. except ValueError as e:
  218. self.assertIn("could not handle output", str(e))
  219. self.assertIn("don't work either", str(e))
  220. else:
  221. self.fail("Should have raised ValueError")
  222. def test_bad_packer(self):
  223. try:
  224. session = ss.Session(packer=__name__ + '._bad_packer')
  225. except ValueError as e:
  226. self.assertIn("could not serialize", str(e))
  227. self.assertIn("don't work", str(e))
  228. else:
  229. self.fail("Should have raised ValueError")
  230. def test_bad_unpacker(self):
  231. try:
  232. session = ss.Session(unpacker=__name__ + '._bad_unpacker')
  233. except ValueError as e:
  234. self.assertIn("could not handle output", str(e))
  235. self.assertIn("don't work either", str(e))
  236. else:
  237. self.fail("Should have raised ValueError")
  238. def test_bad_roundtrip(self):
  239. with self.assertRaises(ValueError):
  240. session = ss.Session(unpack=lambda b: 5)
  241. def _datetime_test(self, session):
  242. content = dict(t=ss.utcnow())
  243. metadata = dict(t=ss.utcnow())
  244. p = session.msg('msg')
  245. msg = session.msg('msg', content=content, metadata=metadata, parent=p['header'])
  246. smsg = session.serialize(msg)
  247. msg2 = session.deserialize(session.feed_identities(smsg)[1])
  248. assert isinstance(msg2['header']['date'], datetime)
  249. self.assertEqual(msg['header'], msg2['header'])
  250. self.assertEqual(msg['parent_header'], msg2['parent_header'])
  251. self.assertEqual(msg['parent_header'], msg2['parent_header'])
  252. assert isinstance(msg['content']['t'], datetime)
  253. assert isinstance(msg['metadata']['t'], datetime)
  254. assert isinstance(msg2['content']['t'], string_types)
  255. assert isinstance(msg2['metadata']['t'], string_types)
  256. self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
  257. self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
  258. def test_datetimes(self):
  259. self._datetime_test(self.session)
  260. def test_datetimes_pickle(self):
  261. session = ss.Session(packer='pickle')
  262. self._datetime_test(session)
  263. def test_datetimes_msgpack(self):
  264. msgpack = pytest.importorskip('msgpack')
  265. session = ss.Session(
  266. pack=msgpack.packb,
  267. unpack=lambda buf: msgpack.unpackb(buf, encoding='utf8'),
  268. )
  269. self._datetime_test(session)
  270. def test_send_raw(self):
  271. ctx = zmq.Context()
  272. A = ctx.socket(zmq.PAIR)
  273. B = ctx.socket(zmq.PAIR)
  274. A.bind("inproc://test")
  275. B.connect("inproc://test")
  276. msg = self.session.msg('execute', content=dict(a=10))
  277. msg_list = [self.session.pack(msg[part]) for part in
  278. ['header', 'parent_header', 'metadata', 'content']]
  279. self.session.send_raw(A, msg_list, ident=b'foo')
  280. ident, new_msg_list = self.session.feed_identities(B.recv_multipart())
  281. new_msg = self.session.deserialize(new_msg_list)
  282. self.assertEqual(ident[0], b'foo')
  283. self.assertEqual(new_msg['msg_type'],msg['msg_type'])
  284. self.assertEqual(new_msg['header'],msg['header'])
  285. self.assertEqual(new_msg['parent_header'],msg['parent_header'])
  286. self.assertEqual(new_msg['content'],msg['content'])
  287. self.assertEqual(new_msg['metadata'],msg['metadata'])
  288. A.close()
  289. B.close()
  290. ctx.term()
  291. def test_clone(self):
  292. s = self.session
  293. s._add_digest('initial')
  294. s2 = s.clone()
  295. assert s2.session == s.session
  296. assert s2.digest_history == s.digest_history
  297. assert s2.digest_history is not s.digest_history
  298. digest = 'abcdef'
  299. s._add_digest(digest)
  300. assert digest in s.digest_history
  301. assert digest not in s2.digest_history