test_zmqstream.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # -*- coding: utf8 -*-
  2. # Copyright (C) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. from __future__ import absolute_import
  5. try:
  6. import asyncio
  7. except ImportError:
  8. asyncio = None
  9. from unittest import TestCase
  10. import pytest
  11. import zmq
  12. try:
  13. import tornado
  14. from tornado import gen
  15. from zmq.eventloop import ioloop, zmqstream
  16. except ImportError:
  17. tornado = None
  18. class TestZMQStream(TestCase):
  19. def setUp(self):
  20. if tornado is None:
  21. pytest.skip()
  22. if asyncio:
  23. asyncio.set_event_loop(asyncio.new_event_loop())
  24. self.context = zmq.Context()
  25. self.loop = ioloop.IOLoop()
  26. self.loop.make_current()
  27. self.push = zmqstream.ZMQStream(self.context.socket(zmq.PUSH))
  28. self.pull = zmqstream.ZMQStream(self.context.socket(zmq.PULL))
  29. port = self.push.bind_to_random_port('tcp://127.0.0.1')
  30. self.pull.connect('tcp://127.0.0.1:%i' % port)
  31. self.stream = self.push
  32. def tearDown(self):
  33. self.loop.close(all_fds=True)
  34. self.context.term()
  35. ioloop.IOLoop.clear_current()
  36. def run_until_timeout(self, timeout=10):
  37. timed_out = []
  38. @gen.coroutine
  39. def sleep_timeout():
  40. yield gen.sleep(timeout)
  41. timed_out[:] = ['timed out']
  42. self.loop.stop()
  43. self.loop.add_callback(lambda : sleep_timeout())
  44. self.loop.start()
  45. assert not timed_out
  46. def test_callable_check(self):
  47. """Ensure callable check works (py3k)."""
  48. self.stream.on_send(lambda *args: None)
  49. self.stream.on_recv(lambda *args: None)
  50. self.assertRaises(AssertionError, self.stream.on_recv, 1)
  51. self.assertRaises(AssertionError, self.stream.on_send, 1)
  52. self.assertRaises(AssertionError, self.stream.on_recv, zmq)
  53. def test_on_recv_basic(self):
  54. sent = [b'basic']
  55. def callback(msg):
  56. assert msg == sent
  57. self.loop.stop()
  58. self.loop.add_callback(lambda : self.push.send_multipart(sent))
  59. self.pull.on_recv(callback)
  60. self.run_until_timeout()
  61. def test_on_recv_wake(self):
  62. sent = [b'wake']
  63. def callback(msg):
  64. assert msg == sent
  65. self.loop.stop()
  66. self.pull.on_recv(callback)
  67. self.loop.call_later(1, lambda : self.push.send_multipart(sent))
  68. self.run_until_timeout()