thread.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. """ZAP Authenticator in a Python Thread.
  2. .. versionadded:: 14.1
  3. """
  4. # Copyright (C) PyZMQ Developers
  5. # Distributed under the terms of the Modified BSD License.
  6. import time
  7. import logging
  8. from threading import Thread, Event
  9. import zmq
  10. from zmq.utils import jsonapi
  11. from zmq.utils.strtypes import bytes, unicode, b, u
  12. import sys
  13. from .base import Authenticator
  14. class AuthenticationThread(Thread):
  15. """A Thread for running a zmq Authenticator
  16. This is run in the background by ThreadedAuthenticator
  17. """
  18. def __init__(self, context, endpoint, encoding='utf-8', log=None, authenticator=None):
  19. super(AuthenticationThread, self).__init__()
  20. self.context = context or zmq.Context.instance()
  21. self.encoding = encoding
  22. self.log = log = log or logging.getLogger('zmq.auth')
  23. self.started = Event()
  24. self.authenticator = authenticator or Authenticator(context, encoding=encoding, log=log)
  25. # create a socket to communicate back to main thread.
  26. self.pipe = context.socket(zmq.PAIR)
  27. self.pipe.linger = 1
  28. self.pipe.connect(endpoint)
  29. def run(self):
  30. """Start the Authentication Agent thread task"""
  31. self.authenticator.start()
  32. self.started.set()
  33. zap = self.authenticator.zap_socket
  34. poller = zmq.Poller()
  35. poller.register(self.pipe, zmq.POLLIN)
  36. poller.register(zap, zmq.POLLIN)
  37. while True:
  38. try:
  39. socks = dict(poller.poll())
  40. except zmq.ZMQError:
  41. break # interrupted
  42. if self.pipe in socks and socks[self.pipe] == zmq.POLLIN:
  43. # Make sure all API requests are processed before
  44. # looking at the ZAP socket.
  45. while True:
  46. try:
  47. msg = self.pipe.recv_multipart(flags=zmq.NOBLOCK)
  48. except zmq.Again:
  49. break
  50. terminate = self._handle_pipe(msg)
  51. if terminate:
  52. break
  53. if terminate:
  54. break
  55. if zap in socks and socks[zap] == zmq.POLLIN:
  56. self._handle_zap()
  57. self.pipe.close()
  58. self.authenticator.stop()
  59. def _handle_zap(self):
  60. """
  61. Handle a message from the ZAP socket.
  62. """
  63. msg = self.authenticator.zap_socket.recv_multipart()
  64. if not msg: return
  65. self.authenticator.handle_zap_message(msg)
  66. def _handle_pipe(self, msg):
  67. """
  68. Handle a message from front-end API.
  69. """
  70. terminate = False
  71. if msg is None:
  72. terminate = True
  73. return terminate
  74. command = msg[0]
  75. self.log.debug("auth received API command %r", command)
  76. if command == b'ALLOW':
  77. addresses = [u(m, self.encoding) for m in msg[1:]]
  78. try:
  79. self.authenticator.allow(*addresses)
  80. except Exception as e:
  81. self.log.exception("Failed to allow %s", addresses)
  82. elif command == b'DENY':
  83. addresses = [u(m, self.encoding) for m in msg[1:]]
  84. try:
  85. self.authenticator.deny(*addresses)
  86. except Exception as e:
  87. self.log.exception("Failed to deny %s", addresses)
  88. elif command == b'PLAIN':
  89. domain = u(msg[1], self.encoding)
  90. json_passwords = msg[2]
  91. self.authenticator.configure_plain(domain, jsonapi.loads(json_passwords))
  92. elif command == b'CURVE':
  93. # For now we don't do anything with domains
  94. domain = u(msg[1], self.encoding)
  95. # If location is CURVE_ALLOW_ANY, allow all clients. Otherwise
  96. # treat location as a directory that holds the certificates.
  97. location = u(msg[2], self.encoding)
  98. self.authenticator.configure_curve(domain, location)
  99. elif command == b'TERMINATE':
  100. terminate = True
  101. else:
  102. self.log.error("Invalid auth command from API: %r", command)
  103. return terminate
  104. def _inherit_docstrings(cls):
  105. """inherit docstrings from Authenticator, so we don't duplicate them"""
  106. for name, method in cls.__dict__.items():
  107. if name.startswith('_') or not callable(method):
  108. continue
  109. upstream_method = getattr(Authenticator, name, None)
  110. if not method.__doc__:
  111. method.__doc__ = upstream_method.__doc__
  112. return cls
  113. @_inherit_docstrings
  114. class ThreadAuthenticator(object):
  115. """Run ZAP authentication in a background thread"""
  116. context = None
  117. log = None
  118. encoding = None
  119. pipe = None
  120. pipe_endpoint = ''
  121. thread = None
  122. auth = None
  123. def __init__(self, context=None, encoding='utf-8', log=None):
  124. self.context = context or zmq.Context.instance()
  125. self.log = log
  126. self.encoding = encoding
  127. self.pipe = None
  128. self.pipe_endpoint = "inproc://{0}.inproc".format(id(self))
  129. self.thread = None
  130. # proxy base Authenticator attributes
  131. def __setattr__(self, key, value):
  132. for obj in [self] + self.__class__.mro():
  133. if key in obj.__dict__:
  134. object.__setattr__(self, key, value)
  135. return
  136. setattr(self.thread.authenticator, key, value)
  137. def __getattr__(self, key):
  138. try:
  139. object.__getattr__(self, key)
  140. except AttributeError:
  141. return getattr(self.thread.authenticator, key)
  142. def allow(self, *addresses):
  143. self.pipe.send_multipart([b'ALLOW'] + [b(a, self.encoding) for a in addresses])
  144. def deny(self, *addresses):
  145. self.pipe.send_multipart([b'DENY'] + [b(a, self.encoding) for a in addresses])
  146. def configure_plain(self, domain='*', passwords=None):
  147. self.pipe.send_multipart([b'PLAIN', b(domain, self.encoding), jsonapi.dumps(passwords or {})])
  148. def configure_curve(self, domain='*', location=''):
  149. domain = b(domain, self.encoding)
  150. location = b(location, self.encoding)
  151. self.pipe.send_multipart([b'CURVE', domain, location])
  152. def configure_curve_callback(self, domain='*', credentials_provider=None):
  153. self.thread.authenticator.configure_curve_callback(domain, credentials_provider=credentials_provider)
  154. def start(self):
  155. """Start the authentication thread"""
  156. # create a socket to communicate with auth thread.
  157. self.pipe = self.context.socket(zmq.PAIR)
  158. self.pipe.linger = 1
  159. self.pipe.bind(self.pipe_endpoint)
  160. self.thread = AuthenticationThread(self.context, self.pipe_endpoint, encoding=self.encoding, log=self.log)
  161. self.thread.start()
  162. # Event.wait:Changed in version 2.7: Previously, the method always returned None.
  163. if sys.version_info < (2,7):
  164. self.thread.started.wait(timeout=10)
  165. else:
  166. if not self.thread.started.wait(timeout=10):
  167. raise RuntimeError("Authenticator thread failed to start")
  168. def stop(self):
  169. """Stop the authentication thread"""
  170. if self.pipe:
  171. self.pipe.send(b'TERMINATE')
  172. if self.is_alive():
  173. self.thread.join()
  174. self.thread = None
  175. self.pipe.close()
  176. self.pipe = None
  177. def is_alive(self):
  178. """Is the ZAP thread currently running?"""
  179. if self.thread and self.thread.is_alive():
  180. return True
  181. return False
  182. def __del__(self):
  183. self.stop()
  184. __all__ = ['ThreadAuthenticator']