You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
5.2 KiB

4 years ago
  1. # This file is dual licensed under the terms of the Apache License, Version
  2. # 2.0, and the BSD License. See the LICENSE file in the root of this repository
  3. # for complete details.
  4. from __future__ import absolute_import, division, print_function
  5. import base64
  6. import binascii
  7. import os
  8. import struct
  9. import time
  10. import six
  11. from cryptography.exceptions import InvalidSignature
  12. from cryptography.hazmat.backends import default_backend
  13. from cryptography.hazmat.primitives import hashes, padding
  14. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  15. from cryptography.hazmat.primitives.hmac import HMAC
  16. class InvalidToken(Exception):
  17. pass
  18. _MAX_CLOCK_SKEW = 60
  19. class Fernet(object):
  20. def __init__(self, key, backend=None):
  21. if backend is None:
  22. backend = default_backend()
  23. key = base64.urlsafe_b64decode(key)
  24. if len(key) != 32:
  25. raise ValueError(
  26. "Fernet key must be 32 url-safe base64-encoded bytes."
  27. )
  28. self._signing_key = key[:16]
  29. self._encryption_key = key[16:]
  30. self._backend = backend
  31. @classmethod
  32. def generate_key(cls):
  33. return base64.urlsafe_b64encode(os.urandom(32))
  34. def encrypt(self, data):
  35. current_time = int(time.time())
  36. iv = os.urandom(16)
  37. return self._encrypt_from_parts(data, current_time, iv)
  38. def _encrypt_from_parts(self, data, current_time, iv):
  39. if not isinstance(data, bytes):
  40. raise TypeError("data must be bytes.")
  41. padder = padding.PKCS7(algorithms.AES.block_size).padder()
  42. padded_data = padder.update(data) + padder.finalize()
  43. encryptor = Cipher(
  44. algorithms.AES(self._encryption_key), modes.CBC(iv), self._backend
  45. ).encryptor()
  46. ciphertext = encryptor.update(padded_data) + encryptor.finalize()
  47. basic_parts = (
  48. b"\x80" + struct.pack(">Q", current_time) + iv + ciphertext
  49. )
  50. h = HMAC(self._signing_key, hashes.SHA256(), backend=self._backend)
  51. h.update(basic_parts)
  52. hmac = h.finalize()
  53. return base64.urlsafe_b64encode(basic_parts + hmac)
  54. def decrypt(self, token, ttl=None):
  55. timestamp, data = Fernet._get_unverified_token_data(token)
  56. return self._decrypt_data(data, timestamp, ttl)
  57. def extract_timestamp(self, token):
  58. timestamp, data = Fernet._get_unverified_token_data(token)
  59. # Verify the token was not tampered with.
  60. self._verify_signature(data)
  61. return timestamp
  62. @staticmethod
  63. def _get_unverified_token_data(token):
  64. if not isinstance(token, bytes):
  65. raise TypeError("token must be bytes.")
  66. try:
  67. data = base64.urlsafe_b64decode(token)
  68. except (TypeError, binascii.Error):
  69. raise InvalidToken
  70. if not data or six.indexbytes(data, 0) != 0x80:
  71. raise InvalidToken
  72. try:
  73. timestamp, = struct.unpack(">Q", data[1:9])
  74. except struct.error:
  75. raise InvalidToken
  76. return timestamp, data
  77. def _verify_signature(self, data):
  78. h = HMAC(self._signing_key, hashes.SHA256(), backend=self._backend)
  79. h.update(data[:-32])
  80. try:
  81. h.verify(data[-32:])
  82. except InvalidSignature:
  83. raise InvalidToken
  84. def _decrypt_data(self, data, timestamp, ttl):
  85. current_time = int(time.time())
  86. if ttl is not None:
  87. if timestamp + ttl < current_time:
  88. raise InvalidToken
  89. if current_time + _MAX_CLOCK_SKEW < timestamp:
  90. raise InvalidToken
  91. self._verify_signature(data)
  92. iv = data[9:25]
  93. ciphertext = data[25:-32]
  94. decryptor = Cipher(
  95. algorithms.AES(self._encryption_key), modes.CBC(iv), self._backend
  96. ).decryptor()
  97. plaintext_padded = decryptor.update(ciphertext)
  98. try:
  99. plaintext_padded += decryptor.finalize()
  100. except ValueError:
  101. raise InvalidToken
  102. unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
  103. unpadded = unpadder.update(plaintext_padded)
  104. try:
  105. unpadded += unpadder.finalize()
  106. except ValueError:
  107. raise InvalidToken
  108. return unpadded
  109. class MultiFernet(object):
  110. def __init__(self, fernets):
  111. fernets = list(fernets)
  112. if not fernets:
  113. raise ValueError(
  114. "MultiFernet requires at least one Fernet instance"
  115. )
  116. self._fernets = fernets
  117. def encrypt(self, msg):
  118. return self._fernets[0].encrypt(msg)
  119. def rotate(self, msg):
  120. timestamp, data = Fernet._get_unverified_token_data(msg)
  121. for f in self._fernets:
  122. try:
  123. p = f._decrypt_data(data, timestamp, None)
  124. break
  125. except InvalidToken:
  126. pass
  127. else:
  128. raise InvalidToken
  129. iv = os.urandom(16)
  130. return self._fernets[0]._encrypt_from_parts(p, timestamp, iv)
  131. def decrypt(self, msg, ttl=None):
  132. for f in self._fernets:
  133. try:
  134. return f.decrypt(msg, ttl)
  135. except InvalidToken:
  136. pass
  137. raise InvalidToken