handler.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # -*- coding: utf-8 -*-
  2. """
  3. requests_toolbelt.auth.handler
  4. ==============================
  5. This holds all of the implementation details of the Authentication Handler.
  6. """
  7. from requests.auth import AuthBase, HTTPBasicAuth
  8. from requests.compat import urlparse, urlunparse
  9. class AuthHandler(AuthBase):
  10. """
  11. The ``AuthHandler`` object takes a dictionary of domains paired with
  12. authentication strategies and will use this to determine which credentials
  13. to use when making a request. For example, you could do the following:
  14. .. code-block:: python
  15. from requests import HTTPDigestAuth
  16. from requests_toolbelt.auth.handler import AuthHandler
  17. import requests
  18. auth = AuthHandler({
  19. 'https://api.github.com': ('sigmavirus24', 'fakepassword'),
  20. 'https://example.com': HTTPDigestAuth('username', 'password')
  21. })
  22. r = requests.get('https://api.github.com/user', auth=auth)
  23. # => <Response [200]>
  24. r = requests.get('https://example.com/some/path', auth=auth)
  25. # => <Response [200]>
  26. s = requests.Session()
  27. s.auth = auth
  28. r = s.get('https://api.github.com/user')
  29. # => <Response [200]>
  30. .. warning::
  31. :class:`requests.auth.HTTPDigestAuth` is not yet thread-safe. If you
  32. use :class:`AuthHandler` across multiple threads you should
  33. instantiate a new AuthHandler for each thread with a new
  34. HTTPDigestAuth instance for each thread.
  35. """
  36. def __init__(self, strategies):
  37. self.strategies = dict(strategies)
  38. self._make_uniform()
  39. def __call__(self, request):
  40. auth = self.get_strategy_for(request.url)
  41. return auth(request)
  42. def __repr__(self):
  43. return '<AuthHandler({0!r})>'.format(self.strategies)
  44. def _make_uniform(self):
  45. existing_strategies = list(self.strategies.items())
  46. self.strategies = {}
  47. for (k, v) in existing_strategies:
  48. self.add_strategy(k, v)
  49. @staticmethod
  50. def _key_from_url(url):
  51. parsed = urlparse(url)
  52. return urlunparse((parsed.scheme.lower(),
  53. parsed.netloc.lower(),
  54. '', '', '', ''))
  55. def add_strategy(self, domain, strategy):
  56. """Add a new domain and authentication strategy.
  57. :param str domain: The domain you wish to match against. For example:
  58. ``'https://api.github.com'``
  59. :param str strategy: The authentication strategy you wish to use for
  60. that domain. For example: ``('username', 'password')`` or
  61. ``requests.HTTPDigestAuth('username', 'password')``
  62. .. code-block:: python
  63. a = AuthHandler({})
  64. a.add_strategy('https://api.github.com', ('username', 'password'))
  65. """
  66. # Turn tuples into Basic Authentication objects
  67. if isinstance(strategy, tuple):
  68. strategy = HTTPBasicAuth(*strategy)
  69. key = self._key_from_url(domain)
  70. self.strategies[key] = strategy
  71. def get_strategy_for(self, url):
  72. """Retrieve the authentication strategy for a specified URL.
  73. :param str url: The full URL you will be making a request against. For
  74. example, ``'https://api.github.com/user'``
  75. :returns: Callable that adds authentication to a request.
  76. .. code-block:: python
  77. import requests
  78. a = AuthHandler({'example.com', ('foo', 'bar')})
  79. strategy = a.get_strategy_for('http://example.com/example')
  80. assert isinstance(strategy, requests.auth.HTTPBasicAuth)
  81. """
  82. key = self._key_from_url(url)
  83. return self.strategies.get(key, NullAuthStrategy())
  84. def remove_strategy(self, domain):
  85. """Remove the domain and strategy from the collection of strategies.
  86. :param str domain: The domain you wish remove. For example,
  87. ``'https://api.github.com'``.
  88. .. code-block:: python
  89. a = AuthHandler({'example.com', ('foo', 'bar')})
  90. a.remove_strategy('example.com')
  91. assert a.strategies == {}
  92. """
  93. key = self._key_from_url(domain)
  94. if key in self.strategies:
  95. del self.strategies[key]
  96. class NullAuthStrategy(AuthBase):
  97. def __repr__(self):
  98. return '<NullAuthStrategy>'
  99. def __call__(self, r):
  100. return r