123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615 |
- # -*- coding: utf8 -*-
- # Copyright (C) PyZMQ Developers
- # Distributed under the terms of the Modified BSD License.
- import copy
- import errno
- import json
- import os
- import platform
- import time
- import warnings
- import socket
- import sys
- try:
- from unittest import mock
- except ImportError:
- mock = None
- import pytest
- from pytest import mark
- import zmq
- from zmq.tests import (
- BaseZMQTestCase, SkipTest, have_gevent, GreenTest, skip_pypy
- )
- from zmq.utils.strtypes import unicode
- pypy = platform.python_implementation().lower() == 'pypy'
- windows = platform.platform().lower().startswith('windows')
- on_travis = bool(os.environ.get('TRAVIS_PYTHON_VERSION'))
- # polling on windows is slow
- POLL_TIMEOUT = 1000 if windows else 100
- class TestSocket(BaseZMQTestCase):
- def test_create(self):
- ctx = self.Context()
- s = ctx.socket(zmq.PUB)
- # Superluminal protocol not yet implemented
- self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a')
- self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a')
- self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://')
- s.close()
- del ctx
-
- def test_context_manager(self):
- url = 'inproc://a'
- with self.Context() as ctx:
- with ctx.socket(zmq.PUSH) as a:
- a.bind(url)
- with ctx.socket(zmq.PULL) as b:
- b.connect(url)
- msg = b'hi'
- a.send(msg)
- rcvd = self.recv(b)
- self.assertEqual(rcvd, msg)
- self.assertEqual(b.closed, True)
- self.assertEqual(a.closed, True)
- self.assertEqual(ctx.closed, True)
- def test_dir(self):
- ctx = self.Context()
- s = ctx.socket(zmq.PUB)
- self.assertTrue('send' in dir(s))
- self.assertTrue('IDENTITY' in dir(s))
- self.assertTrue('AFFINITY' in dir(s))
- self.assertTrue('FD' in dir(s))
- s.close()
- ctx.term()
- @mark.skipif(mock is None, reason="requires unittest.mock")
- def test_mockable(self):
- s = self.socket(zmq.SUB)
- m = mock.Mock(spec=s)
- s.close()
- def test_bind_unicode(self):
- s = self.socket(zmq.PUB)
- p = s.bind_to_random_port(unicode("tcp://*"))
- def test_connect_unicode(self):
- s = self.socket(zmq.PUB)
- s.connect(unicode("tcp://127.0.0.1:5555"))
- def test_bind_to_random_port(self):
- # Check that bind_to_random_port do not hide useful exception
- ctx = self.Context()
- c = ctx.socket(zmq.PUB)
- # Invalid format
- try:
- c.bind_to_random_port('tcp:*')
- except zmq.ZMQError as e:
- self.assertEqual(e.errno, zmq.EINVAL)
- # Invalid protocol
- try:
- c.bind_to_random_port('rand://*')
- except zmq.ZMQError as e:
- self.assertEqual(e.errno, zmq.EPROTONOSUPPORT)
- def test_identity(self):
- s = self.context.socket(zmq.PULL)
- self.sockets.append(s)
- ident = b'identity\0\0'
- s.identity = ident
- self.assertEqual(s.get(zmq.IDENTITY), ident)
- def test_unicode_sockopts(self):
- """test setting/getting sockopts with unicode strings"""
- topic = "tést"
- if str is not unicode:
- topic = topic.decode('utf8')
- p,s = self.create_bound_pair(zmq.PUB, zmq.SUB)
- self.assertEqual(s.send_unicode, s.send_unicode)
- self.assertEqual(p.recv_unicode, p.recv_unicode)
- self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic)
- self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic)
- s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16')
- self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic)
- s.setsockopt_unicode(zmq.SUBSCRIBE, topic)
- self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY)
- self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE)
-
- identb = s.getsockopt(zmq.IDENTITY)
- identu = identb.decode('utf16')
- identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16')
- self.assertEqual(identu, identu2)
- time.sleep(0.1) # wait for connection/subscription
- p.send_unicode(topic,zmq.SNDMORE)
- p.send_unicode(topic*2, encoding='latin-1')
- self.assertEqual(topic, s.recv_unicode())
- self.assertEqual(topic*2, s.recv_unicode(encoding='latin-1'))
-
- def test_int_sockopts(self):
- "test integer sockopts"
- v = zmq.zmq_version_info()
- if v < (3,0):
- default_hwm = 0
- else:
- default_hwm = 1000
- p,s = self.create_bound_pair(zmq.PUB, zmq.SUB)
- p.setsockopt(zmq.LINGER, 0)
- self.assertEqual(p.getsockopt(zmq.LINGER), 0)
- p.setsockopt(zmq.LINGER, -1)
- self.assertEqual(p.getsockopt(zmq.LINGER), -1)
- self.assertEqual(p.hwm, default_hwm)
- p.hwm = 11
- self.assertEqual(p.hwm, 11)
- # p.setsockopt(zmq.EVENTS, zmq.POLLIN)
- self.assertEqual(p.getsockopt(zmq.EVENTS), zmq.POLLOUT)
- self.assertRaisesErrno(zmq.EINVAL, p.setsockopt,zmq.EVENTS, 2**7-1)
- self.assertEqual(p.getsockopt(zmq.TYPE), p.socket_type)
- self.assertEqual(p.getsockopt(zmq.TYPE), zmq.PUB)
- self.assertEqual(s.getsockopt(zmq.TYPE), s.socket_type)
- self.assertEqual(s.getsockopt(zmq.TYPE), zmq.SUB)
-
- # check for overflow / wrong type:
- errors = []
- backref = {}
- constants = zmq.constants
- for name in constants.__all__:
- value = getattr(constants, name)
- if isinstance(value, int):
- backref[value] = name
- for opt in zmq.constants.int_sockopts.union(zmq.constants.int64_sockopts):
- sopt = backref[opt]
- if sopt.startswith((
- 'ROUTER', 'XPUB', 'TCP', 'FAIL',
- 'REQ_', 'CURVE_', 'PROBE_ROUTER',
- 'IPC_FILTER', 'GSSAPI', 'STREAM_',
- 'VMCI_BUFFER_SIZE', 'VMCI_BUFFER_MIN_SIZE',
- 'VMCI_BUFFER_MAX_SIZE', 'VMCI_CONNECT_TIMEOUT',
- )):
- # some sockopts are write-only
- continue
- try:
- n = p.getsockopt(opt)
- except zmq.ZMQError as e:
- errors.append("getsockopt(zmq.%s) raised '%s'."%(sopt, e))
- else:
- if n > 2**31:
- errors.append("getsockopt(zmq.%s) returned a ridiculous value."
- " It is probably the wrong type."%sopt)
- if errors:
- self.fail('\n'.join([''] + errors))
-
- def test_bad_sockopts(self):
- """Test that appropriate errors are raised on bad socket options"""
- s = self.context.socket(zmq.PUB)
- self.sockets.append(s)
- s.setsockopt(zmq.LINGER, 0)
- # unrecognized int sockopts pass through to libzmq, and should raise EINVAL
- self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5)
- self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999)
- # but only int sockopts are allowed through this way, otherwise raise a TypeError
- self.assertRaises(TypeError, s.setsockopt, 9999, b"5")
- # some sockopts are valid in general, but not on every socket:
- self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi')
-
- def test_sockopt_roundtrip(self):
- "test set/getsockopt roundtrip."
- p = self.context.socket(zmq.PUB)
- self.sockets.append(p)
- p.setsockopt(zmq.LINGER, 11)
- self.assertEqual(p.getsockopt(zmq.LINGER), 11)
-
- def test_send_unicode(self):
- "test sending unicode objects"
- a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
- self.sockets.extend([a,b])
- u = "çπ§"
- if str is not unicode:
- u = u.decode('utf8')
- self.assertRaises(TypeError, a.send, u,copy=False)
- self.assertRaises(TypeError, a.send, u,copy=True)
- a.send_unicode(u)
- s = b.recv()
- self.assertEqual(s,u.encode('utf8'))
- self.assertEqual(s.decode('utf8'),u)
- a.send_unicode(u,encoding='utf16')
- s = b.recv_unicode(encoding='utf16')
- self.assertEqual(s,u)
-
- def test_send_multipart_check_type(self):
- "check type on all frames in send_multipart"
- a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
- self.sockets.extend([a,b])
- self.assertRaises(TypeError, a.send_multipart, [b'a', 5])
- a.send_multipart([b'b'])
- rcvd = self.recv_multipart(b)
- self.assertEqual(rcvd, [b'b'])
-
- @skip_pypy
- def test_tracker(self):
- "test the MessageTracker object for tracking when zmq is done with a buffer"
- addr = 'tcp://127.0.0.1'
- # get a port:
- sock = socket.socket()
- sock.bind(('127.0.0.1', 0))
- port = sock.getsockname()[1]
- iface = "%s:%i" % (addr, port)
- sock.close()
- time.sleep(0.1)
- a = self.context.socket(zmq.PUSH)
- b = self.context.socket(zmq.PULL)
- self.sockets.extend([a,b])
- a.connect(iface)
- time.sleep(0.1)
- p1 = a.send(b'something', copy=False, track=True)
- assert isinstance(p1, zmq.MessageTracker)
- assert p1 is zmq._FINISHED_TRACKER
- # small message, should start done
- assert p1.done
- # disable zero-copy threshold
- a.copy_threshold = 0
- p2 = a.send_multipart([b'something', b'else'], copy=False, track=True)
- assert isinstance(p2, zmq.MessageTracker)
- assert not p2.done
- b.bind(iface)
- msg = self.recv_multipart(b)
- for i in range(10):
- if p1.done:
- break
- time.sleep(0.1)
- self.assertEqual(p1.done, True)
- self.assertEqual(msg, [b'something'])
- msg = self.recv_multipart(b)
- for i in range(10):
- if p2.done:
- break
- time.sleep(0.1)
- self.assertEqual(p2.done, True)
- self.assertEqual(msg, [b'something', b'else'])
- m = zmq.Frame(b"again", copy=False, track=True)
- self.assertEqual(m.tracker.done, False)
- p1 = a.send(m, copy=False)
- p2 = a.send(m, copy=False)
- self.assertEqual(m.tracker.done, False)
- self.assertEqual(p1.done, False)
- self.assertEqual(p2.done, False)
- msg = self.recv_multipart(b)
- self.assertEqual(m.tracker.done, False)
- self.assertEqual(msg, [b'again'])
- msg = self.recv_multipart(b)
- self.assertEqual(m.tracker.done, False)
- self.assertEqual(msg, [b'again'])
- self.assertEqual(p1.done, False)
- self.assertEqual(p2.done, False)
- pm = m.tracker
- del m
- for i in range(10):
- if p1.done:
- break
- time.sleep(0.1)
- self.assertEqual(p1.done, True)
- self.assertEqual(p2.done, True)
- m = zmq.Frame(b'something', track=False)
- self.assertRaises(ValueError, a.send, m, copy=False, track=True)
- def test_close(self):
- ctx = self.Context()
- s = ctx.socket(zmq.PUB)
- s.close()
- self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'')
- self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'')
- self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'')
- self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf')
- self.assertRaisesErrno(zmq.ENOTSOCK, s.recv)
- del ctx
-
- def test_attr(self):
- """set setting/getting sockopts as attributes"""
- s = self.context.socket(zmq.DEALER)
- self.sockets.append(s)
- linger = 10
- s.linger = linger
- self.assertEqual(linger, s.linger)
- self.assertEqual(linger, s.getsockopt(zmq.LINGER))
- self.assertEqual(s.fd, s.getsockopt(zmq.FD))
-
- def test_bad_attr(self):
- s = self.context.socket(zmq.DEALER)
- self.sockets.append(s)
- try:
- s.apple='foo'
- except AttributeError:
- pass
- else:
- self.fail("bad setattr should have raised AttributeError")
- try:
- s.apple
- except AttributeError:
- pass
- else:
- self.fail("bad getattr should have raised AttributeError")
- def test_subclass(self):
- """subclasses can assign attributes"""
- class S(zmq.Socket):
- a = None
- def __init__(self, *a, **kw):
- self.a=-1
- super(S, self).__init__(*a, **kw)
-
- s = S(self.context, zmq.REP)
- self.sockets.append(s)
- self.assertEqual(s.a, -1)
- s.a=1
- self.assertEqual(s.a, 1)
- a=s.a
- self.assertEqual(a, 1)
-
- def test_recv_multipart(self):
- a,b = self.create_bound_pair()
- msg = b'hi'
- for i in range(3):
- a.send(msg)
- time.sleep(0.1)
- for i in range(3):
- self.assertEqual(self.recv_multipart(b), [msg])
-
- def test_close_after_destroy(self):
- """s.close() after ctx.destroy() should be fine"""
- ctx = self.Context()
- s = ctx.socket(zmq.REP)
- ctx.destroy()
- # reaper is not instantaneous
- time.sleep(1e-2)
- s.close()
- self.assertTrue(s.closed)
-
- def test_poll(self):
- a,b = self.create_bound_pair()
- tic = time.time()
- evt = a.poll(POLL_TIMEOUT)
- self.assertEqual(evt, 0)
- evt = a.poll(POLL_TIMEOUT, zmq.POLLOUT)
- self.assertEqual(evt, zmq.POLLOUT)
- msg = b'hi'
- a.send(msg)
- evt = b.poll(POLL_TIMEOUT)
- self.assertEqual(evt, zmq.POLLIN)
- msg2 = self.recv(b)
- evt = b.poll(POLL_TIMEOUT)
- self.assertEqual(evt, 0)
- self.assertEqual(msg2, msg)
-
- def test_ipc_path_max_length(self):
- """IPC_PATH_MAX_LEN is a sensible value"""
- if zmq.IPC_PATH_MAX_LEN == 0:
- raise SkipTest("IPC_PATH_MAX_LEN undefined")
-
- msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN
- self.assertTrue(zmq.IPC_PATH_MAX_LEN > 30, msg)
- self.assertTrue(zmq.IPC_PATH_MAX_LEN < 1025, msg)
- def test_ipc_path_max_length_msg(self):
- if zmq.IPC_PATH_MAX_LEN == 0:
- raise SkipTest("IPC_PATH_MAX_LEN undefined")
-
- s = self.context.socket(zmq.PUB)
- self.sockets.append(s)
- try:
- s.bind('ipc://{0}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1)))
- except zmq.ZMQError as e:
- self.assertTrue(str(zmq.IPC_PATH_MAX_LEN) in e.strerror)
- @mark.skipif(windows, reason="ipc not supported on Windows.")
- def test_ipc_path_no_such_file_or_directory_message(self):
- """Display the ipc path in case of an ENOENT exception"""
- s = self.context.socket(zmq.PUB)
- self.sockets.append(s)
- invalid_path = '/foo/bar'
- with pytest.raises(zmq.ZMQError) as error:
- s.bind('ipc://{0}'.format(invalid_path))
- assert error.value.errno == errno.ENOENT
- error_message = str(error.value)
- assert invalid_path in error_message
- assert "no such file or directory" in error_message.lower()
- def test_hwm(self):
- zmq3 = zmq.zmq_version_info()[0] >= 3
- for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER):
- s = self.context.socket(stype)
- s.hwm = 100
- self.assertEqual(s.hwm, 100)
- if zmq3:
- try:
- self.assertEqual(s.sndhwm, 100)
- except AttributeError:
- pass
- try:
- self.assertEqual(s.rcvhwm, 100)
- except AttributeError:
- pass
- s.close()
- def test_copy(self):
- s = self.socket(zmq.PUB)
- scopy = copy.copy(s)
- sdcopy = copy.deepcopy(s)
- self.assert_(scopy._shadow)
- self.assert_(sdcopy._shadow)
- self.assertEqual(s.underlying, scopy.underlying)
- self.assertEqual(s.underlying, sdcopy.underlying)
- s.close()
- def test_send_buffer(self):
- a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
- for buffer_type in (memoryview, bytearray):
- rawbytes = str(buffer_type).encode('ascii')
- msg = buffer_type(rawbytes)
- a.send(msg)
- recvd = b.recv()
- assert recvd == rawbytes
- def test_shadow(self):
- p = self.socket(zmq.PUSH)
- p.bind("tcp://127.0.0.1:5555")
- p2 = zmq.Socket.shadow(p.underlying)
- self.assertEqual(p.underlying, p2.underlying)
- s = self.socket(zmq.PULL)
- s2 = zmq.Socket.shadow(s.underlying)
- self.assertNotEqual(s.underlying, p.underlying)
- self.assertEqual(s.underlying, s2.underlying)
- s2.connect("tcp://127.0.0.1:5555")
- sent = b'hi'
- p2.send(sent)
- rcvd = self.recv(s2)
- self.assertEqual(rcvd, sent)
- def test_shadow_pyczmq(self):
- try:
- from pyczmq import zctx, zsocket
- except Exception:
- raise SkipTest("Requires pyczmq")
- ctx = zctx.new()
- ca = zsocket.new(ctx, zmq.PUSH)
- cb = zsocket.new(ctx, zmq.PULL)
- a = zmq.Socket.shadow(ca)
- b = zmq.Socket.shadow(cb)
- a.bind("inproc://a")
- b.connect("inproc://a")
- a.send(b'hi')
- rcvd = self.recv(b)
- self.assertEqual(rcvd, b'hi')
- def test_subscribe_method(self):
- pub, sub = self.create_bound_pair(zmq.PUB, zmq.SUB)
- sub.subscribe('prefix')
- sub.subscribe = 'c'
- p = zmq.Poller()
- p.register(sub, zmq.POLLIN)
- # wait for subscription handshake
- for i in range(100):
- pub.send(b'canary')
- events = p.poll(250)
- if events:
- break
- self.recv(sub)
- pub.send(b'prefixmessage')
- msg = self.recv(sub)
- self.assertEqual(msg, b'prefixmessage')
- sub.unsubscribe('prefix')
- pub.send(b'prefixmessage')
- events = p.poll(1000)
- self.assertEqual(events, [])
- # Travis can't handle how much memory PyPy uses on this test
- @mark.skipif(
- (
- pypy and on_travis
- ) or (
- sys.maxsize < 2**32
- ) or (
- windows
- ),
- reason="only run on 64b and not on Travis."
- )
- @mark.large
- def test_large_send(self):
- c = os.urandom(1)
- N = 2**31 + 1
- try:
- buf = c * N
- except MemoryError as e:
- raise SkipTest("Not enough memory: %s" % e)
- a, b = self.create_bound_pair()
- try:
- a.send(buf, copy=False)
- rcvd = b.recv(copy=False)
- except MemoryError as e:
- raise SkipTest("Not enough memory: %s" % e)
- # sample the front and back of the received message
- # without checking the whole content
- # Python 2: items in memoryview are bytes
- # Python 3: items im memoryview are int
- byte = c if sys.version_info < (3,) else ord(c)
- view = memoryview(rcvd)
- assert len(view) == N
- assert view[0] == byte
- assert view[-1] == byte
- def test_custom_serialize(self):
- a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
- def serialize(msg):
- frames = []
- frames.extend(msg.get('identities', []))
- content = json.dumps(msg['content']).encode('utf8')
- frames.append(content)
- return frames
- def deserialize(frames):
- identities = frames[:-1]
- content = json.loads(frames[-1].decode('utf8'))
- return {
- 'identities': identities,
- 'content': content,
- }
-
- msg = {
- 'content': {
- 'a': 5,
- 'b': 'bee',
- }
- }
- a.send_serialized(msg, serialize)
- recvd = b.recv_serialized(deserialize)
- assert recvd['content'] == msg['content']
- assert recvd['identities']
- # bounce back, tests identities
- b.send_serialized(recvd, serialize)
- r2 = a.recv_serialized(deserialize)
- assert r2['content'] == msg['content']
- assert not r2['identities']
- if have_gevent and not windows:
- import gevent
-
- class TestSocketGreen(GreenTest, TestSocket):
- test_bad_attr = GreenTest.skip_green
- test_close_after_destroy = GreenTest.skip_green
-
- def test_timeout(self):
- a,b = self.create_bound_pair()
- g = gevent.spawn_later(0.5, lambda: a.send(b'hi'))
- timeout = gevent.Timeout(0.1)
- timeout.start()
- self.assertRaises(gevent.Timeout, b.recv)
- g.kill()
-
- @mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
- def test_warn_set_timeo(self):
- s = self.context.socket(zmq.REQ)
- with warnings.catch_warnings(record=True) as w:
- s.rcvtimeo = 5
- s.close()
- self.assertEqual(len(w), 1)
- self.assertEqual(w[0].category, UserWarning)
-
- @mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
- def test_warn_get_timeo(self):
- s = self.context.socket(zmq.REQ)
- with warnings.catch_warnings(record=True) as w:
- s.sndtimeo
- s.close()
- self.assertEqual(len(w), 1)
- self.assertEqual(w[0].category, UserWarning)
|