123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379 |
- """Base implementation of 0MQ authentication."""
- # Copyright (C) PyZMQ Developers
- # Distributed under the terms of the Modified BSD License.
- import logging
- import zmq
- from zmq.utils import z85
- from zmq.utils.strtypes import bytes, unicode, b, u
- from zmq.error import _check_version
- from .certs import load_certificates
- CURVE_ALLOW_ANY = '*'
- VERSION = b'1.0'
- class Authenticator(object):
- """Implementation of ZAP authentication for zmq connections.
- Note:
- - libzmq provides four levels of security: default NULL (which the Authenticator does
- not see), and authenticated NULL, PLAIN, CURVE, and GSSAPI, which the Authenticator can see.
- - until you add policies, all incoming NULL connections are allowed.
- (classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied.
- - GSSAPI requires no configuration.
- """
- def __init__(self, context=None, encoding='utf-8', log=None):
- _check_version((4,0), "security")
- self.context = context or zmq.Context.instance()
- self.encoding = encoding
- self.allow_any = False
- self.credentials_providers = {}
- self.zap_socket = None
- self.whitelist = set()
- self.blacklist = set()
- # passwords is a dict keyed by domain and contains values
- # of dicts with username:password pairs.
- self.passwords = {}
- # certs is dict keyed by domain and contains values
- # of dicts keyed by the public keys from the specified location.
- self.certs = {}
- self.log = log or logging.getLogger('zmq.auth')
-
- def start(self):
- """Create and bind the ZAP socket"""
- self.zap_socket = self.context.socket(zmq.REP)
- self.zap_socket.linger = 1
- self.zap_socket.bind("inproc://zeromq.zap.01")
- self.log.debug("Starting")
- def stop(self):
- """Close the ZAP socket"""
- if self.zap_socket:
- self.zap_socket.close()
- self.zap_socket = None
- def allow(self, *addresses):
- """Allow (whitelist) IP address(es).
-
- Connections from addresses not in the whitelist will be rejected.
-
- - For NULL, all clients from this address will be accepted.
- - For real auth setups, they will be allowed to continue with authentication.
-
- whitelist is mutually exclusive with blacklist.
- """
- if self.blacklist:
- raise ValueError("Only use a whitelist or a blacklist, not both")
- self.log.debug("Allowing %s", ','.join(addresses))
- self.whitelist.update(addresses)
- def deny(self, *addresses):
- """Deny (blacklist) IP address(es).
-
- Addresses not in the blacklist will be allowed to continue with authentication.
-
- Blacklist is mutually exclusive with whitelist.
- """
- if self.whitelist:
- raise ValueError("Only use a whitelist or a blacklist, not both")
- self.log.debug("Denying %s", ','.join(addresses))
- self.blacklist.update(addresses)
- def configure_plain(self, domain='*', passwords=None):
- """Configure PLAIN authentication for a given domain.
-
- PLAIN authentication uses a plain-text password file.
- To cover all domains, use "*".
- You can modify the password file at any time; it is reloaded automatically.
- """
- if passwords:
- self.passwords[domain] = passwords
- self.log.debug("Configure plain: %s", domain)
- def configure_curve(self, domain='*', location=None):
- """Configure CURVE authentication for a given domain.
-
- CURVE authentication uses a directory that holds all public client certificates,
- i.e. their public keys.
-
- To cover all domains, use "*".
-
- You can add and remove certificates in that directory at any time. configure_curve must be called
- every time certificates are added or removed, in order to update the Authenticator's state
-
- To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
- """
- # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise
- # treat location as a directory that holds the certificates.
- self.log.debug("Configure curve: %s[%s]", domain, location)
- if location == CURVE_ALLOW_ANY:
- self.allow_any = True
- else:
- self.allow_any = False
- try:
- self.certs[domain] = load_certificates(location)
- except Exception as e:
- self.log.error("Failed to load CURVE certs from %s: %s", location, e)
- def configure_curve_callback(self, domain='*', credentials_provider=None):
- """Configure CURVE authentication for a given domain.
- CURVE authentication using a callback function validating
- the client public key according to a custom mechanism, e.g. checking the
- key against records in a db. credentials_provider is an object of a class which
- implements a callback method accepting two parameters (domain and key), e.g.::
- class CredentialsProvider(object):
- def __init__(self):
- ...e.g. db connection
- def callback(self, domain, key):
- valid = ...lookup key and/or domain in db
- if valid:
- logging.info('Authorizing: {0}, {1}'.format(domain, key))
- return True
- else:
- logging.warning('NOT Authorizing: {0}, {1}'.format(domain, key))
- return False
- To cover all domains, use "*".
- To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
- """
- self.allow_any = False
- if credentials_provider is not None:
- self.credentials_providers[domain] = credentials_provider
- else:
- self.log.error("None credentials_provider provided for domain:%s",domain)
- def curve_user_id(self, client_public_key):
- """Return the User-Id corresponding to a CURVE client's public key
-
- Default implementation uses the z85-encoding of the public key.
-
- Override to define a custom mapping of public key : user-id
-
- This is only called on successful authentication.
-
- Parameters
- ----------
- client_public_key: bytes
- The client public key used for the given message
-
- Returns
- -------
- user_id: unicode
- The user ID as text
- """
- return z85.encode(client_public_key).decode('ascii')
- def configure_gssapi(self, domain='*', location=None):
- """Configure GSSAPI authentication
-
- Currently this is a no-op because there is nothing to configure with GSSAPI.
- """
- pass
- def handle_zap_message(self, msg):
- """Perform ZAP authentication"""
- if len(msg) < 6:
- self.log.error("Invalid ZAP message, not enough frames: %r", msg)
- if len(msg) < 2:
- self.log.error("Not enough information to reply")
- else:
- self._send_zap_reply(msg[1], b"400", b"Not enough frames")
- return
-
- version, request_id, domain, address, identity, mechanism = msg[:6]
- credentials = msg[6:]
-
- domain = u(domain, self.encoding, 'replace')
- address = u(address, self.encoding, 'replace')
- if (version != VERSION):
- self.log.error("Invalid ZAP version: %r", msg)
- self._send_zap_reply(request_id, b"400", b"Invalid version")
- return
- self.log.debug("version: %r, request_id: %r, domain: %r,"
- " address: %r, identity: %r, mechanism: %r",
- version, request_id, domain,
- address, identity, mechanism,
- )
- # Is address is explicitly whitelisted or blacklisted?
- allowed = False
- denied = False
- reason = b"NO ACCESS"
- if self.whitelist:
- if address in self.whitelist:
- allowed = True
- self.log.debug("PASSED (whitelist) address=%s", address)
- else:
- denied = True
- reason = b"Address not in whitelist"
- self.log.debug("DENIED (not in whitelist) address=%s", address)
- elif self.blacklist:
- if address in self.blacklist:
- denied = True
- reason = b"Address is blacklisted"
- self.log.debug("DENIED (blacklist) address=%s", address)
- else:
- allowed = True
- self.log.debug("PASSED (not in blacklist) address=%s", address)
- # Perform authentication mechanism-specific checks if necessary
- username = u("anonymous")
- if not denied:
- if mechanism == b'NULL' and not allowed:
- # For NULL, we allow if the address wasn't blacklisted
- self.log.debug("ALLOWED (NULL)")
- allowed = True
- elif mechanism == b'PLAIN':
- # For PLAIN, even a whitelisted address must authenticate
- if len(credentials) != 2:
- self.log.error("Invalid PLAIN credentials: %r", credentials)
- self._send_zap_reply(request_id, b"400", b"Invalid credentials")
- return
- username, password = [ u(c, self.encoding, 'replace') for c in credentials ]
- allowed, reason = self._authenticate_plain(domain, username, password)
- elif mechanism == b'CURVE':
- # For CURVE, even a whitelisted address must authenticate
- if len(credentials) != 1:
- self.log.error("Invalid CURVE credentials: %r", credentials)
- self._send_zap_reply(request_id, b"400", b"Invalid credentials")
- return
- key = credentials[0]
- allowed, reason = self._authenticate_curve(domain, key)
- if allowed:
- username = self.curve_user_id(key)
-
- elif mechanism == b'GSSAPI':
- if len(credentials) != 1:
- self.log.error("Invalid GSSAPI credentials: %r", credentials)
- self._send_zap_reply(request_id, b"400", b"Invalid credentials")
- return
- # use principal as user-id for now
- principal = username = credentials[0]
- allowed, reason = self._authenticate_gssapi(domain, principal)
- if allowed:
- self._send_zap_reply(request_id, b"200", b"OK", username)
- else:
- self._send_zap_reply(request_id, b"400", reason)
- def _authenticate_plain(self, domain, username, password):
- """PLAIN ZAP authentication"""
- allowed = False
- reason = b""
- if self.passwords:
- # If no domain is not specified then use the default domain
- if not domain:
- domain = '*'
- if domain in self.passwords:
- if username in self.passwords[domain]:
- if password == self.passwords[domain][username]:
- allowed = True
- else:
- reason = b"Invalid password"
- else:
- reason = b"Invalid username"
- else:
- reason = b"Invalid domain"
- if allowed:
- self.log.debug("ALLOWED (PLAIN) domain=%s username=%s password=%s",
- domain, username, password,
- )
- else:
- self.log.debug("DENIED %s", reason)
- else:
- reason = b"No passwords defined"
- self.log.debug("DENIED (PLAIN) %s", reason)
- return allowed, reason
- def _authenticate_curve(self, domain, client_key):
- """CURVE ZAP authentication"""
- allowed = False
- reason = b""
- if self.allow_any:
- allowed = True
- reason = b"OK"
- self.log.debug("ALLOWED (CURVE allow any client)")
- elif self.credentials_providers != {}:
- # If no explicit domain is specified then use the default domain
- if not domain:
- domain = '*'
- if domain in self.credentials_providers:
- z85_client_key = z85.encode(client_key)
- # Callback to check if key is Allowed
- if (self.credentials_providers[domain].callback(domain, z85_client_key)):
- allowed = True
- reason = b"OK"
- else:
- reason = b"Unknown key"
- status = "ALLOWED" if allowed else "DENIED"
- self.log.debug("%s (CURVE auth_callback) domain=%s client_key=%s",
- status, domain, z85_client_key,
- )
- else:
- reason = b"Unknown domain"
- else:
- # If no explicit domain is specified then use the default domain
- if not domain:
- domain = '*'
- if domain in self.certs:
- # The certs dict stores keys in z85 format, convert binary key to z85 bytes
- z85_client_key = z85.encode(client_key)
- if self.certs[domain].get(z85_client_key):
- allowed = True
- reason = b"OK"
- else:
- reason = b"Unknown key"
- status = "ALLOWED" if allowed else "DENIED"
- self.log.debug("%s (CURVE) domain=%s client_key=%s",
- status, domain, z85_client_key,
- )
- else:
- reason = b"Unknown domain"
- return allowed, reason
- def _authenticate_gssapi(self, domain, principal):
- """Nothing to do for GSSAPI, which has already been handled by an external service."""
- self.log.debug("ALLOWED (GSSAPI) domain=%s principal=%s", domain, principal)
- return True, b'OK'
- def _send_zap_reply(self, request_id, status_code, status_text, user_id='anonymous'):
- """Send a ZAP reply to finish the authentication."""
- user_id = user_id if status_code == b'200' else b''
- if isinstance(user_id, unicode):
- user_id = user_id.encode(self.encoding, 'replace')
- metadata = b'' # not currently used
- self.log.debug("ZAP reply code=%s text=%s", status_code, status_text)
- reply = [VERSION, request_id, status_code, status_text, user_id, metadata]
- self.zap_socket.send_multipart(reply)
- __all__ = ['Authenticator', 'CURVE_ALLOW_ANY']
|