# Copyright (C) PyZMQ Developers # Distributed under the terms of the Modified BSD License. import copy import gc import os import sys import time from threading import Thread, Event try: from queue import Queue except ImportError: from Queue import Queue try: from unittest import mock except ImportError: mock = None from pytest import mark import zmq from zmq.tests import ( BaseZMQTestCase, have_gevent, GreenTest, skip_green, PYPY, SkipTest, ) class KwargTestSocket(zmq.Socket): test_kwarg_value = None def __init__(self, *args, **kwargs): self.test_kwarg_value = kwargs.pop('test_kwarg', None) super(KwargTestSocket, self).__init__(*args, **kwargs) class KwargTestContext(zmq.Context): _socket_class = KwargTestSocket class TestContext(BaseZMQTestCase): def test_init(self): c1 = self.Context() self.assert_(isinstance(c1, self.Context)) del c1 c2 = self.Context() self.assert_(isinstance(c2, self.Context)) del c2 c3 = self.Context() self.assert_(isinstance(c3, self.Context)) del c3 def test_dir(self): ctx = self.Context() self.assertTrue('socket' in dir(ctx)) if zmq.zmq_version_info() > (3,): self.assertTrue('IO_THREADS' in dir(ctx)) ctx.term() @mark.skipif(mock is None, reason="requires unittest.mock") def test_mockable(self): m = mock.Mock(spec=self.context) def test_term(self): c = self.Context() c.term() self.assert_(c.closed) def test_context_manager(self): with self.Context() as c: pass self.assert_(c.closed) def test_fail_init(self): self.assertRaisesErrno(zmq.EINVAL, self.Context, -1) def test_term_hang(self): rep,req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER) req.setsockopt(zmq.LINGER, 0) req.send(b'hello', copy=False) req.close() rep.close() self.context.term() def test_instance(self): ctx = self.Context.instance() c2 = self.Context.instance(io_threads=2) self.assertTrue(c2 is ctx) c2.term() c3 = self.Context.instance() c4 = self.Context.instance() self.assertFalse(c3 is c2) self.assertFalse(c3.closed) self.assertTrue(c3 is c4) def test_instance_subclass_first(self): self.context.term() class SubContext(zmq.Context): pass sctx = SubContext.instance() ctx = zmq.Context.instance() ctx.term() sctx.term() assert type(ctx) is zmq.Context assert type(sctx) is SubContext def test_instance_subclass_second(self): self.context.term() class SubContextInherit(zmq.Context): pass class SubContextNoInherit(zmq.Context): _instance = None pass ctx = zmq.Context.instance() sctx = SubContextInherit.instance() sctx2 = SubContextNoInherit.instance() ctx.term() sctx.term() sctx2.term() assert type(ctx) is zmq.Context assert type(sctx) is zmq.Context assert type(sctx2) is SubContextNoInherit def test_instance_threadsafe(self): self.context.term() # clear default context q = Queue() # slow context initialization, # to ensure that we are both trying to create one at the same time class SlowContext(self.Context): def __init__(self, *a, **kw): time.sleep(1) super(SlowContext, self).__init__(*a, **kw) def f(): q.put(SlowContext.instance()) # call ctx.instance() in several threads at once N = 16 threads = [ Thread(target=f) for i in range(N) ] [ t.start() for t in threads ] # also call it in the main thread (not first) ctx = SlowContext.instance() assert isinstance(ctx, SlowContext) # check that all the threads got the same context for i in range(N): thread_ctx = q.get(timeout=5) assert thread_ctx is ctx # cleanup ctx.term() [ t.join(timeout=5) for t in threads ] def test_socket_passes_kwargs(self): test_kwarg_value = 'testing one two three' with KwargTestContext() as ctx: with ctx.socket(zmq.DEALER, test_kwarg=test_kwarg_value) as socket: self.assertTrue(socket.test_kwarg_value is test_kwarg_value) def test_many_sockets(self): """opening and closing many sockets shouldn't cause problems""" ctx = self.Context() for i in range(16): sockets = [ ctx.socket(zmq.REP) for i in range(65) ] [ s.close() for s in sockets ] # give the reaper a chance time.sleep(1e-2) ctx.term() def test_sockopts(self): """setting socket options with ctx attributes""" ctx = self.Context() ctx.linger = 5 self.assertEqual(ctx.linger, 5) s = ctx.socket(zmq.REQ) self.assertEqual(s.linger, 5) self.assertEqual(s.getsockopt(zmq.LINGER), 5) s.close() # check that subscribe doesn't get set on sockets that don't subscribe: ctx.subscribe = b'' s = ctx.socket(zmq.REQ) s.close() ctx.term() @mark.skipif( sys.platform.startswith('win'), reason='Segfaults on Windows') def test_destroy(self): """Context.destroy should close sockets""" ctx = self.Context() sockets = [ ctx.socket(zmq.REP) for i in range(65) ] # close half of the sockets [ s.close() for s in sockets[::2] ] ctx.destroy() # reaper is not instantaneous time.sleep(1e-2) for s in sockets: self.assertTrue(s.closed) def test_destroy_linger(self): """Context.destroy should set linger on closing sockets""" req,rep = self.create_bound_pair(zmq.REQ, zmq.REP) req.send(b'hi') time.sleep(1e-2) self.context.destroy(linger=0) # reaper is not instantaneous time.sleep(1e-2) for s in (req,rep): self.assertTrue(s.closed) def test_term_noclose(self): """Context.term won't close sockets""" ctx = self.Context() s = ctx.socket(zmq.REQ) self.assertFalse(s.closed) t = Thread(target=ctx.term) t.start() t.join(timeout=0.1) self.assertTrue(t.is_alive(), "Context should be waiting") s.close() t.join(timeout=0.1) self.assertFalse(t.is_alive(), "Context should have closed") def test_gc(self): """test close&term by garbage collection alone""" if PYPY: raise SkipTest("GC doesn't work ") # test credit @dln (GH #137): def gcf(): def inner(): ctx = self.Context() s = ctx.socket(zmq.PUSH) inner() gc.collect() t = Thread(target=gcf) t.start() t.join(timeout=1) self.assertFalse(t.is_alive(), "Garbage collection should have cleaned up context") def test_cyclic_destroy(self): """ctx.destroy should succeed when cyclic ref prevents gc""" # test credit @dln (GH #137): class CyclicReference(object): def __init__(self, parent=None): self.parent = parent def crash(self, sock): self.sock = sock self.child = CyclicReference(self) def crash_zmq(): ctx = self.Context() sock = ctx.socket(zmq.PULL) c = CyclicReference() c.crash(sock) ctx.destroy() crash_zmq() def test_term_thread(self): """ctx.term should not crash active threads (#139)""" ctx = self.Context() evt = Event() evt.clear() def block(): s = ctx.socket(zmq.REP) s.bind_to_random_port('tcp://127.0.0.1') evt.set() try: s.recv() except zmq.ZMQError as e: self.assertEqual(e.errno, zmq.ETERM) return finally: s.close() self.fail("recv should have been interrupted with ETERM") t = Thread(target=block) t.start() evt.wait(1) self.assertTrue(evt.is_set(), "sync event never fired") time.sleep(0.01) ctx.term() t.join(timeout=1) self.assertFalse(t.is_alive(), "term should have interrupted s.recv()") def test_destroy_no_sockets(self): ctx = self.Context() s = ctx.socket(zmq.PUB) s.bind_to_random_port('tcp://127.0.0.1') s.close() ctx.destroy() assert s.closed assert ctx.closed def test_ctx_opts(self): if zmq.zmq_version_info() < (3,): raise SkipTest("context options require libzmq 3") ctx = self.Context() ctx.set(zmq.MAX_SOCKETS, 2) self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 2) ctx.max_sockets = 100 self.assertEqual(ctx.max_sockets, 100) self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 100) def test_copy(self): c1 = self.Context() c2 = copy.copy(c1) c2b = copy.deepcopy(c1) c3 = copy.deepcopy(c2) self.assert_(c2._shadow) self.assert_(c3._shadow) self.assertEqual(c1.underlying, c2.underlying) self.assertEqual(c1.underlying, c3.underlying) self.assertEqual(c1.underlying, c2b.underlying) s = c3.socket(zmq.PUB) s.close() c1.term() def test_shadow(self): ctx = self.Context() ctx2 = self.Context.shadow(ctx.underlying) self.assertEqual(ctx.underlying, ctx2.underlying) s = ctx.socket(zmq.PUB) s.close() del ctx2 self.assertFalse(ctx.closed) s = ctx.socket(zmq.PUB) ctx2 = self.Context.shadow(ctx.underlying) s2 = ctx2.socket(zmq.PUB) s.close() s2.close() ctx.term() self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB) del ctx2 def test_shadow_pyczmq(self): try: from pyczmq import zctx, zsocket, zstr except Exception: raise SkipTest("Requires pyczmq") ctx = zctx.new() a = zsocket.new(ctx, zmq.PUSH) zsocket.bind(a, "inproc://a") ctx2 = self.Context.shadow_pyczmq(ctx) b = ctx2.socket(zmq.PULL) b.connect("inproc://a") zstr.send(a, b'hi') rcvd = self.recv(b) self.assertEqual(rcvd, b'hi') b.close() @mark.skipif( sys.platform.startswith('win'), reason='No fork on Windows') def test_fork_instance(self): ctx = self.Context.instance() parent_ctx_id = id(ctx) r_fd, w_fd = os.pipe() reader = os.fdopen(r_fd, 'r') child_pid = os.fork() if child_pid == 0: ctx = self.Context.instance() writer = os.fdopen(w_fd, 'w') child_ctx_id = id(ctx) ctx.term() writer.write(str(child_ctx_id) + "\n") writer.flush() writer.close() os._exit(0) else: os.close(w_fd) child_id_s = reader.readline() reader.close() assert child_id_s assert int(child_id_s) != parent_ctx_id ctx.term() if False: # disable green context tests class TestContextGreen(GreenTest, TestContext): """gevent subclass of context tests""" # skip tests that use real threads: test_gc = GreenTest.skip_green test_term_thread = GreenTest.skip_green test_destroy_linger = GreenTest.skip_green