test_auth.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. # -*- coding: utf8 -*-
  2. # Copyright (C) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. import logging
  5. import os
  6. import shutil
  7. import sys
  8. import tempfile
  9. import pytest
  10. import zmq.auth
  11. from zmq.auth.thread import ThreadAuthenticator
  12. from zmq.utils.strtypes import u
  13. from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy
  14. class BaseAuthTestCase(BaseZMQTestCase):
  15. def setUp(self):
  16. if zmq.zmq_version_info() < (4,0):
  17. raise SkipTest("security is new in libzmq 4.0")
  18. try:
  19. zmq.curve_keypair()
  20. except zmq.ZMQError:
  21. raise SkipTest("security requires libzmq to have curve support")
  22. super(BaseAuthTestCase, self).setUp()
  23. # enable debug logging while we run tests
  24. logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
  25. self.auth = self.make_auth()
  26. self.auth.start()
  27. self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs()
  28. def make_auth(self):
  29. raise NotImplementedError()
  30. def tearDown(self):
  31. if self.auth:
  32. self.auth.stop()
  33. self.auth = None
  34. self.remove_certs(self.base_dir)
  35. super(BaseAuthTestCase, self).tearDown()
  36. def create_certs(self):
  37. """Create CURVE certificates for a test"""
  38. # Create temporary CURVE keypairs for this test run. We create all keys in a
  39. # temp directory and then move them into the appropriate private or public
  40. # directory.
  41. base_dir = tempfile.mkdtemp()
  42. keys_dir = os.path.join(base_dir, 'certificates')
  43. public_keys_dir = os.path.join(base_dir, 'public_keys')
  44. secret_keys_dir = os.path.join(base_dir, 'private_keys')
  45. os.mkdir(keys_dir)
  46. os.mkdir(public_keys_dir)
  47. os.mkdir(secret_keys_dir)
  48. server_public_file, server_secret_file = zmq.auth.create_certificates(keys_dir, "server")
  49. client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "client")
  50. for key_file in os.listdir(keys_dir):
  51. if key_file.endswith(".key"):
  52. shutil.move(os.path.join(keys_dir, key_file),
  53. os.path.join(public_keys_dir, '.'))
  54. for key_file in os.listdir(keys_dir):
  55. if key_file.endswith(".key_secret"):
  56. shutil.move(os.path.join(keys_dir, key_file),
  57. os.path.join(secret_keys_dir, '.'))
  58. return (base_dir, public_keys_dir, secret_keys_dir)
  59. def remove_certs(self, base_dir):
  60. """Remove certificates for a test"""
  61. shutil.rmtree(base_dir)
  62. def load_certs(self, secret_keys_dir):
  63. """Return server and client certificate keys"""
  64. server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
  65. client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
  66. server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
  67. client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
  68. return server_public, server_secret, client_public, client_secret
  69. class TestThreadAuthentication(BaseAuthTestCase):
  70. """Test authentication running in a thread"""
  71. def make_auth(self):
  72. return ThreadAuthenticator(self.context)
  73. def can_connect(self, server, client):
  74. """Check if client can connect to server using tcp transport"""
  75. result = False
  76. iface = 'tcp://127.0.0.1'
  77. port = server.bind_to_random_port(iface)
  78. client.connect("%s:%i" % (iface, port))
  79. msg = [b"Hello World"]
  80. if server.poll(1000, zmq.POLLOUT):
  81. server.send_multipart(msg)
  82. if client.poll(1000):
  83. rcvd_msg = client.recv_multipart()
  84. self.assertEqual(rcvd_msg, msg)
  85. result = True
  86. return result
  87. def test_null(self):
  88. """threaded auth - NULL"""
  89. # A default NULL connection should always succeed, and not
  90. # go through our authentication infrastructure at all.
  91. self.auth.stop()
  92. self.auth = None
  93. # use a new context, so ZAP isn't inherited
  94. self.context = self.Context()
  95. server = self.socket(zmq.PUSH)
  96. client = self.socket(zmq.PULL)
  97. self.assertTrue(self.can_connect(server, client))
  98. # By setting a domain we switch on authentication for NULL sockets,
  99. # though no policies are configured yet. The client connection
  100. # should still be allowed.
  101. server = self.socket(zmq.PUSH)
  102. server.zap_domain = b'global'
  103. client = self.socket(zmq.PULL)
  104. self.assertTrue(self.can_connect(server, client))
  105. def test_blacklist(self):
  106. """threaded auth - Blacklist"""
  107. # Blacklist 127.0.0.1, connection should fail
  108. self.auth.deny('127.0.0.1')
  109. server = self.socket(zmq.PUSH)
  110. # By setting a domain we switch on authentication for NULL sockets,
  111. # though no policies are configured yet.
  112. server.zap_domain = b'global'
  113. client = self.socket(zmq.PULL)
  114. self.assertFalse(self.can_connect(server, client))
  115. def test_whitelist(self):
  116. """threaded auth - Whitelist"""
  117. # Whitelist 127.0.0.1, connection should pass"
  118. self.auth.allow('127.0.0.1')
  119. server = self.socket(zmq.PUSH)
  120. # By setting a domain we switch on authentication for NULL sockets,
  121. # though no policies are configured yet.
  122. server.zap_domain = b'global'
  123. client = self.socket(zmq.PULL)
  124. self.assertTrue(self.can_connect(server, client))
  125. def test_plain(self):
  126. """threaded auth - PLAIN"""
  127. # Try PLAIN authentication - without configuring server, connection should fail
  128. server = self.socket(zmq.PUSH)
  129. server.plain_server = True
  130. client = self.socket(zmq.PULL)
  131. client.plain_username = b'admin'
  132. client.plain_password = b'Password'
  133. self.assertFalse(self.can_connect(server, client))
  134. # Try PLAIN authentication - with server configured, connection should pass
  135. server = self.socket(zmq.PUSH)
  136. server.plain_server = True
  137. client = self.socket(zmq.PULL)
  138. client.plain_username = b'admin'
  139. client.plain_password = b'Password'
  140. self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
  141. self.assertTrue(self.can_connect(server, client))
  142. # Try PLAIN authentication - with bogus credentials, connection should fail
  143. server = self.socket(zmq.PUSH)
  144. server.plain_server = True
  145. client = self.socket(zmq.PULL)
  146. client.plain_username = b'admin'
  147. client.plain_password = b'Bogus'
  148. self.assertFalse(self.can_connect(server, client))
  149. # Remove authenticator and check that a normal connection works
  150. self.auth.stop()
  151. self.auth = None
  152. server = self.socket(zmq.PUSH)
  153. client = self.socket(zmq.PULL)
  154. self.assertTrue(self.can_connect(server, client))
  155. client.close()
  156. server.close()
  157. def test_curve(self):
  158. """threaded auth - CURVE"""
  159. self.auth.allow('127.0.0.1')
  160. certs = self.load_certs(self.secret_keys_dir)
  161. server_public, server_secret, client_public, client_secret = certs
  162. #Try CURVE authentication - without configuring server, connection should fail
  163. server = self.socket(zmq.PUSH)
  164. server.curve_publickey = server_public
  165. server.curve_secretkey = server_secret
  166. server.curve_server = True
  167. client = self.socket(zmq.PULL)
  168. client.curve_publickey = client_public
  169. client.curve_secretkey = client_secret
  170. client.curve_serverkey = server_public
  171. self.assertFalse(self.can_connect(server, client))
  172. #Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass
  173. self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
  174. server = self.socket(zmq.PUSH)
  175. server.curve_publickey = server_public
  176. server.curve_secretkey = server_secret
  177. server.curve_server = True
  178. client = self.socket(zmq.PULL)
  179. client.curve_publickey = client_public
  180. client.curve_secretkey = client_secret
  181. client.curve_serverkey = server_public
  182. self.assertTrue(self.can_connect(server, client))
  183. # Try CURVE authentication - with server configured, connection should pass
  184. self.auth.configure_curve(domain='*', location=self.public_keys_dir)
  185. server = self.socket(zmq.PULL)
  186. server.curve_publickey = server_public
  187. server.curve_secretkey = server_secret
  188. server.curve_server = True
  189. client = self.socket(zmq.PUSH)
  190. client.curve_publickey = client_public
  191. client.curve_secretkey = client_secret
  192. client.curve_serverkey = server_public
  193. assert self.can_connect(client, server)
  194. # Remove authenticator and check that a normal connection works
  195. self.auth.stop()
  196. self.auth = None
  197. # Try connecting using NULL and no authentication enabled, connection should pass
  198. server = self.socket(zmq.PUSH)
  199. client = self.socket(zmq.PULL)
  200. self.assertTrue(self.can_connect(server, client))
  201. def test_curve_callback(self):
  202. """threaded auth - CURVE with callback authentication"""
  203. self.auth.allow('127.0.0.1')
  204. certs = self.load_certs(self.secret_keys_dir)
  205. server_public, server_secret, client_public, client_secret = certs
  206. #Try CURVE authentication - without configuring server, connection should fail
  207. server = self.socket(zmq.PUSH)
  208. server.curve_publickey = server_public
  209. server.curve_secretkey = server_secret
  210. server.curve_server = True
  211. client = self.socket(zmq.PULL)
  212. client.curve_publickey = client_public
  213. client.curve_secretkey = client_secret
  214. client.curve_serverkey = server_public
  215. self.assertFalse(self.can_connect(server, client))
  216. #Try CURVE authentication - with callback authentication configured, connection should pass
  217. class CredentialsProvider(object):
  218. def __init__(self):
  219. self.client = client_public
  220. def callback(self, domain, key):
  221. if (key == self.client):
  222. return True
  223. else:
  224. return False
  225. provider = CredentialsProvider()
  226. self.auth.configure_curve_callback(credentials_provider=provider)
  227. server = self.socket(zmq.PUSH)
  228. server.curve_publickey = server_public
  229. server.curve_secretkey = server_secret
  230. server.curve_server = True
  231. client = self.socket(zmq.PULL)
  232. client.curve_publickey = client_public
  233. client.curve_secretkey = client_secret
  234. client.curve_serverkey = server_public
  235. self.assertTrue(self.can_connect(server, client))
  236. #Try CURVE authentication - with callback authentication configured with wrong key, connection should not pass
  237. class WrongCredentialsProvider(object):
  238. def __init__(self):
  239. self.client = "WrongCredentials"
  240. def callback(self, domain, key):
  241. if (key == self.client):
  242. return True
  243. else:
  244. return False
  245. provider = WrongCredentialsProvider()
  246. self.auth.configure_curve_callback(credentials_provider=provider)
  247. server = self.socket(zmq.PUSH)
  248. server.curve_publickey = server_public
  249. server.curve_secretkey = server_secret
  250. server.curve_server = True
  251. client = self.socket(zmq.PULL)
  252. client.curve_publickey = client_public
  253. client.curve_secretkey = client_secret
  254. client.curve_serverkey = server_public
  255. self.assertFalse(self.can_connect(server, client))
  256. @skip_pypy
  257. def test_curve_user_id(self):
  258. """threaded auth - CURVE"""
  259. self.auth.allow('127.0.0.1')
  260. certs = self.load_certs(self.secret_keys_dir)
  261. server_public, server_secret, client_public, client_secret = certs
  262. self.auth.configure_curve(domain='*', location=self.public_keys_dir)
  263. server = self.socket(zmq.PULL)
  264. server.curve_publickey = server_public
  265. server.curve_secretkey = server_secret
  266. server.curve_server = True
  267. client = self.socket(zmq.PUSH)
  268. client.curve_publickey = client_public
  269. client.curve_secretkey = client_secret
  270. client.curve_serverkey = server_public
  271. assert self.can_connect(client, server)
  272. # test default user-id map
  273. client.send(b'test')
  274. msg = self.recv(server, copy=False)
  275. assert msg.bytes == b'test'
  276. try:
  277. user_id = msg.get('User-Id')
  278. except zmq.ZMQVersionError:
  279. pass
  280. else:
  281. assert user_id == u(client_public)
  282. # test custom user-id map
  283. self.auth.curve_user_id = lambda client_key: u'custom'
  284. client2 = self.socket(zmq.PUSH)
  285. client2.curve_publickey = client_public
  286. client2.curve_secretkey = client_secret
  287. client2.curve_serverkey = server_public
  288. assert self.can_connect(client2, server)
  289. client2.send(b'test2')
  290. msg = self.recv(server, copy=False)
  291. assert msg.bytes == b'test2'
  292. try:
  293. user_id = msg.get('User-Id')
  294. except zmq.ZMQVersionError:
  295. pass
  296. else:
  297. assert user_id == u'custom'
  298. def with_ioloop(method, expect_success=True):
  299. """decorator for running tests with an IOLoop"""
  300. def test_method(self):
  301. r = method(self)
  302. loop = self.io_loop
  303. if expect_success:
  304. self.pullstream.on_recv(self.on_message_succeed)
  305. else:
  306. self.pullstream.on_recv(self.on_message_fail)
  307. loop.call_later(1, self.attempt_connection)
  308. loop.call_later(1.2, self.send_msg)
  309. if expect_success:
  310. loop.call_later(2, self.on_test_timeout_fail)
  311. else:
  312. loop.call_later(2, self.on_test_timeout_succeed)
  313. loop.start()
  314. if self.fail_msg:
  315. self.fail(self.fail_msg)
  316. return r
  317. return test_method
  318. def should_auth(method):
  319. return with_ioloop(method, True)
  320. def should_not_auth(method):
  321. return with_ioloop(method, False)
  322. class TestIOLoopAuthentication(BaseAuthTestCase):
  323. """Test authentication running in ioloop"""
  324. def setUp(self):
  325. try:
  326. from tornado import ioloop
  327. except ImportError:
  328. pytest.skip("Requires tornado")
  329. from zmq.eventloop import zmqstream
  330. self.fail_msg = None
  331. self.io_loop = ioloop.IOLoop()
  332. super(TestIOLoopAuthentication, self).setUp()
  333. self.server = self.socket(zmq.PUSH)
  334. self.client = self.socket(zmq.PULL)
  335. self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop)
  336. self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop)
  337. def make_auth(self):
  338. from zmq.auth.ioloop import IOLoopAuthenticator
  339. return IOLoopAuthenticator(self.context, io_loop=self.io_loop)
  340. def tearDown(self):
  341. if self.auth:
  342. self.auth.stop()
  343. self.auth = None
  344. self.io_loop.close(all_fds=True)
  345. super(TestIOLoopAuthentication, self).tearDown()
  346. def attempt_connection(self):
  347. """Check if client can connect to server using tcp transport"""
  348. iface = 'tcp://127.0.0.1'
  349. port = self.server.bind_to_random_port(iface)
  350. self.client.connect("%s:%i" % (iface, port))
  351. def send_msg(self):
  352. """Send a message from server to a client"""
  353. msg = [b"Hello World"]
  354. self.pushstream.send_multipart(msg)
  355. def on_message_succeed(self, frames):
  356. """A message was received, as expected."""
  357. if frames != [b"Hello World"]:
  358. self.fail_msg = "Unexpected message received"
  359. self.io_loop.stop()
  360. def on_message_fail(self, frames):
  361. """A message was received, unexpectedly."""
  362. self.fail_msg = 'Received messaged unexpectedly, security failed'
  363. self.io_loop.stop()
  364. def on_test_timeout_succeed(self):
  365. """Test timer expired, indicates test success"""
  366. self.io_loop.stop()
  367. def on_test_timeout_fail(self):
  368. """Test timer expired, indicates test failure"""
  369. self.fail_msg = 'Test timed out'
  370. self.io_loop.stop()
  371. @should_auth
  372. def test_none(self):
  373. """ioloop auth - NONE"""
  374. # A default NULL connection should always succeed, and not
  375. # go through our authentication infrastructure at all.
  376. # no auth should be running
  377. self.auth.stop()
  378. self.auth = None
  379. @should_auth
  380. def test_null(self):
  381. """ioloop auth - NULL"""
  382. # By setting a domain we switch on authentication for NULL sockets,
  383. # though no policies are configured yet. The client connection
  384. # should still be allowed.
  385. self.server.zap_domain = b'global'
  386. @should_not_auth
  387. def test_blacklist(self):
  388. """ioloop auth - Blacklist"""
  389. # Blacklist 127.0.0.1, connection should fail
  390. self.auth.deny('127.0.0.1')
  391. self.server.zap_domain = b'global'
  392. @should_auth
  393. def test_whitelist(self):
  394. """ioloop auth - Whitelist"""
  395. # Whitelist 127.0.0.1, which overrides the blacklist, connection should pass"
  396. self.auth.allow('127.0.0.1')
  397. self.server.setsockopt(zmq.ZAP_DOMAIN, b'global')
  398. @should_not_auth
  399. def test_plain_unconfigured_server(self):
  400. """ioloop auth - PLAIN, unconfigured server"""
  401. self.client.plain_username = b'admin'
  402. self.client.plain_password = b'Password'
  403. # Try PLAIN authentication - without configuring server, connection should fail
  404. self.server.plain_server = True
  405. @should_auth
  406. def test_plain_configured_server(self):
  407. """ioloop auth - PLAIN, configured server"""
  408. self.client.plain_username = b'admin'
  409. self.client.plain_password = b'Password'
  410. # Try PLAIN authentication - with server configured, connection should pass
  411. self.server.plain_server = True
  412. self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
  413. @should_not_auth
  414. def test_plain_bogus_credentials(self):
  415. """ioloop auth - PLAIN, bogus credentials"""
  416. self.client.plain_username = b'admin'
  417. self.client.plain_password = b'Bogus'
  418. self.server.plain_server = True
  419. self.auth.configure_plain(domain='*', passwords={'admin': 'Password'})
  420. @should_not_auth
  421. def test_curve_unconfigured_server(self):
  422. """ioloop auth - CURVE, unconfigured server"""
  423. certs = self.load_certs(self.secret_keys_dir)
  424. server_public, server_secret, client_public, client_secret = certs
  425. self.auth.allow('127.0.0.1')
  426. self.server.curve_publickey = server_public
  427. self.server.curve_secretkey = server_secret
  428. self.server.curve_server = True
  429. self.client.curve_publickey = client_public
  430. self.client.curve_secretkey = client_secret
  431. self.client.curve_serverkey = server_public
  432. @should_auth
  433. def test_curve_allow_any(self):
  434. """ioloop auth - CURVE, CURVE_ALLOW_ANY"""
  435. certs = self.load_certs(self.secret_keys_dir)
  436. server_public, server_secret, client_public, client_secret = certs
  437. self.auth.allow('127.0.0.1')
  438. self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
  439. self.server.curve_publickey = server_public
  440. self.server.curve_secretkey = server_secret
  441. self.server.curve_server = True
  442. self.client.curve_publickey = client_public
  443. self.client.curve_secretkey = client_secret
  444. self.client.curve_serverkey = server_public
  445. @should_auth
  446. def test_curve_configured_server(self):
  447. """ioloop auth - CURVE, configured server"""
  448. self.auth.allow('127.0.0.1')
  449. certs = self.load_certs(self.secret_keys_dir)
  450. server_public, server_secret, client_public, client_secret = certs
  451. self.auth.configure_curve(domain='*', location=self.public_keys_dir)
  452. self.server.curve_publickey = server_public
  453. self.server.curve_secretkey = server_secret
  454. self.server.curve_server = True
  455. self.client.curve_publickey = client_public
  456. self.client.curve_secretkey = client_secret
  457. self.client.curve_serverkey = server_public