123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- # Copyright (C) PyZMQ Developers
- # Distributed under the terms of the Modified BSD License.
- import time
- from unittest import TestCase
- import zmq
- from zmq import devices
- from zmq.tests import BaseZMQTestCase, SkipTest, PYPY
- from zmq.utils.strtypes import unicode
- if PYPY or zmq.zmq_version_info() >= (4,1):
- # cleanup of shared Context doesn't work on PyPy
- # there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052)
- devices.Device.context_factory = zmq.Context
- class TestMonitoredQueue(BaseZMQTestCase):
-
- sockets = []
-
- def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'):
- self.device = devices.ThreadMonitoredQueue(zmq.PAIR, zmq.PAIR, zmq.PUB,
- in_prefix, out_prefix)
- alice = self.context.socket(zmq.PAIR)
- bob = self.context.socket(zmq.PAIR)
- mon = self.context.socket(zmq.SUB)
-
- aport = alice.bind_to_random_port('tcp://127.0.0.1')
- bport = bob.bind_to_random_port('tcp://127.0.0.1')
- mport = mon.bind_to_random_port('tcp://127.0.0.1')
- mon.setsockopt(zmq.SUBSCRIBE, mon_sub)
-
- self.device.connect_in("tcp://127.0.0.1:%i"%aport)
- self.device.connect_out("tcp://127.0.0.1:%i"%bport)
- self.device.connect_mon("tcp://127.0.0.1:%i"%mport)
- self.device.start()
- time.sleep(.2)
- try:
- # this is currenlty necessary to ensure no dropped monitor messages
- # see LIBZMQ-248 for more info
- mon.recv_multipart(zmq.NOBLOCK)
- except zmq.ZMQError:
- pass
- self.sockets.extend([alice, bob, mon])
- return alice, bob, mon
-
- def teardown_device(self):
- for socket in self.sockets:
- socket.close()
- del socket
- del self.device
-
- def test_reply(self):
- alice, bob, mon = self.build_device()
- alices = b"hello bob".split()
- alice.send_multipart(alices)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices, bobs)
- bobs = b"hello alice".split()
- bob.send_multipart(bobs)
- alices = self.recv_multipart(alice)
- self.assertEqual(alices, bobs)
- self.teardown_device()
-
- def test_queue(self):
- alice, bob, mon = self.build_device()
- alices = b"hello bob".split()
- alice.send_multipart(alices)
- alices2 = b"hello again".split()
- alice.send_multipart(alices2)
- alices3 = b"hello again and again".split()
- alice.send_multipart(alices3)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices, bobs)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices2, bobs)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices3, bobs)
- bobs = b"hello alice".split()
- bob.send_multipart(bobs)
- alices = self.recv_multipart(alice)
- self.assertEqual(alices, bobs)
- self.teardown_device()
-
- def test_monitor(self):
- alice, bob, mon = self.build_device()
- alices = b"hello bob".split()
- alice.send_multipart(alices)
- alices2 = b"hello again".split()
- alice.send_multipart(alices2)
- alices3 = b"hello again and again".split()
- alice.send_multipart(alices3)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices, bobs)
- mons = self.recv_multipart(mon)
- self.assertEqual([b'in']+bobs, mons)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices2, bobs)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices3, bobs)
- mons = self.recv_multipart(mon)
- self.assertEqual([b'in']+alices2, mons)
- bobs = b"hello alice".split()
- bob.send_multipart(bobs)
- alices = self.recv_multipart(alice)
- self.assertEqual(alices, bobs)
- mons = self.recv_multipart(mon)
- self.assertEqual([b'in']+alices3, mons)
- mons = self.recv_multipart(mon)
- self.assertEqual([b'out']+bobs, mons)
- self.teardown_device()
-
- def test_prefix(self):
- alice, bob, mon = self.build_device(b"", b'foo', b'bar')
- alices = b"hello bob".split()
- alice.send_multipart(alices)
- alices2 = b"hello again".split()
- alice.send_multipart(alices2)
- alices3 = b"hello again and again".split()
- alice.send_multipart(alices3)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices, bobs)
- mons = self.recv_multipart(mon)
- self.assertEqual([b'foo']+bobs, mons)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices2, bobs)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices3, bobs)
- mons = self.recv_multipart(mon)
- self.assertEqual([b'foo']+alices2, mons)
- bobs = b"hello alice".split()
- bob.send_multipart(bobs)
- alices = self.recv_multipart(alice)
- self.assertEqual(alices, bobs)
- mons = self.recv_multipart(mon)
- self.assertEqual([b'foo']+alices3, mons)
- mons = self.recv_multipart(mon)
- self.assertEqual([b'bar']+bobs, mons)
- self.teardown_device()
-
- def test_monitor_subscribe(self):
- alice, bob, mon = self.build_device(b"out")
- alices = b"hello bob".split()
- alice.send_multipart(alices)
- alices2 = b"hello again".split()
- alice.send_multipart(alices2)
- alices3 = b"hello again and again".split()
- alice.send_multipart(alices3)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices, bobs)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices2, bobs)
- bobs = self.recv_multipart(bob)
- self.assertEqual(alices3, bobs)
- bobs = b"hello alice".split()
- bob.send_multipart(bobs)
- alices = self.recv_multipart(alice)
- self.assertEqual(alices, bobs)
- mons = self.recv_multipart(mon)
- self.assertEqual([b'out']+bobs, mons)
- self.teardown_device()
-
- def test_router_router(self):
- """test router-router MQ devices"""
- dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
- self.device = dev
- dev.setsockopt_in(zmq.LINGER, 0)
- dev.setsockopt_out(zmq.LINGER, 0)
- dev.setsockopt_mon(zmq.LINGER, 0)
-
- porta = dev.bind_in_to_random_port('tcp://127.0.0.1')
- portb = dev.bind_out_to_random_port('tcp://127.0.0.1')
- a = self.context.socket(zmq.DEALER)
- a.identity = b'a'
- b = self.context.socket(zmq.DEALER)
- b.identity = b'b'
- self.sockets.extend([a, b])
-
- a.connect('tcp://127.0.0.1:%i'%porta)
- b.connect('tcp://127.0.0.1:%i'%portb)
- dev.start()
- time.sleep(1)
- if zmq.zmq_version_info() >= (3,1,0):
- # flush erroneous poll state, due to LIBZMQ-280
- ping_msg = [ b'ping', b'pong' ]
- for s in (a,b):
- s.send_multipart(ping_msg)
- try:
- s.recv(zmq.NOBLOCK)
- except zmq.ZMQError:
- pass
- msg = [ b'hello', b'there' ]
- a.send_multipart([b'b']+msg)
- bmsg = self.recv_multipart(b)
- self.assertEqual(bmsg, [b'a']+msg)
- b.send_multipart(bmsg)
- amsg = self.recv_multipart(a)
- self.assertEqual(amsg, [b'b']+msg)
- self.teardown_device()
-
- def test_default_mq_args(self):
- self.device = dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.DEALER, zmq.PUB)
- dev.setsockopt_in(zmq.LINGER, 0)
- dev.setsockopt_out(zmq.LINGER, 0)
- dev.setsockopt_mon(zmq.LINGER, 0)
- # this will raise if default args are wrong
- dev.start()
- self.teardown_device()
-
- def test_mq_check_prefix(self):
- ins = self.context.socket(zmq.ROUTER)
- outs = self.context.socket(zmq.DEALER)
- mons = self.context.socket(zmq.PUB)
- self.sockets.extend([ins, outs, mons])
-
- ins = unicode('in')
- outs = unicode('out')
- self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons)
|