test_security.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. """Test libzmq security (libzmq >= 3.3.0)"""
  2. # -*- coding: utf8 -*-
  3. # Copyright (C) PyZMQ Developers
  4. # Distributed under the terms of the Modified BSD License.
  5. import os
  6. import contextlib
  7. import time
  8. from threading import Thread
  9. import zmq
  10. from zmq.tests import (
  11. BaseZMQTestCase, SkipTest, PYPY
  12. )
  13. from zmq.utils import z85
  14. USER = b"admin"
  15. PASS = b"password"
  16. class TestSecurity(BaseZMQTestCase):
  17. def setUp(self):
  18. if zmq.zmq_version_info() < (4,0):
  19. raise SkipTest("security is new in libzmq 4.0")
  20. try:
  21. zmq.curve_keypair()
  22. except zmq.ZMQError:
  23. raise SkipTest("security requires libzmq to be built with CURVE support")
  24. super(TestSecurity, self).setUp()
  25. def zap_handler(self):
  26. socket = self.context.socket(zmq.REP)
  27. socket.bind("inproc://zeromq.zap.01")
  28. try:
  29. msg = self.recv_multipart(socket)
  30. version, sequence, domain, address, identity, mechanism = msg[:6]
  31. if mechanism == b'PLAIN':
  32. username, password = msg[6:]
  33. elif mechanism == b'CURVE':
  34. key = msg[6]
  35. self.assertEqual(version, b"1.0")
  36. self.assertEqual(identity, b"IDENT")
  37. reply = [version, sequence]
  38. if mechanism == b'CURVE' or \
  39. (mechanism == b'PLAIN' and username == USER and password == PASS) or \
  40. (mechanism == b'NULL'):
  41. reply.extend([
  42. b"200",
  43. b"OK",
  44. b"anonymous",
  45. b"\5Hello\0\0\0\5World",
  46. ])
  47. else:
  48. reply.extend([
  49. b"400",
  50. b"Invalid username or password",
  51. b"",
  52. b"",
  53. ])
  54. socket.send_multipart(reply)
  55. finally:
  56. socket.close()
  57. @contextlib.contextmanager
  58. def zap(self):
  59. self.start_zap()
  60. time.sleep(0.5) # allow time for the Thread to start
  61. try:
  62. yield
  63. finally:
  64. self.stop_zap()
  65. def start_zap(self):
  66. self.zap_thread = Thread(target=self.zap_handler)
  67. self.zap_thread.start()
  68. def stop_zap(self):
  69. self.zap_thread.join()
  70. def bounce(self, server, client, test_metadata=True):
  71. msg = [os.urandom(64), os.urandom(64)]
  72. client.send_multipart(msg)
  73. frames = self.recv_multipart(server, copy=False)
  74. recvd = list(map(lambda x: x.bytes, frames))
  75. try:
  76. if test_metadata and not PYPY:
  77. for frame in frames:
  78. self.assertEqual(frame.get('User-Id'), 'anonymous')
  79. self.assertEqual(frame.get('Hello'), 'World')
  80. self.assertEqual(frame['Socket-Type'], 'DEALER')
  81. except zmq.ZMQVersionError:
  82. pass
  83. self.assertEqual(recvd, msg)
  84. server.send_multipart(recvd)
  85. msg2 = self.recv_multipart(client)
  86. self.assertEqual(msg2, msg)
  87. def test_null(self):
  88. """test NULL (default) security"""
  89. server = self.socket(zmq.DEALER)
  90. client = self.socket(zmq.DEALER)
  91. self.assertEqual(client.MECHANISM, zmq.NULL)
  92. self.assertEqual(server.mechanism, zmq.NULL)
  93. self.assertEqual(client.plain_server, 0)
  94. self.assertEqual(server.plain_server, 0)
  95. iface = 'tcp://127.0.0.1'
  96. port = server.bind_to_random_port(iface)
  97. client.connect("%s:%i" % (iface, port))
  98. self.bounce(server, client, False)
  99. def test_plain(self):
  100. """test PLAIN authentication"""
  101. server = self.socket(zmq.DEALER)
  102. server.identity = b'IDENT'
  103. client = self.socket(zmq.DEALER)
  104. self.assertEqual(client.plain_username, b'')
  105. self.assertEqual(client.plain_password, b'')
  106. client.plain_username = USER
  107. client.plain_password = PASS
  108. self.assertEqual(client.getsockopt(zmq.PLAIN_USERNAME), USER)
  109. self.assertEqual(client.getsockopt(zmq.PLAIN_PASSWORD), PASS)
  110. self.assertEqual(client.plain_server, 0)
  111. self.assertEqual(server.plain_server, 0)
  112. server.plain_server = True
  113. self.assertEqual(server.mechanism, zmq.PLAIN)
  114. self.assertEqual(client.mechanism, zmq.PLAIN)
  115. assert not client.plain_server
  116. assert server.plain_server
  117. with self.zap():
  118. iface = 'tcp://127.0.0.1'
  119. port = server.bind_to_random_port(iface)
  120. client.connect("%s:%i" % (iface, port))
  121. self.bounce(server, client)
  122. def skip_plain_inauth(self):
  123. """test PLAIN failed authentication"""
  124. server = self.socket(zmq.DEALER)
  125. server.identity = b'IDENT'
  126. client = self.socket(zmq.DEALER)
  127. self.sockets.extend([server, client])
  128. client.plain_username = USER
  129. client.plain_password = b'incorrect'
  130. server.plain_server = True
  131. self.assertEqual(server.mechanism, zmq.PLAIN)
  132. self.assertEqual(client.mechanism, zmq.PLAIN)
  133. with self.zap():
  134. iface = 'tcp://127.0.0.1'
  135. port = server.bind_to_random_port(iface)
  136. client.connect("%s:%i" % (iface, port))
  137. client.send(b'ping')
  138. server.rcvtimeo = 250
  139. self.assertRaisesErrno(zmq.EAGAIN, server.recv)
  140. def test_keypair(self):
  141. """test curve_keypair"""
  142. try:
  143. public, secret = zmq.curve_keypair()
  144. except zmq.ZMQError:
  145. raise SkipTest("CURVE unsupported")
  146. self.assertEqual(type(secret), bytes)
  147. self.assertEqual(type(public), bytes)
  148. self.assertEqual(len(secret), 40)
  149. self.assertEqual(len(public), 40)
  150. # verify that it is indeed Z85
  151. bsecret, bpublic = [ z85.decode(key) for key in (public, secret) ]
  152. self.assertEqual(type(bsecret), bytes)
  153. self.assertEqual(type(bpublic), bytes)
  154. self.assertEqual(len(bsecret), 32)
  155. self.assertEqual(len(bpublic), 32)
  156. def test_curve_public(self):
  157. """test curve_public"""
  158. try:
  159. public, secret = zmq.curve_keypair()
  160. except zmq.ZMQError:
  161. raise SkipTest("CURVE unsupported")
  162. if zmq.zmq_version_info() < (4,2):
  163. raise SkipTest("curve_public is new in libzmq 4.2")
  164. derived_public = zmq.curve_public(secret)
  165. self.assertEqual(type(derived_public), bytes)
  166. self.assertEqual(len(derived_public), 40)
  167. # verify that it is indeed Z85
  168. bpublic = z85.decode(derived_public)
  169. self.assertEqual(type(bpublic), bytes)
  170. self.assertEqual(len(bpublic), 32)
  171. # verify that it is equal to the known public key
  172. self.assertEqual(derived_public, public)
  173. def test_curve(self):
  174. """test CURVE encryption"""
  175. server = self.socket(zmq.DEALER)
  176. server.identity = b'IDENT'
  177. client = self.socket(zmq.DEALER)
  178. self.sockets.extend([server, client])
  179. try:
  180. server.curve_server = True
  181. except zmq.ZMQError as e:
  182. # will raise EINVAL if no CURVE support
  183. if e.errno == zmq.EINVAL:
  184. raise SkipTest("CURVE unsupported")
  185. server_public, server_secret = zmq.curve_keypair()
  186. client_public, client_secret = zmq.curve_keypair()
  187. server.curve_secretkey = server_secret
  188. server.curve_publickey = server_public
  189. client.curve_serverkey = server_public
  190. client.curve_publickey = client_public
  191. client.curve_secretkey = client_secret
  192. self.assertEqual(server.mechanism, zmq.CURVE)
  193. self.assertEqual(client.mechanism, zmq.CURVE)
  194. self.assertEqual(server.get(zmq.CURVE_SERVER), True)
  195. self.assertEqual(client.get(zmq.CURVE_SERVER), False)
  196. with self.zap():
  197. iface = 'tcp://127.0.0.1'
  198. port = server.bind_to_random_port(iface)
  199. client.connect("%s:%i" % (iface, port))
  200. self.bounce(server, client)