__init__.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright (c) PyZMQ Developers.
  2. # Distributed under the terms of the Modified BSD License.
  3. import sys
  4. import time
  5. from threading import Thread
  6. from unittest import TestCase
  7. try:
  8. from unittest import SkipTest
  9. except ImportError:
  10. from unittest2 import SkipTest
  11. from pytest import mark
  12. import zmq
  13. from zmq.utils import jsonapi
  14. try:
  15. import gevent
  16. from zmq import green as gzmq
  17. have_gevent = True
  18. except ImportError:
  19. have_gevent = False
  20. PYPY = 'PyPy' in sys.version
  21. #-----------------------------------------------------------------------------
  22. # skip decorators (directly from unittest)
  23. #-----------------------------------------------------------------------------
  24. _id = lambda x: x
  25. skip_pypy = mark.skipif(PYPY, reason="Doesn't work on PyPy")
  26. require_zmq_4 = mark.skipif(zmq.zmq_version_info() < (4,), reason="requires zmq >= 4")
  27. #-----------------------------------------------------------------------------
  28. # Base test class
  29. #-----------------------------------------------------------------------------
  30. class BaseZMQTestCase(TestCase):
  31. green = False
  32. teardown_timeout = 10
  33. @property
  34. def Context(self):
  35. if self.green:
  36. return gzmq.Context
  37. else:
  38. return zmq.Context
  39. def socket(self, socket_type):
  40. s = self.context.socket(socket_type)
  41. self.sockets.append(s)
  42. return s
  43. def setUp(self):
  44. super(BaseZMQTestCase, self).setUp()
  45. if self.green and not have_gevent:
  46. raise SkipTest("requires gevent")
  47. self.context = self.Context.instance()
  48. self.sockets = []
  49. def tearDown(self):
  50. contexts = set([self.context])
  51. while self.sockets:
  52. sock = self.sockets.pop()
  53. contexts.add(sock.context) # in case additional contexts are created
  54. sock.close(0)
  55. for ctx in contexts:
  56. t = Thread(target=ctx.term)
  57. t.daemon = True
  58. t.start()
  59. t.join(timeout=self.teardown_timeout)
  60. if t.is_alive():
  61. # reset Context.instance, so the failure to term doesn't corrupt subsequent tests
  62. zmq.sugar.context.Context._instance = None
  63. raise RuntimeError("context could not terminate, open sockets likely remain in test")
  64. super(BaseZMQTestCase, self).tearDown()
  65. def create_bound_pair(self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'):
  66. """Create a bound socket pair using a random port."""
  67. s1 = self.context.socket(type1)
  68. s1.setsockopt(zmq.LINGER, 0)
  69. port = s1.bind_to_random_port(interface)
  70. s2 = self.context.socket(type2)
  71. s2.setsockopt(zmq.LINGER, 0)
  72. s2.connect('%s:%s' % (interface, port))
  73. self.sockets.extend([s1,s2])
  74. return s1, s2
  75. def ping_pong(self, s1, s2, msg):
  76. s1.send(msg)
  77. msg2 = s2.recv()
  78. s2.send(msg2)
  79. msg3 = s1.recv()
  80. return msg3
  81. def ping_pong_json(self, s1, s2, o):
  82. if jsonapi.jsonmod is None:
  83. raise SkipTest("No json library")
  84. s1.send_json(o)
  85. o2 = s2.recv_json()
  86. s2.send_json(o2)
  87. o3 = s1.recv_json()
  88. return o3
  89. def ping_pong_pyobj(self, s1, s2, o):
  90. s1.send_pyobj(o)
  91. o2 = s2.recv_pyobj()
  92. s2.send_pyobj(o2)
  93. o3 = s1.recv_pyobj()
  94. return o3
  95. def assertRaisesErrno(self, errno, func, *args, **kwargs):
  96. try:
  97. func(*args, **kwargs)
  98. except zmq.ZMQError as e:
  99. self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
  100. got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
  101. else:
  102. self.fail("Function did not raise any error")
  103. def _select_recv(self, multipart, socket, **kwargs):
  104. """call recv[_multipart] in a way that raises if there is nothing to receive"""
  105. if zmq.zmq_version_info() >= (3,1,0):
  106. # zmq 3.1 has a bug, where poll can return false positives,
  107. # so we wait a little bit just in case
  108. # See LIBZMQ-280 on JIRA
  109. time.sleep(0.1)
  110. r,w,x = zmq.select([socket], [], [], timeout=kwargs.pop('timeout', 5))
  111. assert len(r) > 0, "Should have received a message"
  112. kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
  113. recv = socket.recv_multipart if multipart else socket.recv
  114. return recv(**kwargs)
  115. def recv(self, socket, **kwargs):
  116. """call recv in a way that raises if there is nothing to receive"""
  117. return self._select_recv(False, socket, **kwargs)
  118. def recv_multipart(self, socket, **kwargs):
  119. """call recv_multipart in a way that raises if there is nothing to receive"""
  120. return self._select_recv(True, socket, **kwargs)
  121. class PollZMQTestCase(BaseZMQTestCase):
  122. pass
  123. class GreenTest:
  124. """Mixin for making green versions of test classes"""
  125. green = True
  126. teardown_timeout = 10
  127. def assertRaisesErrno(self, errno, func, *args, **kwargs):
  128. if errno == zmq.EAGAIN:
  129. raise SkipTest("Skipping because we're green.")
  130. try:
  131. func(*args, **kwargs)
  132. except zmq.ZMQError:
  133. e = sys.exc_info()[1]
  134. self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \
  135. got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno)))
  136. else:
  137. self.fail("Function did not raise any error")
  138. def tearDown(self):
  139. contexts = set([self.context])
  140. while self.sockets:
  141. sock = self.sockets.pop()
  142. contexts.add(sock.context) # in case additional contexts are created
  143. sock.close()
  144. try:
  145. gevent.joinall(
  146. [gevent.spawn(ctx.term) for ctx in contexts],
  147. timeout=self.teardown_timeout,
  148. raise_error=True,
  149. )
  150. except gevent.Timeout:
  151. raise RuntimeError("context could not terminate, open sockets likely remain in test")
  152. def skip_green(self):
  153. raise SkipTest("Skipping because we are green")
  154. def skip_green(f):
  155. def skipping_test(self, *args, **kwargs):
  156. if self.green:
  157. raise SkipTest("Skipping because we are green")
  158. else:
  159. return f(self, *args, **kwargs)
  160. return skipping_test