oauth2_session.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. from __future__ import unicode_literals
  2. import logging
  3. from oauthlib.common import generate_token, urldecode
  4. from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError
  5. from oauthlib.oauth2 import TokenExpiredError, is_secure_transport
  6. import requests
  7. log = logging.getLogger(__name__)
  8. class TokenUpdated(Warning):
  9. def __init__(self, token):
  10. super(TokenUpdated, self).__init__()
  11. self.token = token
  12. class OAuth2Session(requests.Session):
  13. """Versatile OAuth 2 extension to :class:`requests.Session`.
  14. Supports any grant type adhering to :class:`oauthlib.oauth2.Client` spec
  15. including the four core OAuth 2 grants.
  16. Can be used to create authorization urls, fetch tokens and access protected
  17. resources using the :class:`requests.Session` interface you are used to.
  18. - :class:`oauthlib.oauth2.WebApplicationClient` (default): Authorization Code Grant
  19. - :class:`oauthlib.oauth2.MobileApplicationClient`: Implicit Grant
  20. - :class:`oauthlib.oauth2.LegacyApplicationClient`: Password Credentials Grant
  21. - :class:`oauthlib.oauth2.BackendApplicationClient`: Client Credentials Grant
  22. Note that the only time you will be using Implicit Grant from python is if
  23. you are driving a user agent able to obtain URL fragments.
  24. """
  25. def __init__(self, client_id=None, client=None, auto_refresh_url=None,
  26. auto_refresh_kwargs=None, scope=None, redirect_uri=None, token=None,
  27. state=None, token_updater=None, **kwargs):
  28. """Construct a new OAuth 2 client session.
  29. :param client_id: Client id obtained during registration
  30. :param client: :class:`oauthlib.oauth2.Client` to be used. Default is
  31. WebApplicationClient which is useful for any
  32. hosted application but not mobile or desktop.
  33. :param scope: List of scopes you wish to request access to
  34. :param redirect_uri: Redirect URI you registered as callback
  35. :param token: Token dictionary, must include access_token
  36. and token_type.
  37. :param state: State string used to prevent CSRF. This will be given
  38. when creating the authorization url and must be supplied
  39. when parsing the authorization response.
  40. Can be either a string or a no argument callable.
  41. :auto_refresh_url: Refresh token endpoint URL, must be HTTPS. Supply
  42. this if you wish the client to automatically refresh
  43. your access tokens.
  44. :auto_refresh_kwargs: Extra arguments to pass to the refresh token
  45. endpoint.
  46. :token_updater: Method with one argument, token, to be used to update
  47. your token databse on automatic token refresh. If not
  48. set a TokenUpdated warning will be raised when a token
  49. has been refreshed. This warning will carry the token
  50. in its token argument.
  51. :param kwargs: Arguments to pass to the Session constructor.
  52. """
  53. super(OAuth2Session, self).__init__(**kwargs)
  54. self._client = client or WebApplicationClient(client_id, token=token)
  55. self.token = token or {}
  56. self.scope = scope
  57. self.redirect_uri = redirect_uri
  58. self.state = state or generate_token
  59. self._state = state
  60. self.auto_refresh_url = auto_refresh_url
  61. self.auto_refresh_kwargs = auto_refresh_kwargs or {}
  62. self.token_updater = token_updater
  63. # Allow customizations for non compliant providers through various
  64. # hooks to adjust requests and responses.
  65. self.compliance_hook = {
  66. 'access_token_response': set([]),
  67. 'refresh_token_response': set([]),
  68. 'protected_request': set([]),
  69. }
  70. def new_state(self):
  71. """Generates a state string to be used in authorizations."""
  72. try:
  73. self._state = self.state()
  74. log.debug('Generated new state %s.', self._state)
  75. except TypeError:
  76. self._state = self.state
  77. log.debug('Re-using previously supplied state %s.', self._state)
  78. return self._state
  79. @property
  80. def client_id(self):
  81. return getattr(self._client, "client_id", None)
  82. @client_id.setter
  83. def client_id(self, value):
  84. self._client.client_id = value
  85. @client_id.deleter
  86. def client_id(self):
  87. del self._client.client_id
  88. @property
  89. def token(self):
  90. return getattr(self._client, "token", None)
  91. @token.setter
  92. def token(self, value):
  93. self._client.token = value
  94. self._client._populate_attributes(value)
  95. @property
  96. def access_token(self):
  97. return getattr(self._client, "access_token", None)
  98. @access_token.setter
  99. def access_token(self, value):
  100. self._client.access_token = value
  101. @access_token.deleter
  102. def access_token(self):
  103. del self._client.access_token
  104. @property
  105. def authorized(self):
  106. """Boolean that indicates whether this session has an OAuth token
  107. or not. If `self.authorized` is True, you can reasonably expect
  108. OAuth-protected requests to the resource to succeed. If
  109. `self.authorized` is False, you need the user to go through the OAuth
  110. authentication dance before OAuth-protected requests to the resource
  111. will succeed.
  112. """
  113. return bool(self.access_token)
  114. def authorization_url(self, url, state=None, **kwargs):
  115. """Form an authorization URL.
  116. :param url: Authorization endpoint url, must be HTTPS.
  117. :param state: An optional state string for CSRF protection. If not
  118. given it will be generated for you.
  119. :param kwargs: Extra parameters to include.
  120. :return: authorization_url, state
  121. """
  122. state = state or self.new_state()
  123. return self._client.prepare_request_uri(url,
  124. redirect_uri=self.redirect_uri,
  125. scope=self.scope,
  126. state=state,
  127. **kwargs), state
  128. def fetch_token(self, token_url, code=None, authorization_response=None,
  129. body='', auth=None, username=None, password=None, method='POST',
  130. timeout=None, headers=None, verify=True, proxies=None, **kwargs):
  131. """Generic method for fetching an access token from the token endpoint.
  132. If you are using the MobileApplicationClient you will want to use
  133. token_from_fragment instead of fetch_token.
  134. :param token_url: Token endpoint URL, must use HTTPS.
  135. :param code: Authorization code (used by WebApplicationClients).
  136. :param authorization_response: Authorization response URL, the callback
  137. URL of the request back to you. Used by
  138. WebApplicationClients instead of code.
  139. :param body: Optional application/x-www-form-urlencoded body to add the
  140. include in the token request. Prefer kwargs over body.
  141. :param auth: An auth tuple or method as accepted by requests.
  142. :param username: Username used by LegacyApplicationClients.
  143. :param password: Password used by LegacyApplicationClients.
  144. :param method: The HTTP method used to make the request. Defaults
  145. to POST, but may also be GET. Other methods should
  146. be added as needed.
  147. :param headers: Dict to default request headers with.
  148. :param timeout: Timeout of the request in seconds.
  149. :param verify: Verify SSL certificate.
  150. :param kwargs: Extra parameters to include in the token request.
  151. :return: A token dict
  152. """
  153. if not is_secure_transport(token_url):
  154. raise InsecureTransportError()
  155. if not code and authorization_response:
  156. self._client.parse_request_uri_response(authorization_response,
  157. state=self._state)
  158. code = self._client.code
  159. elif not code and isinstance(self._client, WebApplicationClient):
  160. code = self._client.code
  161. if not code:
  162. raise ValueError('Please supply either code or '
  163. 'authorization_response parameters.')
  164. body = self._client.prepare_request_body(code=code, body=body,
  165. redirect_uri=self.redirect_uri, username=username,
  166. password=password, **kwargs)
  167. client_id = kwargs.get('client_id', '')
  168. if auth is None:
  169. if client_id:
  170. log.debug('Encoding client_id "%s" with client_secret as Basic auth credentials.', client_id)
  171. client_secret = kwargs.get('client_secret', '')
  172. client_secret = client_secret if client_secret is not None else ''
  173. auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
  174. elif username:
  175. if password is None:
  176. raise ValueError('Username was supplied, but not password.')
  177. log.debug('Encoding username, password as Basic auth credentials.')
  178. auth = requests.auth.HTTPBasicAuth(username, password)
  179. headers = headers or {
  180. 'Accept': 'application/json',
  181. 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8',
  182. }
  183. self.token = {}
  184. if method.upper() == 'POST':
  185. r = self.post(token_url, data=dict(urldecode(body)),
  186. timeout=timeout, headers=headers, auth=auth,
  187. verify=verify, proxies=proxies)
  188. log.debug('Prepared fetch token request body %s', body)
  189. elif method.upper() == 'GET':
  190. # if method is not 'POST', switch body to querystring and GET
  191. r = self.get(token_url, params=dict(urldecode(body)),
  192. timeout=timeout, headers=headers, auth=auth,
  193. verify=verify, proxies=proxies)
  194. log.debug('Prepared fetch token request querystring %s', body)
  195. else:
  196. raise ValueError('The method kwarg must be POST or GET.')
  197. log.debug('Request to fetch token completed with status %s.',
  198. r.status_code)
  199. log.debug('Request headers were %s', r.request.headers)
  200. log.debug('Request body was %s', r.request.body)
  201. log.debug('Response headers were %s and content %s.',
  202. r.headers, r.text)
  203. log.debug('Invoking %d token response hooks.',
  204. len(self.compliance_hook['access_token_response']))
  205. for hook in self.compliance_hook['access_token_response']:
  206. log.debug('Invoking hook %s.', hook)
  207. r = hook(r)
  208. self._client.parse_request_body_response(r.text, scope=self.scope)
  209. self.token = self._client.token
  210. log.debug('Obtained token %s.', self.token)
  211. return self.token
  212. def token_from_fragment(self, authorization_response):
  213. """Parse token from the URI fragment, used by MobileApplicationClients.
  214. :param authorization_response: The full URL of the redirect back to you
  215. :return: A token dict
  216. """
  217. self._client.parse_request_uri_response(authorization_response,
  218. state=self._state)
  219. self.token = self._client.token
  220. return self.token
  221. def refresh_token(self, token_url, refresh_token=None, body='', auth=None,
  222. timeout=None, headers=None, verify=True, proxies=None, **kwargs):
  223. """Fetch a new access token using a refresh token.
  224. :param token_url: The token endpoint, must be HTTPS.
  225. :param refresh_token: The refresh_token to use.
  226. :param body: Optional application/x-www-form-urlencoded body to add the
  227. include in the token request. Prefer kwargs over body.
  228. :param auth: An auth tuple or method as accepted by requests.
  229. :param timeout: Timeout of the request in seconds.
  230. :param verify: Verify SSL certificate.
  231. :param kwargs: Extra parameters to include in the token request.
  232. :return: A token dict
  233. """
  234. if not token_url:
  235. raise ValueError('No token endpoint set for auto_refresh.')
  236. if not is_secure_transport(token_url):
  237. raise InsecureTransportError()
  238. refresh_token = refresh_token or self.token.get('refresh_token')
  239. log.debug('Adding auto refresh key word arguments %s.',
  240. self.auto_refresh_kwargs)
  241. kwargs.update(self.auto_refresh_kwargs)
  242. body = self._client.prepare_refresh_body(body=body,
  243. refresh_token=refresh_token, scope=self.scope, **kwargs)
  244. log.debug('Prepared refresh token request body %s', body)
  245. if headers is None:
  246. headers = {
  247. 'Accept': 'application/json',
  248. 'Content-Type': (
  249. 'application/x-www-form-urlencoded;charset=UTF-8'
  250. ),
  251. }
  252. r = self.post(token_url, data=dict(urldecode(body)), auth=auth,
  253. timeout=timeout, headers=headers, verify=verify, withhold_token=True, proxies=proxies)
  254. log.debug('Request to refresh token completed with status %s.',
  255. r.status_code)
  256. log.debug('Response headers were %s and content %s.',
  257. r.headers, r.text)
  258. log.debug('Invoking %d token response hooks.',
  259. len(self.compliance_hook['refresh_token_response']))
  260. for hook in self.compliance_hook['refresh_token_response']:
  261. log.debug('Invoking hook %s.', hook)
  262. r = hook(r)
  263. self.token = self._client.parse_request_body_response(r.text, scope=self.scope)
  264. if not 'refresh_token' in self.token:
  265. log.debug('No new refresh token given. Re-using old.')
  266. self.token['refresh_token'] = refresh_token
  267. return self.token
  268. def request(self, method, url, data=None, headers=None, withhold_token=False,
  269. client_id=None, client_secret=None, **kwargs):
  270. """Intercept all requests and add the OAuth 2 token if present."""
  271. if not is_secure_transport(url):
  272. raise InsecureTransportError()
  273. if self.token and not withhold_token:
  274. log.debug('Invoking %d protected resource request hooks.',
  275. len(self.compliance_hook['protected_request']))
  276. for hook in self.compliance_hook['protected_request']:
  277. log.debug('Invoking hook %s.', hook)
  278. url, headers, data = hook(url, headers, data)
  279. log.debug('Adding token %s to request.', self.token)
  280. try:
  281. url, headers, data = self._client.add_token(url,
  282. http_method=method, body=data, headers=headers)
  283. # Attempt to retrieve and save new access token if expired
  284. except TokenExpiredError:
  285. if self.auto_refresh_url:
  286. log.debug('Auto refresh is set, attempting to refresh at %s.',
  287. self.auto_refresh_url)
  288. # We mustn't pass auth twice.
  289. auth = kwargs.pop('auth', None)
  290. if client_id and client_secret and (auth is None):
  291. log.debug('Encoding client_id "%s" with client_secret as Basic auth credentials.', client_id)
  292. auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
  293. token = self.refresh_token(
  294. self.auto_refresh_url, auth=auth, **kwargs
  295. )
  296. if self.token_updater:
  297. log.debug('Updating token to %s using %s.',
  298. token, self.token_updater)
  299. self.token_updater(token)
  300. url, headers, data = self._client.add_token(url,
  301. http_method=method, body=data, headers=headers)
  302. else:
  303. raise TokenUpdated(token)
  304. else:
  305. raise
  306. log.debug('Requesting url %s using method %s.', url, method)
  307. log.debug('Supplying headers %s and data %s', headers, data)
  308. log.debug('Passing through key word arguments %s.', kwargs)
  309. return super(OAuth2Session, self).request(method, url,
  310. headers=headers, data=data, **kwargs)
  311. def register_compliance_hook(self, hook_type, hook):
  312. """Register a hook for request/response tweaking.
  313. Available hooks are:
  314. access_token_response invoked before token parsing.
  315. refresh_token_response invoked before refresh token parsing.
  316. protected_request invoked before making a request.
  317. If you find a new hook is needed please send a GitHub PR request
  318. or open an issue.
  319. """
  320. if hook_type not in self.compliance_hook:
  321. raise ValueError('Hook type %s is not in %s.',
  322. hook_type, self.compliance_hook)
  323. self.compliance_hook[hook_type].add(hook)