| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- # Copyright (c) PyZMQ Developers.
- # Distributed under the terms of the Modified BSD License.
- import sys
- import time
- from threading import Thread
- from unittest import TestCase
- try:
- from unittest import SkipTest
- except ImportError:
- from unittest2 import SkipTest
- from pytest import mark
- import zmq
- from zmq.utils import jsonapi
- try:
- import gevent
- from zmq import green as gzmq
- have_gevent = True
- except ImportError:
- have_gevent = False
- PYPY = 'PyPy' in sys.version
- #-----------------------------------------------------------------------------
- # skip decorators (directly from unittest)
- #-----------------------------------------------------------------------------
- _id = lambda x: x
- skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy")
- require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4")
- #-----------------------------------------------------------------------------
- # Base test class
- #-----------------------------------------------------------------------------
- class BaseZMQTestCase(TestCase):
- green = False
- teardown_timeout = 10
-
- @property
- def Context(self):
- if self.green:
- return gzmq.Context
- else:
- return zmq.Context
-
- def socket(self, socket_type):
- s = self.context.socket(socket_type)
- self.sockets.append(s)
- return s
-
- def setUp(self):
- super(BaseZMQTestCase, self).setUp()
- if self.green and not have_gevent:
- raise SkipTest("requires gevent")
- self.context = self.Context.instance()
- self.sockets = []
-
- def tearDown(self):
- contexts = set([self.context])
- while self.sockets:
- sock = self.sockets.pop()
- contexts.add(sock.context) # in case additional contexts are created
- sock.close(0)
- for ctx in contexts:
- t = Thread(target=ctx.term)
- t.daemon = True
- t.start()
- t.join(timeout=self.teardown_timeout)
- if t.is_alive():
- # reset Context.instance, so the failure to term doesn't corrupt subsequent tests
- zmq.sugar.context.Context._instance = None
- raise RuntimeError("context could not terminate, open sockets likely remain in test")
- super(BaseZMQTestCase, self).tearDown()
- def create_bound_pair(self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'):
- """Create a bound socket pair using a random port."""
- s1 = self.context.socket(type1)
- s1.setsockopt(zmq.LINGER, 0)
- port = s1.bind_to_random_port(interface)
- s2 = self.context.socket(type2)
- s2.setsockopt(zmq.LINGER, 0)
- s2.connect('%s:%s' % (interface, port))
- self.sockets.extend([s1,s2])
- return s1, s2
- def ping_pong(self, s1, s2, msg):
- s1.send(msg)
- msg2 = s2.recv()
- s2.send(msg2)
- msg3 = s1.recv()
- return msg3
- def ping_pong_json(self, s1, s2, o):
- if jsonapi.jsonmod is None:
- raise SkipTest("No json library")
- s1.send_json(o)
- o2 = s2.recv_json()
- s2.send_json(o2)
- o3 = s1.recv_json()
- return o3
- def ping_pong_pyobj(self, s1, s2, o):
- s1.send_pyobj(o)
- o2 = s2.recv_pyobj()
- s2.send_pyobj(o2)
- o3 = s1.recv_pyobj()
- return o3
- def assertRaisesErrno(self, errno, func, *args, **kwargs):
- try:
- func(*args, **kwargs)
- except zmq.ZMQError as e:
- self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
- got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
- else:
- self.fail("Function did not raise any error")
-
- def _select_recv(self, multipart, socket, **kwargs):
- """call recv[_multipart] in a way that raises if there is nothing to receive"""
- if zmq.zmq_version_info() >= (3,1,0):
- # zmq 3.1 has a bug, where poll can return false positives,
- # so we wait a little bit just in case
- # See LIBZMQ-280 on JIRA
- time.sleep(0.1)
-
- r,w,x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5))
- assert len(r) > 0, "Should have received a message"
- kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
-
- recv = socket.recv_multipart if multipart else socket.recv
- return recv(**kwargs)
-
- def recv(self, socket, **kwargs):
- """call recv in a way that raises if there is nothing to receive"""
- return self._select_recv(False, socket, **kwargs)
- def recv_multipart(self, socket, **kwargs):
- """call recv_multipart in a way that raises if there is nothing to receive"""
- return self._select_recv(True, socket, **kwargs)
-
- class PollZMQTestCase(BaseZMQTestCase):
- pass
- class GreenTest:
- """Mixin for making green versions of test classes"""
- green = True
- teardown_timeout = 10
-
- def assertRaisesErrno(self, errno, func, *args, **kwargs):
- if errno == zmq.EAGAIN:
- raise SkipTest("Skipping because we're green.")
- try:
- func(*args, **kwargs)
- except zmq.ZMQError:
- e = sys.exc_info()[1]
- self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
- got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
- else:
- self.fail("Function did not raise any error")
- def tearDown(self):
- contexts = set([self.context])
- while self.sockets:
- sock = self.sockets.pop()
- contexts.add(sock.context) # in case additional contexts are created
- sock.close()
- try:
- gevent.joinall(
- [gevent.spawn(ctx.term) for ctx in contexts],
- timeout=self.teardown_timeout,
- raise_error=True,
- )
- except gevent.Timeout:
- raise RuntimeError("context could not terminate, open sockets likely remain in test")
-
- def skip_green(self):
- raise SkipTest("Skipping because we are green")
- def skip_green(f):
- def skipping_test(self, *args, **kwargs):
- if self.green:
- raise SkipTest("Skipping because we are green")
- else:
- return f(self, *args, **kwargs)
- return skipping_test
|