pyopenssl_context.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # Copyright 2019-present MongoDB, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you
  4. # may not use this file except in compliance with the License. You
  5. # may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
  12. # implied. See the License for the specific language governing
  13. # permissions and limitations under the License.
  14. """A CPython compatible SSLContext implementation wrapping PyOpenSSL's
  15. context.
  16. """
  17. import socket as _socket
  18. import ssl as _stdlibssl
  19. from errno import EINTR as _EINTR
  20. # service_identity requires this for py27, so it should always be available
  21. from ipaddress import ip_address as _ip_address
  22. from OpenSSL import SSL as _SSL
  23. from service_identity.pyopenssl import (
  24. verify_hostname as _verify_hostname,
  25. verify_ip_address as _verify_ip_address)
  26. from service_identity import (
  27. CertificateError as _SICertificateError,
  28. VerificationError as _SIVerificationError)
  29. from cryptography.hazmat.backends import default_backend as _default_backend
  30. from bson.py3compat import _unicode
  31. from pymongo.errors import CertificateError as _CertificateError
  32. from pymongo.monotonic import time as _time
  33. from pymongo.ocsp_support import (
  34. _load_trusted_ca_certs,
  35. _ocsp_callback)
  36. from pymongo.ocsp_cache import _OCSPCache
  37. from pymongo.socket_checker import (
  38. _errno_from_exception, SocketChecker as _SocketChecker)
  39. PROTOCOL_SSLv23 = _SSL.SSLv23_METHOD
  40. # Always available
  41. OP_NO_SSLv2 = _SSL.OP_NO_SSLv2
  42. OP_NO_SSLv3 = _SSL.OP_NO_SSLv3
  43. OP_NO_COMPRESSION = _SSL.OP_NO_COMPRESSION
  44. # This isn't currently documented for PyOpenSSL
  45. OP_NO_RENEGOTIATION = getattr(_SSL, "OP_NO_RENEGOTIATION", 0)
  46. # Always available
  47. HAS_SNI = True
  48. CHECK_HOSTNAME_SAFE = True
  49. IS_PYOPENSSL = True
  50. # Base Exception class
  51. SSLError = _SSL.Error
  52. # https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L2995-L3002
  53. _VERIFY_MAP = {
  54. _stdlibssl.CERT_NONE: _SSL.VERIFY_NONE,
  55. _stdlibssl.CERT_OPTIONAL: _SSL.VERIFY_PEER,
  56. _stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT
  57. }
  58. _REVERSE_VERIFY_MAP = dict(
  59. (value, key) for key, value in _VERIFY_MAP.items())
  60. def _is_ip_address(address):
  61. try:
  62. _ip_address(_unicode(address))
  63. return True
  64. except (ValueError, UnicodeError):
  65. return False
  66. # According to the docs for Connection.send it can raise
  67. # WantX509LookupError and should be retried.
  68. _RETRY_ERRORS = (
  69. _SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)
  70. def _ragged_eof(exc):
  71. """Return True if the OpenSSL.SSL.SysCallError is a ragged EOF."""
  72. return exc.args == (-1, 'Unexpected EOF')
  73. # https://github.com/pyca/pyopenssl/issues/168
  74. # https://github.com/pyca/pyopenssl/issues/176
  75. # https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
  76. class _sslConn(_SSL.Connection):
  77. def __init__(self, ctx, sock, suppress_ragged_eofs):
  78. self.socket_checker = _SocketChecker()
  79. self.suppress_ragged_eofs = suppress_ragged_eofs
  80. super(_sslConn, self).__init__(ctx, sock)
  81. def _call(self, call, *args, **kwargs):
  82. timeout = self.gettimeout()
  83. if timeout:
  84. start = _time()
  85. while True:
  86. try:
  87. return call(*args, **kwargs)
  88. except _RETRY_ERRORS:
  89. self.socket_checker.select(
  90. self, True, True, timeout)
  91. if timeout and _time() - start > timeout:
  92. raise _socket.timeout("timed out")
  93. continue
  94. def do_handshake(self, *args, **kwargs):
  95. return self._call(super(_sslConn, self).do_handshake, *args, **kwargs)
  96. def recv(self, *args, **kwargs):
  97. try:
  98. return self._call(super(_sslConn, self).recv, *args, **kwargs)
  99. except _SSL.SysCallError as exc:
  100. # Suppress ragged EOFs to match the stdlib.
  101. if self.suppress_ragged_eofs and _ragged_eof(exc):
  102. return b""
  103. raise
  104. def recv_into(self, *args, **kwargs):
  105. try:
  106. return self._call(super(_sslConn, self).recv_into, *args, **kwargs)
  107. except _SSL.SysCallError as exc:
  108. # Suppress ragged EOFs to match the stdlib.
  109. if self.suppress_ragged_eofs and _ragged_eof(exc):
  110. return 0
  111. raise
  112. def sendall(self, buf, flags=0):
  113. view = memoryview(buf)
  114. total_length = len(buf)
  115. total_sent = 0
  116. sent = 0
  117. while total_sent < total_length:
  118. try:
  119. sent = self._call(
  120. super(_sslConn, self).send, view[total_sent:], flags)
  121. # XXX: It's not clear if this can actually happen. PyOpenSSL
  122. # doesn't appear to have any interrupt handling, nor any interrupt
  123. # errors for OpenSSL connections.
  124. except (IOError, OSError) as exc:
  125. if _errno_from_exception(exc) == _EINTR:
  126. continue
  127. raise
  128. # https://github.com/pyca/pyopenssl/blob/19.1.0/src/OpenSSL/SSL.py#L1756
  129. # https://www.openssl.org/docs/man1.0.2/man3/SSL_write.html
  130. if sent <= 0:
  131. raise Exception("Connection closed")
  132. total_sent += sent
  133. class _CallbackData(object):
  134. """Data class which is passed to the OCSP callback."""
  135. def __init__(self):
  136. self.trusted_ca_certs = None
  137. self.check_ocsp_endpoint = None
  138. self.ocsp_response_cache = _OCSPCache()
  139. class SSLContext(object):
  140. """A CPython compatible SSLContext implementation wrapping PyOpenSSL's
  141. context.
  142. """
  143. __slots__ = ('_protocol', '_ctx', '_callback_data', '_check_hostname')
  144. def __init__(self, protocol):
  145. self._protocol = protocol
  146. self._ctx = _SSL.Context(self._protocol)
  147. self._callback_data = _CallbackData()
  148. self._check_hostname = True
  149. # OCSP
  150. # XXX: Find a better place to do this someday, since this is client
  151. # side configuration and wrap_socket tries to support both client and
  152. # server side sockets.
  153. self._callback_data.check_ocsp_endpoint = True
  154. self._ctx.set_ocsp_client_callback(
  155. callback=_ocsp_callback, data=self._callback_data)
  156. @property
  157. def protocol(self):
  158. """The protocol version chosen when constructing the context.
  159. This attribute is read-only.
  160. """
  161. return self._protocol
  162. def __get_verify_mode(self):
  163. """Whether to try to verify other peers' certificates and how to
  164. behave if verification fails. This attribute must be one of
  165. ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED.
  166. """
  167. return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()]
  168. def __set_verify_mode(self, value):
  169. """Setter for verify_mode."""
  170. def _cb(connobj, x509obj, errnum, errdepth, retcode):
  171. # It seems we don't need to do anything here. Twisted doesn't,
  172. # and OpenSSL's SSL_CTX_set_verify let's you pass NULL
  173. # for the callback option. It's weird that PyOpenSSL requires
  174. # this.
  175. return retcode
  176. self._ctx.set_verify(_VERIFY_MAP[value], _cb)
  177. verify_mode = property(__get_verify_mode, __set_verify_mode)
  178. def __get_check_hostname(self):
  179. return self._check_hostname
  180. def __set_check_hostname(self, value):
  181. if not isinstance(value, bool):
  182. raise TypeError("check_hostname must be True or False")
  183. self._check_hostname = value
  184. check_hostname = property(__get_check_hostname, __set_check_hostname)
  185. def __get_check_ocsp_endpoint(self):
  186. return self._callback_data.check_ocsp_endpoint
  187. def __set_check_ocsp_endpoint(self, value):
  188. if not isinstance(value, bool):
  189. raise TypeError("check_ocsp must be True or False")
  190. self._callback_data.check_ocsp_endpoint = value
  191. check_ocsp_endpoint = property(__get_check_ocsp_endpoint,
  192. __set_check_ocsp_endpoint)
  193. def __get_options(self):
  194. # Calling set_options adds the option to the existing bitmask and
  195. # returns the new bitmask.
  196. # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options
  197. return self._ctx.set_options(0)
  198. def __set_options(self, value):
  199. # Explcitly convert to int, since newer CPython versions
  200. # use enum.IntFlag for options. The values are the same
  201. # regardless of implementation.
  202. self._ctx.set_options(int(value))
  203. options = property(__get_options, __set_options)
  204. def load_cert_chain(self, certfile, keyfile=None, password=None):
  205. """Load a private key and the corresponding certificate. The certfile
  206. string must be the path to a single file in PEM format containing the
  207. certificate as well as any number of CA certificates needed to
  208. establish the certificate's authenticity. The keyfile string, if
  209. present, must point to a file containing the private key. Otherwise
  210. the private key will be taken from certfile as well.
  211. """
  212. # Match CPython behavior
  213. # https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L3930-L3971
  214. # Password callback MUST be set first or it will be ignored.
  215. if password:
  216. def _pwcb(max_length, prompt_twice, user_data):
  217. # XXX:We could check the password length against what OpenSSL
  218. # tells us is the max, but we can't raise an exception, so...
  219. # warn?
  220. return password.encode('utf-8')
  221. self._ctx.set_passwd_cb(_pwcb)
  222. self._ctx.use_certificate_chain_file(certfile)
  223. self._ctx.use_privatekey_file(keyfile or certfile)
  224. self._ctx.check_privatekey()
  225. def load_verify_locations(self, cafile=None, capath=None):
  226. """Load a set of "certification authority"(CA) certificates used to
  227. validate other peers' certificates when `~verify_mode` is other than
  228. ssl.CERT_NONE.
  229. """
  230. self._ctx.load_verify_locations(cafile, capath)
  231. self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile)
  232. def set_default_verify_paths(self):
  233. """Specify that the platform provided CA certificates are to be used
  234. for verification purposes."""
  235. # Note: See PyOpenSSL's docs for limitations, which are similar
  236. # but not that same as CPython's.
  237. self._ctx.set_default_verify_paths()
  238. def wrap_socket(self, sock, server_side=False,
  239. do_handshake_on_connect=True,
  240. suppress_ragged_eofs=True,
  241. server_hostname=None, session=None):
  242. """Wrap an existing Python socket sock and return a TLS socket
  243. object.
  244. """
  245. ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
  246. if session:
  247. ssl_conn.set_session(session)
  248. if server_side is True:
  249. ssl_conn.set_accept_state()
  250. else:
  251. # SNI
  252. if server_hostname and not _is_ip_address(server_hostname):
  253. # XXX: Do this in a callback registered with
  254. # SSLContext.set_info_callback? See Twisted for an example.
  255. ssl_conn.set_tlsext_host_name(server_hostname.encode('idna'))
  256. if self.verify_mode != _stdlibssl.CERT_NONE:
  257. # Request a stapled OCSP response.
  258. ssl_conn.request_ocsp()
  259. ssl_conn.set_connect_state()
  260. # If this wasn't true the caller of wrap_socket would call
  261. # do_handshake()
  262. if do_handshake_on_connect:
  263. # XXX: If we do hostname checking in a callback we can get rid
  264. # of this call to do_handshake() since the handshake
  265. # will happen automatically later.
  266. ssl_conn.do_handshake()
  267. # XXX: Do this in a callback registered with
  268. # SSLContext.set_info_callback? See Twisted for an example.
  269. if self.check_hostname and server_hostname is not None:
  270. try:
  271. if _is_ip_address(server_hostname):
  272. _verify_ip_address(ssl_conn, _unicode(server_hostname))
  273. else:
  274. _verify_hostname(ssl_conn, _unicode(server_hostname))
  275. except (_SICertificateError, _SIVerificationError) as exc:
  276. raise _CertificateError(str(exc))
  277. return ssl_conn