123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- """test building messages with Session"""
- # Copyright (c) Jupyter Development Team.
- # Distributed under the terms of the Modified BSD License.
- import hmac
- import os
- import sys
- import uuid
- from datetime import datetime
- try:
- from unittest import mock
- except ImportError:
- import mock
- import pytest
- import zmq
- from zmq.tests import BaseZMQTestCase
- from zmq.eventloop.zmqstream import ZMQStream
- from jupyter_client import session as ss
- from jupyter_client import jsonutil
- from ipython_genutils.py3compat import string_types
- def _bad_packer(obj):
- raise TypeError("I don't work")
- def _bad_unpacker(bytes):
- raise TypeError("I don't work either")
- class SessionTestCase(BaseZMQTestCase):
- def setUp(self):
- BaseZMQTestCase.setUp(self)
- self.session = ss.Session()
- @pytest.fixture
- def no_copy_threshold():
- """Disable zero-copy optimizations in pyzmq >= 17"""
- with mock.patch.object(zmq, 'COPY_THRESHOLD', 1, create=True):
- yield
- @pytest.mark.usefixtures('no_copy_threshold')
- class TestSession(SessionTestCase):
- def test_msg(self):
- """message format"""
- msg = self.session.msg('execute')
- thekeys = set('header parent_header metadata content msg_type msg_id'.split())
- s = set(msg.keys())
- self.assertEqual(s, thekeys)
- self.assertTrue(isinstance(msg['content'],dict))
- self.assertTrue(isinstance(msg['metadata'],dict))
- self.assertTrue(isinstance(msg['header'],dict))
- self.assertTrue(isinstance(msg['parent_header'],dict))
- self.assertTrue(isinstance(msg['msg_id'], string_types))
- self.assertTrue(isinstance(msg['msg_type'], string_types))
- self.assertEqual(msg['header']['msg_type'], 'execute')
- self.assertEqual(msg['msg_type'], 'execute')
- def test_serialize(self):
- msg = self.session.msg('execute', content=dict(a=10, b=1.1))
- msg_list = self.session.serialize(msg, ident=b'foo')
- ident, msg_list = self.session.feed_identities(msg_list)
- new_msg = self.session.deserialize(msg_list)
- self.assertEqual(ident[0], b'foo')
- self.assertEqual(new_msg['msg_id'],msg['msg_id'])
- self.assertEqual(new_msg['msg_type'],msg['msg_type'])
- self.assertEqual(new_msg['header'],msg['header'])
- self.assertEqual(new_msg['content'],msg['content'])
- self.assertEqual(new_msg['parent_header'],msg['parent_header'])
- self.assertEqual(new_msg['metadata'],msg['metadata'])
- # ensure floats don't come out as Decimal:
- self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
- def test_default_secure(self):
- self.assertIsInstance(self.session.key, bytes)
- self.assertIsInstance(self.session.auth, hmac.HMAC)
- def test_send(self):
- ctx = zmq.Context()
- A = ctx.socket(zmq.PAIR)
- B = ctx.socket(zmq.PAIR)
- A.bind("inproc://test")
- B.connect("inproc://test")
- msg = self.session.msg('execute', content=dict(a=10))
- self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
- ident, msg_list = self.session.feed_identities(B.recv_multipart())
- new_msg = self.session.deserialize(msg_list)
- self.assertEqual(ident[0], b'foo')
- self.assertEqual(new_msg['msg_id'],msg['msg_id'])
- self.assertEqual(new_msg['msg_type'],msg['msg_type'])
- self.assertEqual(new_msg['header'],msg['header'])
- self.assertEqual(new_msg['content'],msg['content'])
- self.assertEqual(new_msg['parent_header'],msg['parent_header'])
- self.assertEqual(new_msg['metadata'],msg['metadata'])
- self.assertEqual(new_msg['buffers'],[b'bar'])
- content = msg['content']
- header = msg['header']
- header['msg_id'] = self.session.msg_id
- parent = msg['parent_header']
- metadata = msg['metadata']
- msg_type = header['msg_type']
- self.session.send(A, None, content=content, parent=parent,
- header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
- ident, msg_list = self.session.feed_identities(B.recv_multipart())
- new_msg = self.session.deserialize(msg_list)
- self.assertEqual(ident[0], b'foo')
- self.assertEqual(new_msg['msg_id'],header['msg_id'])
- self.assertEqual(new_msg['msg_type'],msg['msg_type'])
- self.assertEqual(new_msg['header'],msg['header'])
- self.assertEqual(new_msg['content'],msg['content'])
- self.assertEqual(new_msg['metadata'],msg['metadata'])
- self.assertEqual(new_msg['parent_header'],msg['parent_header'])
- self.assertEqual(new_msg['buffers'],[b'bar'])
- header['msg_id'] = self.session.msg_id
- self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
- ident, new_msg = self.session.recv(B)
- self.assertEqual(ident[0], b'foo')
- self.assertEqual(new_msg['msg_id'],header['msg_id'])
- self.assertEqual(new_msg['msg_type'],msg['msg_type'])
- self.assertEqual(new_msg['header'],msg['header'])
- self.assertEqual(new_msg['content'],msg['content'])
- self.assertEqual(new_msg['metadata'],msg['metadata'])
- self.assertEqual(new_msg['parent_header'],msg['parent_header'])
- self.assertEqual(new_msg['buffers'],[b'bar'])
- # buffers must support the buffer protocol
- with self.assertRaises(TypeError):
- self.session.send(A, msg, ident=b'foo', buffers=[1])
- # buffers must be contiguous
- buf = memoryview(os.urandom(16))
- if sys.version_info >= (3,3):
- with self.assertRaises(ValueError):
- self.session.send(A, msg, ident=b'foo', buffers=[buf[::2]])
- A.close()
- B.close()
- ctx.term()
- def test_args(self):
- """initialization arguments for Session"""
- s = self.session
- self.assertTrue(s.pack is ss.default_packer)
- self.assertTrue(s.unpack is ss.default_unpacker)
- self.assertEqual(s.username, os.environ.get('USER', u'username'))
- s = ss.Session()
- self.assertEqual(s.username, os.environ.get('USER', u'username'))
- self.assertRaises(TypeError, ss.Session, pack='hi')
- self.assertRaises(TypeError, ss.Session, unpack='hi')
- u = str(uuid.uuid4())
- s = ss.Session(username=u'carrot', session=u)
- self.assertEqual(s.session, u)
- self.assertEqual(s.username, u'carrot')
- def test_tracking(self):
- """test tracking messages"""
- a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
- s = self.session
- s.copy_threshold = 1
- stream = ZMQStream(a)
- msg = s.send(a, 'hello', track=False)
- self.assertTrue(msg['tracker'] is ss.DONE)
- msg = s.send(a, 'hello', track=True)
- self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
- M = zmq.Message(b'hi there', track=True)
- msg = s.send(a, 'hello', buffers=[M], track=True)
- t = msg['tracker']
- self.assertTrue(isinstance(t, zmq.MessageTracker))
- self.assertRaises(zmq.NotDone, t.wait, .1)
- del M
- t.wait(1) # this will raise
- def test_unique_msg_ids(self):
- """test that messages receive unique ids"""
- ids = set()
- for i in range(2**12):
- h = self.session.msg_header('test')
- msg_id = h['msg_id']
- self.assertTrue(msg_id not in ids)
- ids.add(msg_id)
- def test_feed_identities(self):
- """scrub the front for zmq IDENTITIES"""
- theids = "engine client other".split()
- content = dict(code='whoda',stuff=object())
- themsg = self.session.msg('execute',content=content)
- pmsg = theids
- def test_session_id(self):
- session = ss.Session()
- # get bs before us
- bs = session.bsession
- us = session.session
- self.assertEqual(us.encode('ascii'), bs)
- session = ss.Session()
- # get us before bs
- us = session.session
- bs = session.bsession
- self.assertEqual(us.encode('ascii'), bs)
- # change propagates:
- session.session = 'something else'
- bs = session.bsession
- us = session.session
- self.assertEqual(us.encode('ascii'), bs)
- session = ss.Session(session='stuff')
- # get us before bs
- self.assertEqual(session.bsession, session.session.encode('ascii'))
- self.assertEqual(b'stuff', session.bsession)
- def test_zero_digest_history(self):
- session = ss.Session(digest_history_size=0)
- for i in range(11):
- session._add_digest(uuid.uuid4().bytes)
- self.assertEqual(len(session.digest_history), 0)
- def test_cull_digest_history(self):
- session = ss.Session(digest_history_size=100)
- for i in range(100):
- session._add_digest(uuid.uuid4().bytes)
- self.assertTrue(len(session.digest_history) == 100)
- session._add_digest(uuid.uuid4().bytes)
- self.assertTrue(len(session.digest_history) == 91)
- for i in range(9):
- session._add_digest(uuid.uuid4().bytes)
- self.assertTrue(len(session.digest_history) == 100)
- session._add_digest(uuid.uuid4().bytes)
- self.assertTrue(len(session.digest_history) == 91)
- def test_bad_pack(self):
- try:
- session = ss.Session(pack=_bad_packer)
- except ValueError as e:
- self.assertIn("could not serialize", str(e))
- self.assertIn("don't work", str(e))
- else:
- self.fail("Should have raised ValueError")
- def test_bad_unpack(self):
- try:
- session = ss.Session(unpack=_bad_unpacker)
- except ValueError as e:
- self.assertIn("could not handle output", str(e))
- self.assertIn("don't work either", str(e))
- else:
- self.fail("Should have raised ValueError")
- def test_bad_packer(self):
- try:
- session = ss.Session(packer=__name__ + '._bad_packer')
- except ValueError as e:
- self.assertIn("could not serialize", str(e))
- self.assertIn("don't work", str(e))
- else:
- self.fail("Should have raised ValueError")
- def test_bad_unpacker(self):
- try:
- session = ss.Session(unpacker=__name__ + '._bad_unpacker')
- except ValueError as e:
- self.assertIn("could not handle output", str(e))
- self.assertIn("don't work either", str(e))
- else:
- self.fail("Should have raised ValueError")
- def test_bad_roundtrip(self):
- with self.assertRaises(ValueError):
- session = ss.Session(unpack=lambda b: 5)
- def _datetime_test(self, session):
- content = dict(t=ss.utcnow())
- metadata = dict(t=ss.utcnow())
- p = session.msg('msg')
- msg = session.msg('msg', content=content, metadata=metadata, parent=p['header'])
- smsg = session.serialize(msg)
- msg2 = session.deserialize(session.feed_identities(smsg)[1])
- assert isinstance(msg2['header']['date'], datetime)
- self.assertEqual(msg['header'], msg2['header'])
- self.assertEqual(msg['parent_header'], msg2['parent_header'])
- self.assertEqual(msg['parent_header'], msg2['parent_header'])
- assert isinstance(msg['content']['t'], datetime)
- assert isinstance(msg['metadata']['t'], datetime)
- assert isinstance(msg2['content']['t'], string_types)
- assert isinstance(msg2['metadata']['t'], string_types)
- self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
- self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
- def test_datetimes(self):
- self._datetime_test(self.session)
- def test_datetimes_pickle(self):
- session = ss.Session(packer='pickle')
- self._datetime_test(session)
- def test_datetimes_msgpack(self):
- msgpack = pytest.importorskip('msgpack')
- session = ss.Session(
- pack=msgpack.packb,
- unpack=lambda buf: msgpack.unpackb(buf, encoding='utf8'),
- )
- self._datetime_test(session)
- def test_send_raw(self):
- ctx = zmq.Context()
- A = ctx.socket(zmq.PAIR)
- B = ctx.socket(zmq.PAIR)
- A.bind("inproc://test")
- B.connect("inproc://test")
- msg = self.session.msg('execute', content=dict(a=10))
- msg_list = [self.session.pack(msg[part]) for part in
- ['header', 'parent_header', 'metadata', 'content']]
- self.session.send_raw(A, msg_list, ident=b'foo')
- ident, new_msg_list = self.session.feed_identities(B.recv_multipart())
- new_msg = self.session.deserialize(new_msg_list)
- self.assertEqual(ident[0], b'foo')
- self.assertEqual(new_msg['msg_type'],msg['msg_type'])
- self.assertEqual(new_msg['header'],msg['header'])
- self.assertEqual(new_msg['parent_header'],msg['parent_header'])
- self.assertEqual(new_msg['content'],msg['content'])
- self.assertEqual(new_msg['metadata'],msg['metadata'])
- A.close()
- B.close()
- ctx.term()
-
- def test_clone(self):
- s = self.session
- s._add_digest('initial')
- s2 = s.clone()
- assert s2.session == s.session
- assert s2.digest_history == s.digest_history
- assert s2.digest_history is not s.digest_history
- digest = 'abcdef'
- s._add_digest(digest)
- assert digest in s.digest_history
- assert digest not in s2.digest_history
|