aead.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. from __future__ import absolute_import, division, print_function
  5. from cryptography.exceptions import InvalidTag
  6. _ENCRYPT = 1
  7. _DECRYPT = 0
  8. def _aead_cipher_name(cipher):
  9. from cryptography.hazmat.primitives.ciphers.aead import (
  10. AESCCM,
  11. AESGCM,
  12. ChaCha20Poly1305,
  13. )
  14. if isinstance(cipher, ChaCha20Poly1305):
  15. return b"chacha20-poly1305"
  16. elif isinstance(cipher, AESCCM):
  17. return "aes-{}-ccm".format(len(cipher._key) * 8).encode("ascii")
  18. else:
  19. assert isinstance(cipher, AESGCM)
  20. return "aes-{}-gcm".format(len(cipher._key) * 8).encode("ascii")
  21. def _aead_setup(backend, cipher_name, key, nonce, tag, tag_len, operation):
  22. evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name)
  23. backend.openssl_assert(evp_cipher != backend._ffi.NULL)
  24. ctx = backend._lib.EVP_CIPHER_CTX_new()
  25. ctx = backend._ffi.gc(ctx, backend._lib.EVP_CIPHER_CTX_free)
  26. res = backend._lib.EVP_CipherInit_ex(
  27. ctx,
  28. evp_cipher,
  29. backend._ffi.NULL,
  30. backend._ffi.NULL,
  31. backend._ffi.NULL,
  32. int(operation == _ENCRYPT),
  33. )
  34. backend.openssl_assert(res != 0)
  35. res = backend._lib.EVP_CIPHER_CTX_set_key_length(ctx, len(key))
  36. backend.openssl_assert(res != 0)
  37. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  38. ctx,
  39. backend._lib.EVP_CTRL_AEAD_SET_IVLEN,
  40. len(nonce),
  41. backend._ffi.NULL,
  42. )
  43. backend.openssl_assert(res != 0)
  44. if operation == _DECRYPT:
  45. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  46. ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag
  47. )
  48. backend.openssl_assert(res != 0)
  49. elif cipher_name.endswith(b"-ccm"):
  50. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  51. ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, tag_len, backend._ffi.NULL
  52. )
  53. backend.openssl_assert(res != 0)
  54. nonce_ptr = backend._ffi.from_buffer(nonce)
  55. key_ptr = backend._ffi.from_buffer(key)
  56. res = backend._lib.EVP_CipherInit_ex(
  57. ctx,
  58. backend._ffi.NULL,
  59. backend._ffi.NULL,
  60. key_ptr,
  61. nonce_ptr,
  62. int(operation == _ENCRYPT),
  63. )
  64. backend.openssl_assert(res != 0)
  65. return ctx
  66. def _set_length(backend, ctx, data_len):
  67. intptr = backend._ffi.new("int *")
  68. res = backend._lib.EVP_CipherUpdate(
  69. ctx, backend._ffi.NULL, intptr, backend._ffi.NULL, data_len
  70. )
  71. backend.openssl_assert(res != 0)
  72. def _process_aad(backend, ctx, associated_data):
  73. outlen = backend._ffi.new("int *")
  74. res = backend._lib.EVP_CipherUpdate(
  75. ctx, backend._ffi.NULL, outlen, associated_data, len(associated_data)
  76. )
  77. backend.openssl_assert(res != 0)
  78. def _process_data(backend, ctx, data):
  79. outlen = backend._ffi.new("int *")
  80. buf = backend._ffi.new("unsigned char[]", len(data))
  81. res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
  82. backend.openssl_assert(res != 0)
  83. return backend._ffi.buffer(buf, outlen[0])[:]
  84. def _encrypt(backend, cipher, nonce, data, associated_data, tag_length):
  85. from cryptography.hazmat.primitives.ciphers.aead import AESCCM
  86. cipher_name = _aead_cipher_name(cipher)
  87. ctx = _aead_setup(
  88. backend, cipher_name, cipher._key, nonce, None, tag_length, _ENCRYPT
  89. )
  90. # CCM requires us to pass the length of the data before processing anything
  91. # However calling this with any other AEAD results in an error
  92. if isinstance(cipher, AESCCM):
  93. _set_length(backend, ctx, len(data))
  94. _process_aad(backend, ctx, associated_data)
  95. processed_data = _process_data(backend, ctx, data)
  96. outlen = backend._ffi.new("int *")
  97. res = backend._lib.EVP_CipherFinal_ex(ctx, backend._ffi.NULL, outlen)
  98. backend.openssl_assert(res != 0)
  99. backend.openssl_assert(outlen[0] == 0)
  100. tag_buf = backend._ffi.new("unsigned char[]", tag_length)
  101. res = backend._lib.EVP_CIPHER_CTX_ctrl(
  102. ctx, backend._lib.EVP_CTRL_AEAD_GET_TAG, tag_length, tag_buf
  103. )
  104. backend.openssl_assert(res != 0)
  105. tag = backend._ffi.buffer(tag_buf)[:]
  106. return processed_data + tag
  107. def _decrypt(backend, cipher, nonce, data, associated_data, tag_length):
  108. from cryptography.hazmat.primitives.ciphers.aead import AESCCM
  109. if len(data) < tag_length:
  110. raise InvalidTag
  111. tag = data[-tag_length:]
  112. data = data[:-tag_length]
  113. cipher_name = _aead_cipher_name(cipher)
  114. ctx = _aead_setup(
  115. backend, cipher_name, cipher._key, nonce, tag, tag_length, _DECRYPT
  116. )
  117. # CCM requires us to pass the length of the data before processing anything
  118. # However calling this with any other AEAD results in an error
  119. if isinstance(cipher, AESCCM):
  120. _set_length(backend, ctx, len(data))
  121. _process_aad(backend, ctx, associated_data)
  122. # CCM has a different error path if the tag doesn't match. Errors are
  123. # raised in Update and Final is irrelevant.
  124. if isinstance(cipher, AESCCM):
  125. outlen = backend._ffi.new("int *")
  126. buf = backend._ffi.new("unsigned char[]", len(data))
  127. res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
  128. if res != 1:
  129. backend._consume_errors()
  130. raise InvalidTag
  131. processed_data = backend._ffi.buffer(buf, outlen[0])[:]
  132. else:
  133. processed_data = _process_data(backend, ctx, data)
  134. outlen = backend._ffi.new("int *")
  135. res = backend._lib.EVP_CipherFinal_ex(ctx, backend._ffi.NULL, outlen)
  136. if res == 0:
  137. backend._consume_errors()
  138. raise InvalidTag
  139. return processed_data