123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- # -*- coding: utf-8 -*-
- """
- hyper/ssl_compat
- ~~~~~~~~~
- Shoves pyOpenSSL into an API that looks like the standard Python 3.x ssl
- module.
- Currently exposes exactly those attributes, classes, and methods that we
- actually use in hyper (all method signatures are complete, however). May be
- expanded to something more general-purpose in the future.
- """
- try:
- import StringIO as BytesIO
- except ImportError:
- from io import BytesIO
- import errno
- import socket
- import time
- from OpenSSL import SSL as ossl
- from service_identity.pyopenssl import verify_hostname as _verify
- CERT_NONE = ossl.VERIFY_NONE
- CERT_REQUIRED = ossl.VERIFY_PEER | ossl.VERIFY_FAIL_IF_NO_PEER_CERT
- _OPENSSL_ATTRS = dict(
- OP_NO_COMPRESSION='OP_NO_COMPRESSION',
- PROTOCOL_TLSv1_2='TLSv1_2_METHOD',
- PROTOCOL_SSLv23='SSLv23_METHOD',
- )
- for external, internal in _OPENSSL_ATTRS.items():
- value = getattr(ossl, internal, None)
- if value:
- locals()[external] = value
- OP_ALL = 0
- # TODO: Find out the names of these other flags.
- for bit in [31] + list(range(10)):
- OP_ALL |= 1 << bit
- HAS_NPN = True
- def _proxy(method):
- def inner(self, *args, **kwargs):
- return getattr(self._conn, method)(*args, **kwargs)
- return inner
- # Referenced in hyper/http20/connection.py. These values come
- # from the python ssl package, and must be defined in this file
- # for hyper to work in python versions <2.7.9
- SSL_ERROR_WANT_READ = 2
- SSL_ERROR_WANT_WRITE = 3
- # TODO missing some attributes
- class SSLError(OSError):
- pass
- class CertificateError(SSLError):
- pass
- def verify_hostname(ssl_sock, server_hostname):
- """
- A method nearly compatible with the stdlib's match_hostname.
- """
- if isinstance(server_hostname, bytes):
- server_hostname = server_hostname.decode('ascii')
- return _verify(ssl_sock._conn, server_hostname)
- class SSLSocket(object):
- SSL_TIMEOUT = 3
- SSL_RETRY = .01
- def __init__(self, conn, server_side, do_handshake_on_connect,
- suppress_ragged_eofs, server_hostname, check_hostname):
- self._conn = conn
- self._do_handshake_on_connect = do_handshake_on_connect
- self._suppress_ragged_eofs = suppress_ragged_eofs
- self._check_hostname = check_hostname
- if server_side:
- self._conn.set_accept_state()
- else:
- if server_hostname:
- self._conn.set_tlsext_host_name(
- server_hostname.encode('utf-8')
- )
- self._server_hostname = server_hostname
- # FIXME does this override do_handshake_on_connect=False?
- self._conn.set_connect_state()
- if self.connected and self._do_handshake_on_connect:
- self.do_handshake()
- @property
- def connected(self):
- try:
- self._conn.getpeername()
- except socket.error as e:
- if e.errno != errno.ENOTCONN:
- # It's an exception other than the one we expected if we're not
- # connected.
- raise
- return False
- return True
- # Lovingly stolen from CherryPy
- # (http://svn.cherrypy.org/tags/cherrypy-3.2.1/cherrypy/wsgiserver/ssl_pyopenssl.py).
- def _safe_ssl_call(self, suppress_ragged_eofs, call, *args, **kwargs):
- """Wrap the given call with SSL error-trapping."""
- start = time.time()
- while True:
- try:
- return call(*args, **kwargs)
- except (ossl.WantReadError, ossl.WantWriteError):
- # Sleep and try again. This is dangerous, because it means
- # the rest of the stack has no way of differentiating
- # between a "new handshake" error and "client dropped".
- # Note this isn't an endless loop: there's a timeout below.
- time.sleep(self.SSL_RETRY)
- except ossl.Error as e:
- if suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'):
- return b''
- raise socket.error(e.args[0])
- if time.time() - start > self.SSL_TIMEOUT:
- raise socket.timeout('timed out')
- def connect(self, address):
- self._conn.connect(address)
- if self._do_handshake_on_connect:
- self.do_handshake()
- def do_handshake(self):
- self._safe_ssl_call(False, self._conn.do_handshake)
- if self._check_hostname:
- verify_hostname(self, self._server_hostname)
- def recv(self, bufsize, flags=None):
- return self._safe_ssl_call(
- self._suppress_ragged_eofs,
- self._conn.recv,
- bufsize,
- flags
- )
- def recv_into(self, buffer, bufsize=None, flags=None):
- # A temporary recv_into implementation. Should be replaced when
- # PyOpenSSL has merged pyca/pyopenssl#121.
- if bufsize is None:
- bufsize = len(buffer)
- data = self.recv(bufsize, flags)
- data_len = len(data)
- buffer[0:data_len] = data
- return data_len
- def send(self, data, flags=None):
- return self._safe_ssl_call(False, self._conn.send, data, flags)
- def sendall(self, data, flags=None):
- return self._safe_ssl_call(False, self._conn.sendall, data, flags)
- def selected_npn_protocol(self):
- proto = self._conn.get_next_proto_negotiated()
- if isinstance(proto, bytes):
- proto = proto.decode('ascii')
- return proto if proto else None
- def selected_alpn_protocol(self):
- proto = self._conn.get_alpn_proto_negotiated()
- if isinstance(proto, bytes):
- proto = proto.decode('ascii')
- return proto if proto else None
- def getpeercert(self):
- def resolve_alias(alias):
- return dict(
- C='countryName',
- ST='stateOrProvinceName',
- L='localityName',
- O='organizationName',
- OU='organizationalUnitName',
- CN='commonName',
- ).get(alias, alias)
- def to_components(name):
- # TODO Verify that these are actually *supposed* to all be
- # single-element tuples, and that's not just a quirk of the
- # examples I've seen.
- return tuple(
- [
- (resolve_alias(k.decode('utf-8'), v.decode('utf-8')),)
- for k, v in name.get_components()
- ]
- )
- # The standard getpeercert() takes the nice X509 object tree returned
- # by OpenSSL and turns it into a dict according to some format it seems
- # to have made up on the spot. Here, we do our best to emulate that.
- cert = self._conn.get_peer_certificate()
- result = dict(
- issuer=to_components(cert.get_issuer()),
- subject=to_components(cert.get_subject()),
- version=cert.get_subject(),
- serialNumber=cert.get_serial_number(),
- notBefore=cert.get_notBefore(),
- notAfter=cert.get_notAfter(),
- )
- # TODO extensions, including subjectAltName
- # (see _decode_certificate in _ssl.c)
- return result
- # a dash of magic to reduce boilerplate
- methods = ['accept', 'bind', 'close', 'getsockname', 'listen', 'fileno']
- for method in methods:
- locals()[method] = _proxy(method)
- class SSLContext(object):
- def __init__(self, protocol):
- self.protocol = protocol
- self._ctx = ossl.Context(protocol)
- self.options = OP_ALL
- self.check_hostname = False
- self.npn_protos = []
- @property
- def options(self):
- return self._options
- @options.setter
- def options(self, value):
- self._options = value
- self._ctx.set_options(value)
- @property
- def verify_mode(self):
- return self._ctx.get_verify_mode()
- @verify_mode.setter
- def verify_mode(self, value):
- # TODO verify exception is raised on failure
- self._ctx.set_verify(
- value, lambda conn, cert, errnum, errdepth, ok: ok
- )
- def set_default_verify_paths(self):
- self._ctx.set_default_verify_paths()
- def load_verify_locations(self, cafile=None, capath=None, cadata=None):
- # TODO factor out common code
- if cafile is not None:
- cafile = cafile.encode('utf-8')
- if capath is not None:
- capath = capath.encode('utf-8')
- self._ctx.load_verify_locations(cafile, capath)
- if cadata is not None:
- self._ctx.load_verify_locations(BytesIO(cadata))
- def load_cert_chain(self, certfile, keyfile=None, password=None):
- self._ctx.use_certificate_file(certfile)
- if password is not None:
- self._ctx.set_passwd_cb(
- lambda max_length, prompt_twice, userdata: password
- )
- self._ctx.use_privatekey_file(keyfile or certfile)
- def set_npn_protocols(self, protocols):
- self.protocols = list(map(lambda x: x.encode('ascii'), protocols))
- def cb(conn, protos):
- # Detect the overlapping set of protocols.
- overlap = set(protos) & set(self.protocols)
- # Select the option that comes last in the list in the overlap.
- for p in self.protocols:
- if p in overlap:
- return p
- else:
- return b''
- self._ctx.set_npn_select_callback(cb)
- def set_alpn_protocols(self, protocols):
- protocols = list(map(lambda x: x.encode('ascii'), protocols))
- self._ctx.set_alpn_protos(protocols)
- def wrap_socket(self,
- sock,
- server_side=False,
- do_handshake_on_connect=True,
- suppress_ragged_eofs=True,
- server_hostname=None):
- conn = ossl.Connection(self._ctx, sock)
- return SSLSocket(conn, server_side, do_handshake_on_connect,
- suppress_ragged_eofs, server_hostname,
- # TODO what if this is changed after the fact?
- self.check_hostname)
|