| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- # Copyright (C) PyZMQ Developers
- # Distributed under the terms of the Modified BSD License.
- import copy
- import gc
- import os
- import sys
- import time
- from threading import Thread, Event
- try:
- from queue import Queue
- except ImportError:
- from Queue import Queue
- try:
- from unittest import mock
- except ImportError:
- mock = None
- from pytest import mark
- import zmq
- from zmq.tests import (
- BaseZMQTestCase, have_gevent, GreenTest, skip_green, PYPY, SkipTest,
- )
- class KwargTestSocket(zmq.Socket):
- test_kwarg_value = None
- def __init__(self, *args, **kwargs):
- self.test_kwarg_value = kwargs.pop('test_kwarg', None)
- super(KwargTestSocket, self).__init__(*args, **kwargs)
- class KwargTestContext(zmq.Context):
- _socket_class = KwargTestSocket
- class TestContext(BaseZMQTestCase):
- def test_init(self):
- c1 = self.Context()
- self.assert_(isinstance(c1, self.Context))
- del c1
- c2 = self.Context()
- self.assert_(isinstance(c2, self.Context))
- del c2
- c3 = self.Context()
- self.assert_(isinstance(c3, self.Context))
- del c3
- def test_dir(self):
- ctx = self.Context()
- self.assertTrue('socket' in dir(ctx))
- if zmq.zmq_version_info() > (3,):
- self.assertTrue('IO_THREADS' in dir(ctx))
- ctx.term()
- @mark.skipif(mock is None, reason="requires unittest.mock")
- def test_mockable(self):
- m = mock.Mock(spec=self.context)
- def test_term(self):
- c = self.Context()
- c.term()
- self.assert_(c.closed)
- def test_context_manager(self):
- with self.Context() as c:
- pass
- self.assert_(c.closed)
- def test_fail_init(self):
- self.assertRaisesErrno(zmq.EINVAL, self.Context, -1)
- def test_term_hang(self):
- rep,req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER)
- req.setsockopt(zmq.LINGER, 0)
- req.send(b'hello', copy=False)
- req.close()
- rep.close()
- self.context.term()
- def test_instance(self):
- ctx = self.Context.instance()
- c2 = self.Context.instance(io_threads=2)
- self.assertTrue(c2 is ctx)
- c2.term()
- c3 = self.Context.instance()
- c4 = self.Context.instance()
- self.assertFalse(c3 is c2)
- self.assertFalse(c3.closed)
- self.assertTrue(c3 is c4)
- def test_instance_subclass_first(self):
- self.context.term()
- class SubContext(zmq.Context):
- pass
- sctx = SubContext.instance()
- ctx = zmq.Context.instance()
- ctx.term()
- sctx.term()
- assert type(ctx) is zmq.Context
- assert type(sctx) is SubContext
- def test_instance_subclass_second(self):
- self.context.term()
- class SubContextInherit(zmq.Context):
- pass
- class SubContextNoInherit(zmq.Context):
- _instance = None
- pass
- ctx = zmq.Context.instance()
- sctx = SubContextInherit.instance()
- sctx2 = SubContextNoInherit.instance()
- ctx.term()
- sctx.term()
- sctx2.term()
- assert type(ctx) is zmq.Context
- assert type(sctx) is zmq.Context
- assert type(sctx2) is SubContextNoInherit
- def test_instance_threadsafe(self):
- self.context.term() # clear default context
- q = Queue()
- # slow context initialization,
- # to ensure that we are both trying to create one at the same time
- class SlowContext(self.Context):
- def __init__(self, *a, **kw):
- time.sleep(1)
- super(SlowContext, self).__init__(*a, **kw)
- def f():
- q.put(SlowContext.instance())
- # call ctx.instance() in several threads at once
- N = 16
- threads = [ Thread(target=f) for i in range(N) ]
- [ t.start() for t in threads ]
- # also call it in the main thread (not first)
- ctx = SlowContext.instance()
- assert isinstance(ctx, SlowContext)
- # check that all the threads got the same context
- for i in range(N):
- thread_ctx = q.get(timeout=5)
- assert thread_ctx is ctx
- # cleanup
- ctx.term()
- [ t.join(timeout=5) for t in threads ]
- def test_socket_passes_kwargs(self):
- test_kwarg_value = 'testing one two three'
- with KwargTestContext() as ctx:
- with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket:
- self.assertTrue(socket.test_kwarg_value is test_kwarg_value)
- def test_many_sockets(self):
- """opening and closing many sockets shouldn't cause problems"""
- ctx = self.Context()
- for i in range(16):
- sockets = [ ctx.socket(zmq.REP) for i in range(65) ]
- [ s.close() for s in sockets ]
- # give the reaper a chance
- time.sleep(1e-2)
- ctx.term()
-
- def test_sockopts(self):
- """setting socket options with ctx attributes"""
- ctx = self.Context()
- ctx.linger = 5
- self.assertEqual(ctx.linger, 5)
- s = ctx.socket(zmq.REQ)
- self.assertEqual(s.linger, 5)
- self.assertEqual(s.getsockopt(zmq.LINGER), 5)
- s.close()
- # check that subscribe doesn't get set on sockets that don't subscribe:
- ctx.subscribe = b''
- s = ctx.socket(zmq.REQ)
- s.close()
-
- ctx.term()
- @mark.skipif(
- sys.platform.startswith('win'),
- reason='Segfaults on Windows')
- def test_destroy(self):
- """Context.destroy should close sockets"""
- ctx = self.Context()
- sockets = [ ctx.socket(zmq.REP) for i in range(65) ]
-
- # close half of the sockets
- [ s.close() for s in sockets[::2] ]
-
- ctx.destroy()
- # reaper is not instantaneous
- time.sleep(1e-2)
- for s in sockets:
- self.assertTrue(s.closed)
-
- def test_destroy_linger(self):
- """Context.destroy should set linger on closing sockets"""
- req,rep = self.create_bound_pair(zmq.REQ, zmq.REP)
- req.send(b'hi')
- time.sleep(1e-2)
- self.context.destroy(linger=0)
- # reaper is not instantaneous
- time.sleep(1e-2)
- for s in (req,rep):
- self.assertTrue(s.closed)
-
- def test_term_noclose(self):
- """Context.term won't close sockets"""
- ctx = self.Context()
- s = ctx.socket(zmq.REQ)
- self.assertFalse(s.closed)
- t = Thread(target=ctx.term)
- t.start()
- t.join(timeout=0.1)
- self.assertTrue(t.is_alive(), "Context should be waiting")
- s.close()
- t.join(timeout=0.1)
- self.assertFalse(t.is_alive(), "Context should have closed")
-
- def test_gc(self):
- """test close&term by garbage collection alone"""
- if PYPY:
- raise SkipTest("GC doesn't work ")
-
- # test credit @dln (GH #137):
- def gcf():
- def inner():
- ctx = self.Context()
- s = ctx.socket(zmq.PUSH)
- inner()
- gc.collect()
- t = Thread(target=gcf)
- t.start()
- t.join(timeout=1)
- self.assertFalse(t.is_alive(), "Garbage collection should have cleaned up context")
-
- def test_cyclic_destroy(self):
- """ctx.destroy should succeed when cyclic ref prevents gc"""
- # test credit @dln (GH #137):
- class CyclicReference(object):
- def __init__(self, parent=None):
- self.parent = parent
-
- def crash(self, sock):
- self.sock = sock
- self.child = CyclicReference(self)
-
- def crash_zmq():
- ctx = self.Context()
- sock = ctx.socket(zmq.PULL)
- c = CyclicReference()
- c.crash(sock)
- ctx.destroy()
-
- crash_zmq()
-
- def test_term_thread(self):
- """ctx.term should not crash active threads (#139)"""
- ctx = self.Context()
- evt = Event()
- evt.clear()
- def block():
- s = ctx.socket(zmq.REP)
- s.bind_to_random_port('tcp://127.0.0.1')
- evt.set()
- try:
- s.recv()
- except zmq.ZMQError as e:
- self.assertEqual(e.errno, zmq.ETERM)
- return
- finally:
- s.close()
- self.fail("recv should have been interrupted with ETERM")
- t = Thread(target=block)
- t.start()
-
- evt.wait(1)
- self.assertTrue(evt.is_set(), "sync event never fired")
- time.sleep(0.01)
- ctx.term()
- t.join(timeout=1)
- self.assertFalse(t.is_alive(), "term should have interrupted s.recv()")
-
- def test_destroy_no_sockets(self):
- ctx = self.Context()
- s = ctx.socket(zmq.PUB)
- s.bind_to_random_port('tcp://127.0.0.1')
- s.close()
- ctx.destroy()
- assert s.closed
- assert ctx.closed
-
- def test_ctx_opts(self):
- if zmq.zmq_version_info() < (3,):
- raise SkipTest("context options require libzmq 3")
- ctx = self.Context()
- ctx.set(zmq.MAX_SOCKETS, 2)
- self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 2)
- ctx.max_sockets = 100
- self.assertEqual(ctx.max_sockets, 100)
- self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 100)
-
- def test_copy(self):
- c1 = self.Context()
- c2 = copy.copy(c1)
- c2b = copy.deepcopy(c1)
- c3 = copy.deepcopy(c2)
- self.assert_(c2._shadow)
- self.assert_(c3._shadow)
- self.assertEqual(c1.underlying, c2.underlying)
- self.assertEqual(c1.underlying, c3.underlying)
- self.assertEqual(c1.underlying, c2b.underlying)
- s = c3.socket(zmq.PUB)
- s.close()
- c1.term()
-
- def test_shadow(self):
- ctx = self.Context()
- ctx2 = self.Context.shadow(ctx.underlying)
- self.assertEqual(ctx.underlying, ctx2.underlying)
- s = ctx.socket(zmq.PUB)
- s.close()
- del ctx2
- self.assertFalse(ctx.closed)
- s = ctx.socket(zmq.PUB)
- ctx2 = self.Context.shadow(ctx.underlying)
- s2 = ctx2.socket(zmq.PUB)
- s.close()
- s2.close()
- ctx.term()
- self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB)
- del ctx2
- def test_shadow_pyczmq(self):
- try:
- from pyczmq import zctx, zsocket, zstr
- except Exception:
- raise SkipTest("Requires pyczmq")
- ctx = zctx.new()
- a = zsocket.new(ctx, zmq.PUSH)
- zsocket.bind(a, "inproc://a")
- ctx2 = self.Context.shadow_pyczmq(ctx)
- b = ctx2.socket(zmq.PULL)
- b.connect("inproc://a")
- zstr.send(a, b'hi')
- rcvd = self.recv(b)
- self.assertEqual(rcvd, b'hi')
- b.close()
- @mark.skipif(
- sys.platform.startswith('win'),
- reason='No fork on Windows')
- def test_fork_instance(self):
- ctx = self.Context.instance()
- parent_ctx_id = id(ctx)
- r_fd, w_fd = os.pipe()
- reader = os.fdopen(r_fd, 'r')
- child_pid = os.fork()
- if child_pid == 0:
- ctx = self.Context.instance()
- writer = os.fdopen(w_fd, 'w')
- child_ctx_id = id(ctx)
- ctx.term()
- writer.write(str(child_ctx_id) + "\n")
- writer.flush()
- writer.close()
- os._exit(0)
- else:
- os.close(w_fd)
- child_id_s = reader.readline()
- reader.close()
- assert child_id_s
- assert int(child_id_s) != parent_ctx_id
- ctx.term()
- if False: # disable green context tests
- class TestContextGreen(GreenTest, TestContext):
- """gevent subclass of context tests"""
- # skip tests that use real threads:
- test_gc = GreenTest.skip_green
- test_term_thread = GreenTest.skip_green
- test_destroy_linger = GreenTest.skip_green
|