patch.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # -*- coding: utf-8 -*-
  2. # !/usr/bin/env python
  3. __import__('urllib3.contrib.pyopenssl')
  4. import requests
  5. from OpenSSL.crypto import PKCS12, X509, PKey
  6. def _is_key_file_encrypted(keyfile):
  7. '''In memory key is not encrypted'''
  8. if isinstance(keyfile, PKey):
  9. return False
  10. return _is_key_file_encrypted.original(keyfile)
  11. class PyOpenSSLContext(requests.packages.urllib3.contrib.pyopenssl.PyOpenSSLContext):
  12. def load_cert_chain(self, certfile, keyfile = None, password = None):
  13. if isinstance(certfile, X509) and isinstance(keyfile, PKey):
  14. self._ctx.use_certificate(certfile)
  15. self._ctx.use_privatekey(keyfile)
  16. else:
  17. super(PyOpenSSLContext, self).load_cert_chain(certfile, keyfile = keyfile, password = password)
  18. class HTTPAdapter(requests.adapters.HTTPAdapter):
  19. '''Handle a variety of cert types'''
  20. def cert_verify(self, conn, url, verify, cert):
  21. if cert:
  22. # PKCS12
  23. if isinstance(cert, PKCS12):
  24. conn.cert_file = cert.get_certificate()
  25. conn.key_file = cert.get_privatekey()
  26. cert = None
  27. elif isinstance(cert, tuple) and len(cert) == 2:
  28. # X509 and PKey
  29. if isinstance(cert[0], X509) and isinstance(cert[1], PKey):
  30. conn.cert_file = cert[0]
  31. conn.key_file = cert[1]
  32. cert = None
  33. # cryptography objects
  34. elif hasattr(cert[0], 'public_bytes') and hasattr(cert[1], 'private_bytes'):
  35. conn.cert_file = X509.from_cryptography(cert[0])
  36. conn.key_file = PKey.from_cryptography_key(cert[1])
  37. cert = None
  38. super(HTTPAdapter, self).cert_verify(conn, url, verify, cert)
  39. class OldMethod(object):
  40. old_amqp_channel_wait = None
  41. def amqp_channel_wait(self, allowed_methods = None, timeout = None):
  42. if not timeout:
  43. timeout = 1
  44. return OldMethod.old_amqp_channel_wait(self, allowed_methods, timeout)
  45. def patch_requests(adapter = True):
  46. '''You can perform a full patch and use requests as usual:
  47. >>> patch_requests()
  48. >>> requests.get('https://httpbin.org/get')
  49. or use the adapter explicitly:
  50. >>> patch_requests(adapter=False)
  51. >>> session = requests.Session()
  52. >>> session.mount('https', HTTPAdapter())
  53. >>> session.get('https://httpbin.org/get')
  54. '''
  55. if hasattr(requests.packages.urllib3.util.ssl_, '_is_key_file_encrypted'):
  56. _is_key_file_encrypted.original = requests.packages.urllib3.util.ssl_._is_key_file_encrypted
  57. requests.packages.urllib3.util.ssl_._is_key_file_encrypted = _is_key_file_encrypted
  58. requests.packages.urllib3.util.ssl_.SSLContext = PyOpenSSLContext
  59. if adapter:
  60. requests.sessions.HTTPAdapter = HTTPAdapter
  61. import amqp
  62. OldMethod.old_amqp_channel_wait = amqp.abstract_channel.AbstractChannel.wait
  63. amqp.abstract_channel.AbstractChannel.wait = amqp_channel_wait