api_jws.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import binascii
  2. import json
  3. import warnings
  4. from collections import Mapping
  5. from .algorithms import (
  6. Algorithm, get_default_algorithms, has_crypto, requires_cryptography # NOQA
  7. )
  8. from .compat import binary_type, string_types, text_type
  9. from .exceptions import (
  10. DecodeError, InvalidAlgorithmError, InvalidSignatureError,
  11. InvalidTokenError
  12. )
  13. from .utils import base64url_decode, base64url_encode, force_bytes, merge_dict
  14. class PyJWS(object):
  15. header_typ = 'JWT'
  16. def __init__(self, algorithms=None, options=None):
  17. self._algorithms = get_default_algorithms()
  18. self._valid_algs = (set(algorithms) if algorithms is not None
  19. else set(self._algorithms))
  20. # Remove algorithms that aren't on the whitelist
  21. for key in list(self._algorithms.keys()):
  22. if key not in self._valid_algs:
  23. del self._algorithms[key]
  24. if not options:
  25. options = {}
  26. self.options = merge_dict(self._get_default_options(), options)
  27. @staticmethod
  28. def _get_default_options():
  29. return {
  30. 'verify_signature': True
  31. }
  32. def register_algorithm(self, alg_id, alg_obj):
  33. """
  34. Registers a new Algorithm for use when creating and verifying tokens.
  35. """
  36. if alg_id in self._algorithms:
  37. raise ValueError('Algorithm already has a handler.')
  38. if not isinstance(alg_obj, Algorithm):
  39. raise TypeError('Object is not of type `Algorithm`')
  40. self._algorithms[alg_id] = alg_obj
  41. self._valid_algs.add(alg_id)
  42. def unregister_algorithm(self, alg_id):
  43. """
  44. Unregisters an Algorithm for use when creating and verifying tokens
  45. Throws KeyError if algorithm is not registered.
  46. """
  47. if alg_id not in self._algorithms:
  48. raise KeyError('The specified algorithm could not be removed'
  49. ' because it is not registered.')
  50. del self._algorithms[alg_id]
  51. self._valid_algs.remove(alg_id)
  52. def get_algorithms(self):
  53. """
  54. Returns a list of supported values for the 'alg' parameter.
  55. """
  56. return list(self._valid_algs)
  57. def encode(self, payload, key, algorithm='HS256', headers=None,
  58. json_encoder=None):
  59. segments = []
  60. if algorithm is None:
  61. algorithm = 'none'
  62. if algorithm not in self._valid_algs:
  63. pass
  64. # Header
  65. header = {'typ': self.header_typ, 'alg': algorithm}
  66. if headers:
  67. self._validate_headers(headers)
  68. header.update(headers)
  69. json_header = force_bytes(
  70. json.dumps(
  71. header,
  72. separators=(',', ':'),
  73. cls=json_encoder
  74. )
  75. )
  76. segments.append(base64url_encode(json_header))
  77. segments.append(base64url_encode(payload))
  78. # Segments
  79. signing_input = b'.'.join(segments)
  80. try:
  81. alg_obj = self._algorithms[algorithm]
  82. key = alg_obj.prepare_key(key)
  83. signature = alg_obj.sign(signing_input, key)
  84. except KeyError:
  85. if not has_crypto and algorithm in requires_cryptography:
  86. raise NotImplementedError(
  87. "Algorithm '%s' could not be found. Do you have cryptography "
  88. "installed?" % algorithm
  89. )
  90. else:
  91. raise NotImplementedError('Algorithm not supported')
  92. segments.append(base64url_encode(signature))
  93. return b'.'.join(segments)
  94. def decode(self, jws, key='', verify=True, algorithms=None, options=None,
  95. **kwargs):
  96. merged_options = merge_dict(self.options, options)
  97. verify_signature = merged_options['verify_signature']
  98. if verify_signature and not algorithms:
  99. warnings.warn(
  100. 'It is strongly recommended that you pass in a ' +
  101. 'value for the "algorithms" argument when calling decode(). ' +
  102. 'This argument will be mandatory in a future version.',
  103. DeprecationWarning
  104. )
  105. payload, signing_input, header, signature = self._load(jws)
  106. if not verify:
  107. warnings.warn('The verify parameter is deprecated. '
  108. 'Please use verify_signature in options instead.',
  109. DeprecationWarning, stacklevel=2)
  110. elif verify_signature:
  111. self._verify_signature(payload, signing_input, header, signature,
  112. key, algorithms)
  113. return payload
  114. def get_unverified_header(self, jwt):
  115. """Returns back the JWT header parameters as a dict()
  116. Note: The signature is not verified so the header parameters
  117. should not be fully trusted until signature verification is complete
  118. """
  119. headers = self._load(jwt)[2]
  120. self._validate_headers(headers)
  121. return headers
  122. def _load(self, jwt):
  123. if isinstance(jwt, text_type):
  124. jwt = jwt.encode('utf-8')
  125. if not issubclass(type(jwt), binary_type):
  126. raise DecodeError("Invalid token type. Token must be a {0}".format(
  127. binary_type))
  128. try:
  129. signing_input, crypto_segment = jwt.rsplit(b'.', 1)
  130. header_segment, payload_segment = signing_input.split(b'.', 1)
  131. except ValueError:
  132. raise DecodeError('Not enough segments')
  133. try:
  134. header_data = base64url_decode(header_segment)
  135. except (TypeError, binascii.Error):
  136. raise DecodeError('Invalid header padding')
  137. try:
  138. header = json.loads(header_data.decode('utf-8'))
  139. except ValueError as e:
  140. raise DecodeError('Invalid header string: %s' % e)
  141. if not isinstance(header, Mapping):
  142. raise DecodeError('Invalid header string: must be a json object')
  143. try:
  144. payload = base64url_decode(payload_segment)
  145. except (TypeError, binascii.Error):
  146. raise DecodeError('Invalid payload padding')
  147. try:
  148. signature = base64url_decode(crypto_segment)
  149. except (TypeError, binascii.Error):
  150. raise DecodeError('Invalid crypto padding')
  151. return (payload, signing_input, header, signature)
  152. def _verify_signature(self, payload, signing_input, header, signature,
  153. key='', algorithms=None):
  154. alg = header.get('alg')
  155. if algorithms is not None and alg not in algorithms:
  156. raise InvalidAlgorithmError('The specified alg value is not allowed')
  157. try:
  158. alg_obj = self._algorithms[alg]
  159. key = alg_obj.prepare_key(key)
  160. if not alg_obj.verify(signing_input, key, signature):
  161. raise InvalidSignatureError('Signature verification failed')
  162. except KeyError:
  163. raise InvalidAlgorithmError('Algorithm not supported')
  164. def _validate_headers(self, headers):
  165. if 'kid' in headers:
  166. self._validate_kid(headers['kid'])
  167. def _validate_kid(self, kid):
  168. if not isinstance(kid, string_types):
  169. raise InvalidTokenError('Key ID header parameter must be a string')
  170. _jws_global_obj = PyJWS()
  171. encode = _jws_global_obj.encode
  172. decode = _jws_global_obj.decode
  173. register_algorithm = _jws_global_obj.register_algorithm
  174. unregister_algorithm = _jws_global_obj.unregister_algorithm
  175. get_unverified_header = _jws_global_obj.get_unverified_header