소스 검색

feat(api): crypto.decrypt() now returns ciphertext's timestamp

Peter Thomassen 3 년 전
부모
커밋
2115b37570
4개의 변경된 파일12개의 추가작업 그리고 7개의 파일을 삭제
  1. 3 2
      api/desecapi/crypto.py
  2. 4 3
      api/desecapi/serializers.py
  3. 4 1
      api/desecapi/tests/test_crypto.py
  4. 1 1
      api/desecapi/tests/test_user_management.py

+ 3 - 2
api/desecapi/crypto.py

@@ -35,9 +35,10 @@ def encrypt(data, *, context):
 
 def decrypt(token, *, context, ttl=None):
     key = retrieve_key(label=b'crypt', context=context)
+    f = Fernet(key=key)
     try:
-        value = Fernet(key=key).decrypt(token, ttl=ttl)
+        ret = f.extract_timestamp(token), f.decrypt(token, ttl=ttl)
         metrics.get('desecapi_key_decryption_success').labels(context).inc()
-        return value
+        return ret
     except InvalidToken:
         raise ValueError

+ 4 - 3
api/desecapi/serializers.py

@@ -729,6 +729,7 @@ class AuthenticatedActionSerializer(serializers.ModelSerializer):
     validity_period = settings.VALIDITY_PERIOD_VERIFICATION_SIGNATURE
 
     _crypto_context = 'desecapi.serializers.AuthenticatedActionSerializer'
+    timestamp = None  # is set to the code's timestamp during validation
 
     class Meta:
         model = models.AuthenticatedAction
@@ -744,8 +745,8 @@ class AuthenticatedActionSerializer(serializers.ModelSerializer):
     def _unpack_code(cls, code, *, ttl):
         code += -len(code) % 4 * '='
         try:
-            payload = crypto.decrypt(code.encode(), context=cls._crypto_context, ttl=ttl)
-            return json.loads(payload.decode())
+            timestamp, payload = crypto.decrypt(code.encode(), context=cls._crypto_context, ttl=ttl)
+            return timestamp, json.loads(payload.decode())
         except (TypeError, UnicodeDecodeError, UnicodeEncodeError, json.JSONDecodeError, binascii.Error):
             raise ValueError
 
@@ -766,7 +767,7 @@ class AuthenticatedActionSerializer(serializers.ModelSerializer):
 
         # decode from single string
         try:
-            unpacked_data = self._unpack_code(self.context['code'], ttl=ttl)
+            self.timestamp, unpacked_data = self._unpack_code(self.context['code'], ttl=ttl)
         except KeyError:
             raise serializers.ValidationError({'code': ['This field is required.']})
         except ValueError:

+ 4 - 1
api/desecapi/tests/test_crypto.py

@@ -1,4 +1,5 @@
 from math import log
+import time
 
 from django.test import TestCase
 
@@ -42,7 +43,9 @@ class CryptoTestCase(TestCase):
     def test_encrypt_decrypt(self):
         plain = b'test'
         ciphertext = crypto.encrypt(plain, context=self.context)
-        self.assertEqual(plain, crypto.decrypt(ciphertext, context=self.context))
+        timestamp, decrypted = crypto.decrypt(ciphertext, context=self.context)
+        self.assertEqual(plain, decrypted)
+        self.assertTrue(0 <= time.time() - timestamp <= 1)
 
     def test_encrypt_decrypt_raises_on_tampering(self):
         ciphertext = crypto.encrypt(b'test', context=self.context)

+ 1 - 1
api/desecapi/tests/test_user_management.py

@@ -443,7 +443,7 @@ class UserManagementTestCase(DesecTestCase, PublicSuffixMockMixin):
         if tampered_domain is not None:
             path = urlparse(confirmation_link).path
             code = resolve(path).kwargs.get('code')
-            data = AuthenticatedActionSerializer._unpack_code(code, ttl=None)
+            _, data = AuthenticatedActionSerializer._unpack_code(code, ttl=None)
             data['domain'] = tampered_domain
             tampered_code = AuthenticatedActionSerializer._pack_code(data)
             confirmation_link = confirmation_link.replace(code, tampered_code)