base.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. """Base implementation of 0MQ authentication."""
  2. # Copyright (C) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. import logging
  5. import zmq
  6. from zmq.utils import z85
  7. from zmq.utils.strtypes import bytes, unicode, b, u
  8. from zmq.error import _check_version
  9. from .certs import load_certificates
  10. CURVE_ALLOW_ANY = '*'
  11. VERSION = b'1.0'
  12. class Authenticator(object):
  13. """Implementation of ZAP authentication for zmq connections.
  14. Note:
  15. - libzmq provides four levels of security: default NULL (which the Authenticator does
  16. not see), and authenticated NULL, PLAIN, CURVE, and GSSAPI, which the Authenticator can see.
  17. - until you add policies, all incoming NULL connections are allowed.
  18. (classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied.
  19. - GSSAPI requires no configuration.
  20. """
  21. def __init__(self, context=None, encoding='utf-8', log=None):
  22. _check_version((4,0), "security")
  23. self.context = context or zmq.Context.instance()
  24. self.encoding = encoding
  25. self.allow_any = False
  26. self.credentials_providers = {}
  27. self.zap_socket = None
  28. self.whitelist = set()
  29. self.blacklist = set()
  30. # passwords is a dict keyed by domain and contains values
  31. # of dicts with username:password pairs.
  32. self.passwords = {}
  33. # certs is dict keyed by domain and contains values
  34. # of dicts keyed by the public keys from the specified location.
  35. self.certs = {}
  36. self.log = log or logging.getLogger('zmq.auth')
  37. def start(self):
  38. """Create and bind the ZAP socket"""
  39. self.zap_socket = self.context.socket(zmq.REP)
  40. self.zap_socket.linger = 1
  41. self.zap_socket.bind("inproc://zeromq.zap.01")
  42. self.log.debug("Starting")
  43. def stop(self):
  44. """Close the ZAP socket"""
  45. if self.zap_socket:
  46. self.zap_socket.close()
  47. self.zap_socket = None
  48. def allow(self, *addresses):
  49. """Allow (whitelist) IP address(es).
  50. Connections from addresses not in the whitelist will be rejected.
  51. - For NULL, all clients from this address will be accepted.
  52. - For real auth setups, they will be allowed to continue with authentication.
  53. whitelist is mutually exclusive with blacklist.
  54. """
  55. if self.blacklist:
  56. raise ValueError("Only use a whitelist or a blacklist, not both")
  57. self.log.debug("Allowing %s", ','.join(addresses))
  58. self.whitelist.update(addresses)
  59. def deny(self, *addresses):
  60. """Deny (blacklist) IP address(es).
  61. Addresses not in the blacklist will be allowed to continue with authentication.
  62. Blacklist is mutually exclusive with whitelist.
  63. """
  64. if self.whitelist:
  65. raise ValueError("Only use a whitelist or a blacklist, not both")
  66. self.log.debug("Denying %s", ','.join(addresses))
  67. self.blacklist.update(addresses)
  68. def configure_plain(self, domain='*', passwords=None):
  69. """Configure PLAIN authentication for a given domain.
  70. PLAIN authentication uses a plain-text password file.
  71. To cover all domains, use "*".
  72. You can modify the password file at any time; it is reloaded automatically.
  73. """
  74. if passwords:
  75. self.passwords[domain] = passwords
  76. self.log.debug("Configure plain: %s", domain)
  77. def configure_curve(self, domain='*', location=None):
  78. """Configure CURVE authentication for a given domain.
  79. CURVE authentication uses a directory that holds all public client certificates,
  80. i.e. their public keys.
  81. To cover all domains, use "*".
  82. You can add and remove certificates in that directory at any time. configure_curve must be called
  83. every time certificates are added or removed, in order to update the Authenticator's state
  84. To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
  85. """
  86. # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise
  87. # treat location as a directory that holds the certificates.
  88. self.log.debug("Configure curve: %s[%s]", domain, location)
  89. if location == CURVE_ALLOW_ANY:
  90. self.allow_any = True
  91. else:
  92. self.allow_any = False
  93. try:
  94. self.certs[domain] = load_certificates(location)
  95. except Exception as e:
  96. self.log.error("Failed to load CURVE certs from %s: %s", location, e)
  97. def configure_curve_callback(self, domain='*', credentials_provider=None):
  98. """Configure CURVE authentication for a given domain.
  99. CURVE authentication using a callback function validating
  100. the client public key according to a custom mechanism, e.g. checking the
  101. key against records in a db. credentials_provider is an object of a class which
  102. implements a callback method accepting two parameters (domain and key), e.g.::
  103. class CredentialsProvider(object):
  104. def __init__(self):
  105. ...e.g. db connection
  106. def callback(self, domain, key):
  107. valid = ...lookup key and/or domain in db
  108. if valid:
  109. logging.info('Authorizing: {0}, {1}'.format(domain, key))
  110. return True
  111. else:
  112. logging.warning('NOT Authorizing: {0}, {1}'.format(domain, key))
  113. return False
  114. To cover all domains, use "*".
  115. To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
  116. """
  117. self.allow_any = False
  118. if credentials_provider is not None:
  119. self.credentials_providers[domain] = credentials_provider
  120. else:
  121. self.log.error("None credentials_provider provided for domain:%s",domain)
  122. def curve_user_id(self, client_public_key):
  123. """Return the User-Id corresponding to a CURVE client's public key
  124. Default implementation uses the z85-encoding of the public key.
  125. Override to define a custom mapping of public key : user-id
  126. This is only called on successful authentication.
  127. Parameters
  128. ----------
  129. client_public_key: bytes
  130. The client public key used for the given message
  131. Returns
  132. -------
  133. user_id: unicode
  134. The user ID as text
  135. """
  136. return z85.encode(client_public_key).decode('ascii')
  137. def configure_gssapi(self, domain='*', location=None):
  138. """Configure GSSAPI authentication
  139. Currently this is a no-op because there is nothing to configure with GSSAPI.
  140. """
  141. pass
  142. def handle_zap_message(self, msg):
  143. """Perform ZAP authentication"""
  144. if len(msg) < 6:
  145. self.log.error("Invalid ZAP message, not enough frames: %r", msg)
  146. if len(msg) < 2:
  147. self.log.error("Not enough information to reply")
  148. else:
  149. self._send_zap_reply(msg[1], b"400", b"Not enough frames")
  150. return
  151. version, request_id, domain, address, identity, mechanism = msg[:6]
  152. credentials = msg[6:]
  153. domain = u(domain, self.encoding, 'replace')
  154. address = u(address, self.encoding, 'replace')
  155. if (version != VERSION):
  156. self.log.error("Invalid ZAP version: %r", msg)
  157. self._send_zap_reply(request_id, b"400", b"Invalid version")
  158. return
  159. self.log.debug("version: %r, request_id: %r, domain: %r,"
  160. " address: %r, identity: %r, mechanism: %r",
  161. version, request_id, domain,
  162. address, identity, mechanism,
  163. )
  164. # Is address is explicitly whitelisted or blacklisted?
  165. allowed = False
  166. denied = False
  167. reason = b"NO ACCESS"
  168. if self.whitelist:
  169. if address in self.whitelist:
  170. allowed = True
  171. self.log.debug("PASSED (whitelist) address=%s", address)
  172. else:
  173. denied = True
  174. reason = b"Address not in whitelist"
  175. self.log.debug("DENIED (not in whitelist) address=%s", address)
  176. elif self.blacklist:
  177. if address in self.blacklist:
  178. denied = True
  179. reason = b"Address is blacklisted"
  180. self.log.debug("DENIED (blacklist) address=%s", address)
  181. else:
  182. allowed = True
  183. self.log.debug("PASSED (not in blacklist) address=%s", address)
  184. # Perform authentication mechanism-specific checks if necessary
  185. username = u("anonymous")
  186. if not denied:
  187. if mechanism == b'NULL' and not allowed:
  188. # For NULL, we allow if the address wasn't blacklisted
  189. self.log.debug("ALLOWED (NULL)")
  190. allowed = True
  191. elif mechanism == b'PLAIN':
  192. # For PLAIN, even a whitelisted address must authenticate
  193. if len(credentials) != 2:
  194. self.log.error("Invalid PLAIN credentials: %r", credentials)
  195. self._send_zap_reply(request_id, b"400", b"Invalid credentials")
  196. return
  197. username, password = [ u(c, self.encoding, 'replace') for c in credentials ]
  198. allowed, reason = self._authenticate_plain(domain, username, password)
  199. elif mechanism == b'CURVE':
  200. # For CURVE, even a whitelisted address must authenticate
  201. if len(credentials) != 1:
  202. self.log.error("Invalid CURVE credentials: %r", credentials)
  203. self._send_zap_reply(request_id, b"400", b"Invalid credentials")
  204. return
  205. key = credentials[0]
  206. allowed, reason = self._authenticate_curve(domain, key)
  207. if allowed:
  208. username = self.curve_user_id(key)
  209. elif mechanism == b'GSSAPI':
  210. if len(credentials) != 1:
  211. self.log.error("Invalid GSSAPI credentials: %r", credentials)
  212. self._send_zap_reply(request_id, b"400", b"Invalid credentials")
  213. return
  214. # use principal as user-id for now
  215. principal = username = credentials[0]
  216. allowed, reason = self._authenticate_gssapi(domain, principal)
  217. if allowed:
  218. self._send_zap_reply(request_id, b"200", b"OK", username)
  219. else:
  220. self._send_zap_reply(request_id, b"400", reason)
  221. def _authenticate_plain(self, domain, username, password):
  222. """PLAIN ZAP authentication"""
  223. allowed = False
  224. reason = b""
  225. if self.passwords:
  226. # If no domain is not specified then use the default domain
  227. if not domain:
  228. domain = '*'
  229. if domain in self.passwords:
  230. if username in self.passwords[domain]:
  231. if password == self.passwords[domain][username]:
  232. allowed = True
  233. else:
  234. reason = b"Invalid password"
  235. else:
  236. reason = b"Invalid username"
  237. else:
  238. reason = b"Invalid domain"
  239. if allowed:
  240. self.log.debug("ALLOWED (PLAIN) domain=%s username=%s password=%s",
  241. domain, username, password,
  242. )
  243. else:
  244. self.log.debug("DENIED %s", reason)
  245. else:
  246. reason = b"No passwords defined"
  247. self.log.debug("DENIED (PLAIN) %s", reason)
  248. return allowed, reason
  249. def _authenticate_curve(self, domain, client_key):
  250. """CURVE ZAP authentication"""
  251. allowed = False
  252. reason = b""
  253. if self.allow_any:
  254. allowed = True
  255. reason = b"OK"
  256. self.log.debug("ALLOWED (CURVE allow any client)")
  257. elif self.credentials_providers != {}:
  258. # If no explicit domain is specified then use the default domain
  259. if not domain:
  260. domain = '*'
  261. if domain in self.credentials_providers:
  262. z85_client_key = z85.encode(client_key)
  263. # Callback to check if key is Allowed
  264. if (self.credentials_providers[domain].callback(domain, z85_client_key)):
  265. allowed = True
  266. reason = b"OK"
  267. else:
  268. reason = b"Unknown key"
  269. status = "ALLOWED" if allowed else "DENIED"
  270. self.log.debug("%s (CURVE auth_callback) domain=%s client_key=%s",
  271. status, domain, z85_client_key,
  272. )
  273. else:
  274. reason = b"Unknown domain"
  275. else:
  276. # If no explicit domain is specified then use the default domain
  277. if not domain:
  278. domain = '*'
  279. if domain in self.certs:
  280. # The certs dict stores keys in z85 format, convert binary key to z85 bytes
  281. z85_client_key = z85.encode(client_key)
  282. if self.certs[domain].get(z85_client_key):
  283. allowed = True
  284. reason = b"OK"
  285. else:
  286. reason = b"Unknown key"
  287. status = "ALLOWED" if allowed else "DENIED"
  288. self.log.debug("%s (CURVE) domain=%s client_key=%s",
  289. status, domain, z85_client_key,
  290. )
  291. else:
  292. reason = b"Unknown domain"
  293. return allowed, reason
  294. def _authenticate_gssapi(self, domain, principal):
  295. """Nothing to do for GSSAPI, which has already been handled by an external service."""
  296. self.log.debug("ALLOWED (GSSAPI) domain=%s principal=%s", domain, principal)
  297. return True, b'OK'
  298. def _send_zap_reply(self, request_id, status_code, status_text, user_id='anonymous'):
  299. """Send a ZAP reply to finish the authentication."""
  300. user_id = user_id if status_code == b'200' else b''
  301. if isinstance(user_id, unicode):
  302. user_id = user_id.encode(self.encoding, 'replace')
  303. metadata = b'' # not currently used
  304. self.log.debug("ZAP reply code=%s text=%s", status_code, status_text)
  305. reply = [VERSION, request_id, status_code, status_text, user_id, metadata]
  306. self.zap_socket.send_multipart(reply)
  307. __all__ = ['Authenticator', 'CURVE_ALLOW_ANY']