test_message.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # -*- coding: utf8 -*-
  2. # Copyright (C) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. import copy
  5. import sys
  6. try:
  7. from sys import getrefcount as grc
  8. except ImportError:
  9. grc = None
  10. import time
  11. from pprint import pprint
  12. from unittest import TestCase
  13. import zmq
  14. from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy, PYPY
  15. from zmq.utils.strtypes import unicode, bytes, b, u
  16. # some useful constants:
  17. x = b'x'
  18. if grc:
  19. rc0 = grc(x)
  20. v = memoryview(x)
  21. view_rc = grc(x) - rc0
  22. def await_gc(obj, rc):
  23. """wait for refcount on an object to drop to an expected value
  24. Necessary because of the zero-copy gc thread,
  25. which can take some time to receive its DECREF message.
  26. """
  27. for i in range(50):
  28. # rc + 2 because of the refs in this function
  29. if grc(obj) <= rc + 2:
  30. return
  31. time.sleep(0.05)
  32. class TestFrame(BaseZMQTestCase):
  33. @skip_pypy
  34. def test_above_30(self):
  35. """Message above 30 bytes are never copied by 0MQ."""
  36. for i in range(5, 16): # 32, 64,..., 65536
  37. s = (2**i)*x
  38. self.assertEqual(grc(s), 2)
  39. m = zmq.Frame(s, copy=False)
  40. self.assertEqual(grc(s), 4)
  41. del m
  42. await_gc(s, 2)
  43. self.assertEqual(grc(s), 2)
  44. del s
  45. def test_str(self):
  46. """Test the str representations of the Frames."""
  47. for i in range(16):
  48. s = (2**i)*x
  49. m = zmq.Frame(s)
  50. m_str = str(m)
  51. m_str_b = b(m_str) # py3compat
  52. self.assertEqual(s, m_str_b)
  53. def test_bytes(self):
  54. """Test the Frame.bytes property."""
  55. for i in range(1,16):
  56. s = (2**i)*x
  57. m = zmq.Frame(s)
  58. b = m.bytes
  59. self.assertEqual(s, m.bytes)
  60. if not PYPY:
  61. # check that it copies
  62. self.assert_(b is not s)
  63. # check that it copies only once
  64. self.assert_(b is m.bytes)
  65. def test_unicode(self):
  66. """Test the unicode representations of the Frames."""
  67. s = u('asdf')
  68. self.assertRaises(TypeError, zmq.Frame, s)
  69. for i in range(16):
  70. s = (2**i)*u('§')
  71. m = zmq.Frame(s.encode('utf8'))
  72. self.assertEqual(s, unicode(m.bytes,'utf8'))
  73. def test_len(self):
  74. """Test the len of the Frames."""
  75. for i in range(16):
  76. s = (2**i)*x
  77. m = zmq.Frame(s)
  78. self.assertEqual(len(s), len(m))
  79. @skip_pypy
  80. def test_lifecycle1(self):
  81. """Run through a ref counting cycle with a copy."""
  82. for i in range(5, 16): # 32, 64,..., 65536
  83. s = (2**i)*x
  84. rc = 2
  85. self.assertEqual(grc(s), rc)
  86. m = zmq.Frame(s, copy=False)
  87. rc += 2
  88. self.assertEqual(grc(s), rc)
  89. m2 = copy.copy(m)
  90. rc += 1
  91. self.assertEqual(grc(s), rc)
  92. buf = m2.buffer
  93. rc += view_rc
  94. self.assertEqual(grc(s), rc)
  95. self.assertEqual(s, b(str(m)))
  96. self.assertEqual(s, bytes(m2))
  97. self.assertEqual(s, m.bytes)
  98. # self.assert_(s is str(m))
  99. # self.assert_(s is str(m2))
  100. del m2
  101. rc -= 1
  102. self.assertEqual(grc(s), rc)
  103. rc -= view_rc
  104. del buf
  105. self.assertEqual(grc(s), rc)
  106. del m
  107. rc -= 2
  108. await_gc(s, rc)
  109. self.assertEqual(grc(s), rc)
  110. self.assertEqual(rc, 2)
  111. del s
  112. @skip_pypy
  113. def test_lifecycle2(self):
  114. """Run through a different ref counting cycle with a copy."""
  115. for i in range(5, 16): # 32, 64,..., 65536
  116. s = (2**i)*x
  117. rc = 2
  118. self.assertEqual(grc(s), rc)
  119. m = zmq.Frame(s, copy=False)
  120. rc += 2
  121. self.assertEqual(grc(s), rc)
  122. m2 = copy.copy(m)
  123. rc += 1
  124. self.assertEqual(grc(s), rc)
  125. buf = m.buffer
  126. rc += view_rc
  127. self.assertEqual(grc(s), rc)
  128. self.assertEqual(s, b(str(m)))
  129. self.assertEqual(s, bytes(m2))
  130. self.assertEqual(s, m2.bytes)
  131. self.assertEqual(s, m.bytes)
  132. # self.assert_(s is str(m))
  133. # self.assert_(s is str(m2))
  134. del buf
  135. self.assertEqual(grc(s), rc)
  136. del m
  137. # m.buffer is kept until m is del'd
  138. rc -= view_rc
  139. rc -= 1
  140. self.assertEqual(grc(s), rc)
  141. del m2
  142. rc -= 2
  143. await_gc(s, rc)
  144. self.assertEqual(grc(s), rc)
  145. self.assertEqual(rc, 2)
  146. del s
  147. @skip_pypy
  148. def test_tracker(self):
  149. m = zmq.Frame(b'asdf', copy=False, track=True)
  150. self.assertFalse(m.tracker.done)
  151. pm = zmq.MessageTracker(m)
  152. self.assertFalse(pm.done)
  153. del m
  154. for i in range(10):
  155. if pm.done:
  156. break
  157. time.sleep(0.1)
  158. self.assertTrue(pm.done)
  159. def test_no_tracker(self):
  160. m = zmq.Frame(b'asdf', track=False)
  161. self.assertEqual(m.tracker, None)
  162. m2 = copy.copy(m)
  163. self.assertEqual(m2.tracker, None)
  164. self.assertRaises(ValueError, zmq.MessageTracker, m)
  165. @skip_pypy
  166. def test_multi_tracker(self):
  167. m = zmq.Frame(b'asdf', copy=False, track=True)
  168. m2 = zmq.Frame(b'whoda', copy=False, track=True)
  169. mt = zmq.MessageTracker(m,m2)
  170. self.assertFalse(m.tracker.done)
  171. self.assertFalse(mt.done)
  172. self.assertRaises(zmq.NotDone, mt.wait, 0.1)
  173. del m
  174. time.sleep(0.1)
  175. self.assertRaises(zmq.NotDone, mt.wait, 0.1)
  176. self.assertFalse(mt.done)
  177. del m2
  178. self.assertTrue(mt.wait() is None)
  179. self.assertTrue(mt.done)
  180. def test_buffer_in(self):
  181. """test using a buffer as input"""
  182. ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√")
  183. m = zmq.Frame(memoryview(ins))
  184. def test_bad_buffer_in(self):
  185. """test using a bad object"""
  186. self.assertRaises(TypeError, zmq.Frame, 5)
  187. self.assertRaises(TypeError, zmq.Frame, object())
  188. def test_buffer_out(self):
  189. """receiving buffered output"""
  190. ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√")
  191. m = zmq.Frame(ins)
  192. outb = m.buffer
  193. self.assertTrue(isinstance(outb, memoryview))
  194. self.assert_(outb is m.buffer)
  195. self.assert_(m.buffer is m.buffer)
  196. @skip_pypy
  197. def test_memoryview_shape(self):
  198. """memoryview shape info"""
  199. if sys.version_info < (3,):
  200. raise SkipTest("only test memoryviews on Python 3")
  201. data = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√")
  202. n = len(data)
  203. f = zmq.Frame(data)
  204. view1 = f.buffer
  205. self.assertEqual(view1.ndim, 1)
  206. self.assertEqual(view1.shape, (n,))
  207. self.assertEqual(view1.tobytes(), data)
  208. view2 = memoryview(f)
  209. self.assertEqual(view2.ndim, 1)
  210. self.assertEqual(view2.shape, (n,))
  211. self.assertEqual(view2.tobytes(), data)
  212. def test_multisend(self):
  213. """ensure that a message remains intact after multiple sends"""
  214. a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
  215. s = b"message"
  216. m = zmq.Frame(s)
  217. self.assertEqual(s, m.bytes)
  218. a.send(m, copy=False)
  219. time.sleep(0.1)
  220. self.assertEqual(s, m.bytes)
  221. a.send(m, copy=False)
  222. time.sleep(0.1)
  223. self.assertEqual(s, m.bytes)
  224. a.send(m, copy=True)
  225. time.sleep(0.1)
  226. self.assertEqual(s, m.bytes)
  227. a.send(m, copy=True)
  228. time.sleep(0.1)
  229. self.assertEqual(s, m.bytes)
  230. for i in range(4):
  231. r = b.recv()
  232. self.assertEqual(s,r)
  233. self.assertEqual(s, m.bytes)
  234. def test_memoryview(self):
  235. """test messages from memoryview"""
  236. major,minor = sys.version_info[:2]
  237. if not (major >= 3 or (major == 2 and minor >= 7)):
  238. raise SkipTest("memoryviews only in python >= 2.7")
  239. s = b'carrotjuice'
  240. v = memoryview(s)
  241. m = zmq.Frame(s)
  242. buf = m.buffer
  243. s2 = buf.tobytes()
  244. self.assertEqual(s2,s)
  245. self.assertEqual(m.bytes,s)
  246. def test_noncopying_recv(self):
  247. """check for clobbering message buffers"""
  248. null = b'\0'*64
  249. sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
  250. for i in range(32):
  251. # try a few times
  252. sb.send(null, copy=False)
  253. m = sa.recv(copy=False)
  254. mb = m.bytes
  255. # buf = memoryview(m)
  256. buf = m.buffer
  257. del m
  258. for i in range(5):
  259. ff=b'\xff'*(40 + i*10)
  260. sb.send(ff, copy=False)
  261. m2 = sa.recv(copy=False)
  262. b = buf.tobytes()
  263. self.assertEqual(b, null)
  264. self.assertEqual(mb, null)
  265. self.assertEqual(m2.bytes, ff)
  266. @skip_pypy
  267. def test_buffer_numpy(self):
  268. """test non-copying numpy array messages"""
  269. try:
  270. import numpy
  271. from numpy.testing import assert_array_equal
  272. except ImportError:
  273. raise SkipTest("requires numpy")
  274. if sys.version_info < (2,7):
  275. raise SkipTest("requires new-style buffer interface (py >= 2.7)")
  276. rand = numpy.random.randint
  277. shapes = [ rand(2,5) for i in range(5) ]
  278. a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
  279. dtypes = [int, float, '>i4', 'B']
  280. for i in range(1,len(shapes)+1):
  281. shape = shapes[:i]
  282. for dt in dtypes:
  283. A = numpy.empty(shape, dtype=dt)
  284. a.send(A, copy=False)
  285. msg = b.recv(copy=False)
  286. B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
  287. assert_array_equal(A, B)
  288. A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')])
  289. A['a'] = 1024
  290. A['b'] = 1e9
  291. A['c'] = 'hello there'
  292. a.send(A, copy=False)
  293. msg = b.recv(copy=False)
  294. B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
  295. assert_array_equal(A, B)
  296. def test_frame_more(self):
  297. """test Frame.more attribute"""
  298. frame = zmq.Frame(b"hello")
  299. self.assertFalse(frame.more)
  300. sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
  301. sa.send_multipart([b'hi', b'there'])
  302. frame = self.recv(sb, copy=False)
  303. self.assertTrue(frame.more)
  304. if zmq.zmq_version_info()[0] >= 3 and not PYPY:
  305. self.assertTrue(frame.get(zmq.MORE))
  306. frame = self.recv(sb, copy=False)
  307. self.assertFalse(frame.more)
  308. if zmq.zmq_version_info()[0] >= 3 and not PYPY:
  309. self.assertFalse(frame.get(zmq.MORE))