ssl_compat.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # -*- coding: utf-8 -*-
  2. """
  3. hyper/ssl_compat
  4. ~~~~~~~~~
  5. Shoves pyOpenSSL into an API that looks like the standard Python 3.x ssl
  6. module.
  7. Currently exposes exactly those attributes, classes, and methods that we
  8. actually use in hyper (all method signatures are complete, however). May be
  9. expanded to something more general-purpose in the future.
  10. """
  11. try:
  12. import StringIO as BytesIO
  13. except ImportError:
  14. from io import BytesIO
  15. import errno
  16. import socket
  17. import time
  18. from OpenSSL import SSL as ossl
  19. from service_identity.pyopenssl import verify_hostname as _verify
  20. CERT_NONE = ossl.VERIFY_NONE
  21. CERT_REQUIRED = ossl.VERIFY_PEER | ossl.VERIFY_FAIL_IF_NO_PEER_CERT
  22. _OPENSSL_ATTRS = dict(
  23. OP_NO_COMPRESSION='OP_NO_COMPRESSION',
  24. PROTOCOL_TLSv1_2='TLSv1_2_METHOD',
  25. PROTOCOL_SSLv23='SSLv23_METHOD',
  26. )
  27. for external, internal in _OPENSSL_ATTRS.items():
  28. value = getattr(ossl, internal, None)
  29. if value:
  30. locals()[external] = value
  31. OP_ALL = 0
  32. # TODO: Find out the names of these other flags.
  33. for bit in [31] + list(range(10)):
  34. OP_ALL |= 1 << bit
  35. HAS_NPN = True
  36. def _proxy(method):
  37. def inner(self, *args, **kwargs):
  38. return getattr(self._conn, method)(*args, **kwargs)
  39. return inner
  40. # Referenced in hyper/http20/connection.py. These values come
  41. # from the python ssl package, and must be defined in this file
  42. # for hyper to work in python versions <2.7.9
  43. SSL_ERROR_WANT_READ = 2
  44. SSL_ERROR_WANT_WRITE = 3
  45. # TODO missing some attributes
  46. class SSLError(OSError):
  47. pass
  48. class CertificateError(SSLError):
  49. pass
  50. def verify_hostname(ssl_sock, server_hostname):
  51. """
  52. A method nearly compatible with the stdlib's match_hostname.
  53. """
  54. if isinstance(server_hostname, bytes):
  55. server_hostname = server_hostname.decode('ascii')
  56. return _verify(ssl_sock._conn, server_hostname)
  57. class SSLSocket(object):
  58. SSL_TIMEOUT = 3
  59. SSL_RETRY = .01
  60. def __init__(self, conn, server_side, do_handshake_on_connect,
  61. suppress_ragged_eofs, server_hostname, check_hostname):
  62. self._conn = conn
  63. self._do_handshake_on_connect = do_handshake_on_connect
  64. self._suppress_ragged_eofs = suppress_ragged_eofs
  65. self._check_hostname = check_hostname
  66. if server_side:
  67. self._conn.set_accept_state()
  68. else:
  69. if server_hostname:
  70. self._conn.set_tlsext_host_name(
  71. server_hostname.encode('utf-8')
  72. )
  73. self._server_hostname = server_hostname
  74. # FIXME does this override do_handshake_on_connect=False?
  75. self._conn.set_connect_state()
  76. if self.connected and self._do_handshake_on_connect:
  77. self.do_handshake()
  78. @property
  79. def connected(self):
  80. try:
  81. self._conn.getpeername()
  82. except socket.error as e:
  83. if e.errno != errno.ENOTCONN:
  84. # It's an exception other than the one we expected if we're not
  85. # connected.
  86. raise
  87. return False
  88. return True
  89. # Lovingly stolen from CherryPy
  90. # (http://svn.cherrypy.org/tags/cherrypy-3.2.1/cherrypy/wsgiserver/ssl_pyopenssl.py).
  91. def _safe_ssl_call(self, suppress_ragged_eofs, call, *args, **kwargs):
  92. """Wrap the given call with SSL error-trapping."""
  93. start = time.time()
  94. while True:
  95. try:
  96. return call(*args, **kwargs)
  97. except (ossl.WantReadError, ossl.WantWriteError):
  98. # Sleep and try again. This is dangerous, because it means
  99. # the rest of the stack has no way of differentiating
  100. # between a "new handshake" error and "client dropped".
  101. # Note this isn't an endless loop: there's a timeout below.
  102. time.sleep(self.SSL_RETRY)
  103. except ossl.Error as e:
  104. if suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'):
  105. return b''
  106. raise socket.error(e.args[0])
  107. if time.time() - start > self.SSL_TIMEOUT:
  108. raise socket.timeout('timed out')
  109. def connect(self, address):
  110. self._conn.connect(address)
  111. if self._do_handshake_on_connect:
  112. self.do_handshake()
  113. def do_handshake(self):
  114. self._safe_ssl_call(False, self._conn.do_handshake)
  115. if self._check_hostname:
  116. verify_hostname(self, self._server_hostname)
  117. def recv(self, bufsize, flags=None):
  118. return self._safe_ssl_call(
  119. self._suppress_ragged_eofs,
  120. self._conn.recv,
  121. bufsize,
  122. flags
  123. )
  124. def recv_into(self, buffer, bufsize=None, flags=None):
  125. # A temporary recv_into implementation. Should be replaced when
  126. # PyOpenSSL has merged pyca/pyopenssl#121.
  127. if bufsize is None:
  128. bufsize = len(buffer)
  129. data = self.recv(bufsize, flags)
  130. data_len = len(data)
  131. buffer[0:data_len] = data
  132. return data_len
  133. def send(self, data, flags=None):
  134. return self._safe_ssl_call(False, self._conn.send, data, flags)
  135. def sendall(self, data, flags=None):
  136. return self._safe_ssl_call(False, self._conn.sendall, data, flags)
  137. def selected_npn_protocol(self):
  138. proto = self._conn.get_next_proto_negotiated()
  139. if isinstance(proto, bytes):
  140. proto = proto.decode('ascii')
  141. return proto if proto else None
  142. def selected_alpn_protocol(self):
  143. proto = self._conn.get_alpn_proto_negotiated()
  144. if isinstance(proto, bytes):
  145. proto = proto.decode('ascii')
  146. return proto if proto else None
  147. def getpeercert(self):
  148. def resolve_alias(alias):
  149. return dict(
  150. C='countryName',
  151. ST='stateOrProvinceName',
  152. L='localityName',
  153. O='organizationName',
  154. OU='organizationalUnitName',
  155. CN='commonName',
  156. ).get(alias, alias)
  157. def to_components(name):
  158. # TODO Verify that these are actually *supposed* to all be
  159. # single-element tuples, and that's not just a quirk of the
  160. # examples I've seen.
  161. return tuple(
  162. [
  163. (resolve_alias(k.decode('utf-8'), v.decode('utf-8')),)
  164. for k, v in name.get_components()
  165. ]
  166. )
  167. # The standard getpeercert() takes the nice X509 object tree returned
  168. # by OpenSSL and turns it into a dict according to some format it seems
  169. # to have made up on the spot. Here, we do our best to emulate that.
  170. cert = self._conn.get_peer_certificate()
  171. result = dict(
  172. issuer=to_components(cert.get_issuer()),
  173. subject=to_components(cert.get_subject()),
  174. version=cert.get_subject(),
  175. serialNumber=cert.get_serial_number(),
  176. notBefore=cert.get_notBefore(),
  177. notAfter=cert.get_notAfter(),
  178. )
  179. # TODO extensions, including subjectAltName
  180. # (see _decode_certificate in _ssl.c)
  181. return result
  182. # a dash of magic to reduce boilerplate
  183. methods = ['accept', 'bind', 'close', 'getsockname', 'listen', 'fileno']
  184. for method in methods:
  185. locals()[method] = _proxy(method)
  186. class SSLContext(object):
  187. def __init__(self, protocol):
  188. self.protocol = protocol
  189. self._ctx = ossl.Context(protocol)
  190. self.options = OP_ALL
  191. self.check_hostname = False
  192. self.npn_protos = []
  193. @property
  194. def options(self):
  195. return self._options
  196. @options.setter
  197. def options(self, value):
  198. self._options = value
  199. self._ctx.set_options(value)
  200. @property
  201. def verify_mode(self):
  202. return self._ctx.get_verify_mode()
  203. @verify_mode.setter
  204. def verify_mode(self, value):
  205. # TODO verify exception is raised on failure
  206. self._ctx.set_verify(
  207. value, lambda conn, cert, errnum, errdepth, ok: ok
  208. )
  209. def set_default_verify_paths(self):
  210. self._ctx.set_default_verify_paths()
  211. def load_verify_locations(self, cafile=None, capath=None, cadata=None):
  212. # TODO factor out common code
  213. if cafile is not None:
  214. cafile = cafile.encode('utf-8')
  215. if capath is not None:
  216. capath = capath.encode('utf-8')
  217. self._ctx.load_verify_locations(cafile, capath)
  218. if cadata is not None:
  219. self._ctx.load_verify_locations(BytesIO(cadata))
  220. def load_cert_chain(self, certfile, keyfile=None, password=None):
  221. self._ctx.use_certificate_file(certfile)
  222. if password is not None:
  223. self._ctx.set_passwd_cb(
  224. lambda max_length, prompt_twice, userdata: password
  225. )
  226. self._ctx.use_privatekey_file(keyfile or certfile)
  227. def set_npn_protocols(self, protocols):
  228. self.protocols = list(map(lambda x: x.encode('ascii'), protocols))
  229. def cb(conn, protos):
  230. # Detect the overlapping set of protocols.
  231. overlap = set(protos) & set(self.protocols)
  232. # Select the option that comes last in the list in the overlap.
  233. for p in self.protocols:
  234. if p in overlap:
  235. return p
  236. else:
  237. return b''
  238. self._ctx.set_npn_select_callback(cb)
  239. def set_alpn_protocols(self, protocols):
  240. protocols = list(map(lambda x: x.encode('ascii'), protocols))
  241. self._ctx.set_alpn_protos(protocols)
  242. def wrap_socket(self,
  243. sock,
  244. server_side=False,
  245. do_handshake_on_connect=True,
  246. suppress_ragged_eofs=True,
  247. server_hostname=None):
  248. conn = ossl.Connection(self._ctx, sock)
  249. return SSLSocket(conn, server_side, do_handshake_on_connect,
  250. suppress_ragged_eofs, server_hostname,
  251. # TODO what if this is changed after the fact?
  252. self.check_hostname)