test_future.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # coding: utf-8
  2. # Copyright (c) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. from datetime import timedelta
  5. import os
  6. import json
  7. import sys
  8. import pytest
  9. gen = pytest.importorskip('tornado.gen')
  10. import zmq
  11. from zmq.eventloop import future
  12. from tornado.ioloop import IOLoop
  13. from zmq.utils.strtypes import u
  14. from zmq.tests import BaseZMQTestCase
  15. class TestFutureSocket(BaseZMQTestCase):
  16. Context = future.Context
  17. def setUp(self):
  18. self.loop = IOLoop()
  19. self.loop.make_current()
  20. super(TestFutureSocket, self).setUp()
  21. def tearDown(self):
  22. super(TestFutureSocket, self).tearDown()
  23. if self.loop:
  24. self.loop.close(all_fds=True)
  25. IOLoop.clear_current()
  26. IOLoop.clear_instance()
  27. def test_socket_class(self):
  28. s = self.context.socket(zmq.PUSH)
  29. assert isinstance(s, future.Socket)
  30. s.close()
  31. def test_instance_subclass_first(self):
  32. actx = self.Context.instance()
  33. ctx = zmq.Context.instance()
  34. ctx.term()
  35. actx.term()
  36. assert type(ctx) is zmq.Context
  37. assert type(actx) is self.Context
  38. def test_instance_subclass_second(self):
  39. ctx = zmq.Context.instance()
  40. actx = self.Context.instance()
  41. ctx.term()
  42. actx.term()
  43. assert type(ctx) is zmq.Context
  44. assert type(actx) is self.Context
  45. def test_recv_multipart(self):
  46. @gen.coroutine
  47. def test():
  48. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  49. f = b.recv_multipart()
  50. assert not f.done()
  51. yield a.send(b'hi')
  52. recvd = yield f
  53. self.assertEqual(recvd, [b'hi'])
  54. self.loop.run_sync(test)
  55. def test_recv(self):
  56. @gen.coroutine
  57. def test():
  58. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  59. f1 = b.recv()
  60. f2 = b.recv()
  61. assert not f1.done()
  62. assert not f2.done()
  63. yield a.send_multipart([b'hi', b'there'])
  64. recvd = yield f2
  65. assert f1.done()
  66. self.assertEqual(f1.result(), b'hi')
  67. self.assertEqual(recvd, b'there')
  68. self.loop.run_sync(test)
  69. def test_recv_cancel(self):
  70. @gen.coroutine
  71. def test():
  72. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  73. f1 = b.recv()
  74. f2 = b.recv_multipart()
  75. assert f1.cancel()
  76. assert f1.done()
  77. assert not f2.done()
  78. yield a.send_multipart([b'hi', b'there'])
  79. recvd = yield f2
  80. assert f1.cancelled()
  81. assert f2.done()
  82. self.assertEqual(recvd, [b'hi', b'there'])
  83. self.loop.run_sync(test)
  84. @pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
  85. def test_recv_timeout(self):
  86. @gen.coroutine
  87. def test():
  88. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  89. b.rcvtimeo = 100
  90. f1 = b.recv()
  91. b.rcvtimeo = 1000
  92. f2 = b.recv_multipart()
  93. with pytest.raises(zmq.Again):
  94. yield f1
  95. yield a.send_multipart([b'hi', b'there'])
  96. recvd = yield f2
  97. assert f2.done()
  98. self.assertEqual(recvd, [b'hi', b'there'])
  99. self.loop.run_sync(test)
  100. @pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
  101. def test_send_timeout(self):
  102. @gen.coroutine
  103. def test():
  104. s = self.socket(zmq.PUSH)
  105. s.sndtimeo = 100
  106. with pytest.raises(zmq.Again):
  107. yield s.send(b'not going anywhere')
  108. self.loop.run_sync(test)
  109. @pytest.mark.now
  110. def test_send_noblock(self):
  111. @gen.coroutine
  112. def test():
  113. s = self.socket(zmq.PUSH)
  114. with pytest.raises(zmq.Again):
  115. yield s.send(b'not going anywhere', flags=zmq.NOBLOCK)
  116. self.loop.run_sync(test)
  117. @pytest.mark.now
  118. def test_send_multipart_noblock(self):
  119. @gen.coroutine
  120. def test():
  121. s = self.socket(zmq.PUSH)
  122. with pytest.raises(zmq.Again):
  123. yield s.send_multipart([b'not going anywhere'], flags=zmq.NOBLOCK)
  124. self.loop.run_sync(test)
  125. def test_recv_string(self):
  126. @gen.coroutine
  127. def test():
  128. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  129. f = b.recv_string()
  130. assert not f.done()
  131. msg = u('πøøπ')
  132. yield a.send_string(msg)
  133. recvd = yield f
  134. assert f.done()
  135. self.assertEqual(f.result(), msg)
  136. self.assertEqual(recvd, msg)
  137. self.loop.run_sync(test)
  138. def test_recv_json(self):
  139. @gen.coroutine
  140. def test():
  141. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  142. f = b.recv_json()
  143. assert not f.done()
  144. obj = dict(a=5)
  145. yield a.send_json(obj)
  146. recvd = yield f
  147. assert f.done()
  148. self.assertEqual(f.result(), obj)
  149. self.assertEqual(recvd, obj)
  150. self.loop.run_sync(test)
  151. def test_recv_json_cancelled(self):
  152. @gen.coroutine
  153. def test():
  154. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  155. f = b.recv_json()
  156. assert not f.done()
  157. f.cancel()
  158. # cycle eventloop to allow cancel events to fire
  159. yield gen.sleep(0)
  160. obj = dict(a=5)
  161. yield a.send_json(obj)
  162. with pytest.raises(future.CancelledError):
  163. recvd = yield f
  164. assert f.done()
  165. # give it a chance to incorrectly consume the event
  166. events = yield b.poll(timeout=5)
  167. assert events
  168. yield gen.sleep(0)
  169. # make sure cancelled recv didn't eat up event
  170. recvd = yield gen.with_timeout(timedelta(seconds=5), b.recv_json())
  171. assert recvd == obj
  172. self.loop.run_sync(test)
  173. def test_recv_pyobj(self):
  174. @gen.coroutine
  175. def test():
  176. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  177. f = b.recv_pyobj()
  178. assert not f.done()
  179. obj = dict(a=5)
  180. yield a.send_pyobj(obj)
  181. recvd = yield f
  182. assert f.done()
  183. self.assertEqual(f.result(), obj)
  184. self.assertEqual(recvd, obj)
  185. self.loop.run_sync(test)
  186. def test_custom_serialize(self):
  187. def serialize(msg):
  188. frames = []
  189. frames.extend(msg.get('identities', []))
  190. content = json.dumps(msg['content']).encode('utf8')
  191. frames.append(content)
  192. return frames
  193. def deserialize(frames):
  194. identities = frames[:-1]
  195. content = json.loads(frames[-1].decode('utf8'))
  196. return {
  197. 'identities': identities,
  198. 'content': content,
  199. }
  200. @gen.coroutine
  201. def test():
  202. a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
  203. msg = {
  204. 'content': {
  205. 'a': 5,
  206. 'b': 'bee',
  207. }
  208. }
  209. yield a.send_serialized(msg, serialize)
  210. recvd = yield b.recv_serialized(deserialize)
  211. assert recvd['content'] == msg['content']
  212. assert recvd['identities']
  213. # bounce back, tests identities
  214. yield b.send_serialized(recvd, serialize)
  215. r2 = yield a.recv_serialized(deserialize)
  216. assert r2['content'] == msg['content']
  217. assert not r2['identities']
  218. self.loop.run_sync(test)
  219. def test_custom_serialize_error(self):
  220. @gen.coroutine
  221. def test():
  222. a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
  223. msg = {
  224. 'content': {
  225. 'a': 5,
  226. 'b': 'bee',
  227. }
  228. }
  229. with pytest.raises(TypeError):
  230. yield a.send_serialized(json, json.dumps)
  231. yield a.send(b'not json')
  232. with pytest.raises(TypeError):
  233. recvd = yield b.recv_serialized(json.loads)
  234. self.loop.run_sync(test)
  235. def test_poll(self):
  236. @gen.coroutine
  237. def test():
  238. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  239. f = b.poll(timeout=0)
  240. assert f.done()
  241. self.assertEqual(f.result(), 0)
  242. f = b.poll(timeout=1)
  243. assert not f.done()
  244. evt = yield f
  245. self.assertEqual(evt, 0)
  246. f = b.poll(timeout=1000)
  247. assert not f.done()
  248. yield a.send_multipart([b'hi', b'there'])
  249. evt = yield f
  250. self.assertEqual(evt, zmq.POLLIN)
  251. recvd = yield b.recv_multipart()
  252. self.assertEqual(recvd, [b'hi', b'there'])
  253. self.loop.run_sync(test)
  254. @pytest.mark.skipif(
  255. sys.platform.startswith('win'),
  256. reason='Windows unsupported socket type')
  257. def test_poll_base_socket(self):
  258. @gen.coroutine
  259. def test():
  260. ctx = zmq.Context()
  261. url = 'inproc://test'
  262. a = ctx.socket(zmq.PUSH)
  263. b = ctx.socket(zmq.PULL)
  264. self.sockets.extend([a, b])
  265. a.bind(url)
  266. b.connect(url)
  267. poller = future.Poller()
  268. poller.register(b, zmq.POLLIN)
  269. f = poller.poll(timeout=1000)
  270. assert not f.done()
  271. a.send_multipart([b'hi', b'there'])
  272. evt = yield f
  273. self.assertEqual(evt, [(b, zmq.POLLIN)])
  274. recvd = b.recv_multipart()
  275. self.assertEqual(recvd, [b'hi', b'there'])
  276. a.close()
  277. b.close()
  278. ctx.term()
  279. self.loop.run_sync(test)
  280. def test_close_all_fds(self):
  281. s = self.socket(zmq.PUB)
  282. self.loop.close(all_fds=True)
  283. self.loop = None # avoid second close later
  284. assert s.closed
  285. @pytest.mark.skipif(
  286. sys.platform.startswith('win'),
  287. reason='Windows does not support polling on files')
  288. def test_poll_raw(self):
  289. @gen.coroutine
  290. def test():
  291. p = future.Poller()
  292. # make a pipe
  293. r, w = os.pipe()
  294. r = os.fdopen(r, 'rb')
  295. w = os.fdopen(w, 'wb')
  296. # POLLOUT
  297. p.register(r, zmq.POLLIN)
  298. p.register(w, zmq.POLLOUT)
  299. evts = yield p.poll(timeout=1)
  300. evts = dict(evts)
  301. assert r.fileno() not in evts
  302. assert w.fileno() in evts
  303. assert evts[w.fileno()] == zmq.POLLOUT
  304. # POLLIN
  305. p.unregister(w)
  306. w.write(b'x')
  307. w.flush()
  308. evts = yield p.poll(timeout=1000)
  309. evts = dict(evts)
  310. assert r.fileno() in evts
  311. assert evts[r.fileno()] == zmq.POLLIN
  312. assert r.read(1) == b'x'
  313. r.close()
  314. w.close()
  315. self.loop.run_sync(test)