concatkdf.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. from __future__ import absolute_import, division, print_function
  5. import struct
  6. from cryptography import utils
  7. from cryptography.exceptions import (
  8. AlreadyFinalized,
  9. InvalidKey,
  10. UnsupportedAlgorithm,
  11. _Reasons,
  12. )
  13. from cryptography.hazmat.backends import _get_backend
  14. from cryptography.hazmat.backends.interfaces import HMACBackend
  15. from cryptography.hazmat.backends.interfaces import HashBackend
  16. from cryptography.hazmat.primitives import constant_time, hashes, hmac
  17. from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
  18. def _int_to_u32be(n):
  19. return struct.pack(">I", n)
  20. def _common_args_checks(algorithm, length, otherinfo):
  21. max_length = algorithm.digest_size * (2 ** 32 - 1)
  22. if length > max_length:
  23. raise ValueError(
  24. "Can not derive keys larger than {} bits.".format(max_length)
  25. )
  26. if otherinfo is not None:
  27. utils._check_bytes("otherinfo", otherinfo)
  28. def _concatkdf_derive(key_material, length, auxfn, otherinfo):
  29. utils._check_byteslike("key_material", key_material)
  30. output = [b""]
  31. outlen = 0
  32. counter = 1
  33. while length > outlen:
  34. h = auxfn()
  35. h.update(_int_to_u32be(counter))
  36. h.update(key_material)
  37. h.update(otherinfo)
  38. output.append(h.finalize())
  39. outlen += len(output[-1])
  40. counter += 1
  41. return b"".join(output)[:length]
  42. @utils.register_interface(KeyDerivationFunction)
  43. class ConcatKDFHash(object):
  44. def __init__(self, algorithm, length, otherinfo, backend=None):
  45. backend = _get_backend(backend)
  46. _common_args_checks(algorithm, length, otherinfo)
  47. self._algorithm = algorithm
  48. self._length = length
  49. self._otherinfo = otherinfo
  50. if self._otherinfo is None:
  51. self._otherinfo = b""
  52. if not isinstance(backend, HashBackend):
  53. raise UnsupportedAlgorithm(
  54. "Backend object does not implement HashBackend.",
  55. _Reasons.BACKEND_MISSING_INTERFACE,
  56. )
  57. self._backend = backend
  58. self._used = False
  59. def _hash(self):
  60. return hashes.Hash(self._algorithm, self._backend)
  61. def derive(self, key_material):
  62. if self._used:
  63. raise AlreadyFinalized
  64. self._used = True
  65. return _concatkdf_derive(
  66. key_material, self._length, self._hash, self._otherinfo
  67. )
  68. def verify(self, key_material, expected_key):
  69. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  70. raise InvalidKey
  71. @utils.register_interface(KeyDerivationFunction)
  72. class ConcatKDFHMAC(object):
  73. def __init__(self, algorithm, length, salt, otherinfo, backend=None):
  74. backend = _get_backend(backend)
  75. _common_args_checks(algorithm, length, otherinfo)
  76. self._algorithm = algorithm
  77. self._length = length
  78. self._otherinfo = otherinfo
  79. if self._otherinfo is None:
  80. self._otherinfo = b""
  81. if salt is None:
  82. salt = b"\x00" * algorithm.block_size
  83. else:
  84. utils._check_bytes("salt", salt)
  85. self._salt = salt
  86. if not isinstance(backend, HMACBackend):
  87. raise UnsupportedAlgorithm(
  88. "Backend object does not implement HMACBackend.",
  89. _Reasons.BACKEND_MISSING_INTERFACE,
  90. )
  91. self._backend = backend
  92. self._used = False
  93. def _hmac(self):
  94. return hmac.HMAC(self._salt, self._algorithm, self._backend)
  95. def derive(self, key_material):
  96. if self._used:
  97. raise AlreadyFinalized
  98. self._used = True
  99. return _concatkdf_derive(
  100. key_material, self._length, self._hmac, self._otherinfo
  101. )
  102. def verify(self, key_material, expected_key):
  103. if not constant_time.bytes_eq(self.derive(key_material), expected_key):
  104. raise InvalidKey