Преглед изворни кода

feat(api): add TOTP verification endpoint

Peter Thomassen пре 2 година
родитељ
комит
7b15dfb436

+ 35 - 0
api/desecapi/models/mfa.py

@@ -1,11 +1,15 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import base64
 import base64
+from datetime import datetime
+from functools import cached_property
 import secrets
 import secrets
 import uuid
 import uuid
 
 
+from django.conf import settings
 from django.db import models, transaction
 from django.db import models, transaction
 from django.utils import timezone
 from django.utils import timezone
+from pyotp import TOTP, utils as pyotp_utils
 
 
 
 
 class BaseFactor(models.Model):
 class BaseFactor(models.Model):
@@ -36,6 +40,37 @@ class TOTPFactor(BaseFactor):
     secret = models.BinaryField(max_length=32, default=_secret_default.__func__)
     secret = models.BinaryField(max_length=32, default=_secret_default.__func__)
     last_verified_timestep = models.PositiveIntegerField(default=0)
     last_verified_timestep = models.PositiveIntegerField(default=0)
 
 
+    @cached_property
+    def _totp(self):
+        # TODO switch to self.secret once https://github.com/pyauth/pyotp/pull/138 is released
+        return TOTP(self.base32_secret, digits=6)
+
     @property
     @property
     def base32_secret(self):
     def base32_secret(self):
         return base64.b32encode(self.secret).rstrip(b"=").decode("ascii")
         return base64.b32encode(self.secret).rstrip(b"=").decode("ascii")
+
+    @property
+    def uri(self):
+        return self._totp.provisioning_uri(
+            name=self.name,
+            issuer_name=f"desec.{settings.DESECSTACK_DOMAIN}",
+        )
+
+    @transaction.atomic
+    def verify(self, code):
+        now = timezone.now()
+        timestep_now = self._totp.timecode(now)
+
+        for offset in (-1, 0, 1):
+            timestep = timestep_now + offset
+            if not (self.last_verified_timestep < timestep):
+                continue
+            if pyotp_utils.strings_equal(str(code), self._totp.generate_otp(timestep)):
+                if not self.user.mfa_enabled:  # enabling MFA
+                    self.user.credentials_changed = now
+                    self.user.save()
+                self.last_used = now
+                self.last_verified_timestep = timestep
+                self.save()
+                return True
+        return False

+ 4 - 0
api/desecapi/models/users.py

@@ -82,6 +82,10 @@ class User(ExportModelOperationsMixin("User"), AbstractBaseUser):
         # Simplest possible answer: All admins are staff
         # Simplest possible answer: All admins are staff
         return self.is_admin
         return self.is_admin
 
 
+    @property
+    def mfa_enabled(self):
+        return self.basefactor_set.exclude(last_used__isnull=True).exists()
+
     def activate(self):
     def activate(self):
         self.is_active = True
         self.is_active = True
         self.needs_captcha = False
         self.needs_captcha = False

+ 1 - 1
api/desecapi/serializers/__init__.py

@@ -12,7 +12,7 @@ from .authenticated_actions import (
 from .captcha import CaptchaSerializer, CaptchaSolutionSerializer
 from .captcha import CaptchaSerializer, CaptchaSolutionSerializer
 from .domains import DomainSerializer
 from .domains import DomainSerializer
 from .donation import DonationSerializer
 from .donation import DonationSerializer
-from .mfa import TOTPFactorSerializer
+from .mfa import TOTPCodeSerializer, TOTPFactorSerializer
 from .records import RRsetSerializer
 from .records import RRsetSerializer
 from .tokens import TokenDomainPolicySerializer, TokenSerializer
 from .tokens import TokenDomainPolicySerializer, TokenSerializer
 from .users import (
 from .users import (

+ 17 - 2
api/desecapi/serializers/mfa.py

@@ -9,8 +9,8 @@ class TOTPFactorSerializer(serializers.ModelSerializer):
 
 
     class Meta:
     class Meta:
         model = TOTPFactor
         model = TOTPFactor
-        fields = ("id", "created", "last_used", "name", "secret", "user")
-        read_only_fields = ("id", "created", "last_used", "secret", "user")
+        fields = ("id", "created", "last_used", "name", "secret", "uri", "user")
+        read_only_fields = ("id", "created", "last_used", "secret", "uri", "user")
         extra_kwargs = {
         extra_kwargs = {
             # needed for uniqueness, https://github.com/encode/django-rest-framework/issues/7489
             # needed for uniqueness, https://github.com/encode/django-rest-framework/issues/7489
             "name": {"default": ""}
             "name": {"default": ""}
@@ -31,6 +31,7 @@ class TOTPFactorSerializer(serializers.ModelSerializer):
         fields = super().get_fields()
         fields = super().get_fields()
         if not self.include_secret:
         if not self.include_secret:
             fields.pop("secret")
             fields.pop("secret")
+            fields.pop("uri")
         return fields
         return fields
 
 
     def to_representation(self, instance):
     def to_representation(self, instance):
@@ -38,3 +39,17 @@ class TOTPFactorSerializer(serializers.ModelSerializer):
         if "secret" in ret:
         if "secret" in ret:
             ret["secret"] = instance.base32_secret
             ret["secret"] = instance.base32_secret
         return ret
         return ret
+
+
+class TOTPCodeSerializer(serializers.Serializer):
+    # length requirements preserve leading zeros
+    code = serializers.RegexField("^[0-9]+$", max_length=6, min_length=6)
+
+    class Meta:
+        fields = ("code",)
+
+    def validate_code(self, value):
+        factor = self.context["view"].get_object()
+        if not factor.verify(value):
+            raise serializers.ValidationError("Invalid code.")
+        return value

+ 71 - 2
api/desecapi/tests/test_totp.py

@@ -1,3 +1,6 @@
+from datetime import datetime, timedelta
+
+from pyotp import TOTP
 from rest_framework import status
 from rest_framework import status
 
 
 from desecapi.tests.base import DomainOwnerTestCase
 from desecapi.tests.base import DomainOwnerTestCase
@@ -30,10 +33,18 @@ class TOTPFactorTestCase(DomainOwnerTestCase):
         response = self.client.post(confirmation_link)
         response = self.client.post(confirmation_link)
         self.assertResponse(response, status.HTTP_200_OK)
         self.assertResponse(response, status.HTTP_200_OK)
         totp = response.data
         totp = response.data
-        self.assertEqual(totp.keys(), {"id", "created", "last_used", "name", "secret"})
+        self.assertEqual(
+            totp.keys(), {"id", "created", "last_used", "name", "secret", "uri"}
+        )
         self.assertEqual(totp["name"], "")
         self.assertEqual(totp["name"], "")
         self.assertIsNone(totp["last_used"])
         self.assertIsNone(totp["last_used"])
         self.assertRegex(totp["secret"], r"^[A-Z0-9]{52}$")  # 32 bytes make 52 chars
         self.assertRegex(totp["secret"], r"^[A-Z0-9]{52}$")  # 32 bytes make 52 chars
+        self.assertResponse(
+            self.assertRegex(
+                totp["uri"],
+                r"^otpauth://totp/.*:Secret[?]secret=[A-Z0-9]{52}&issuer=.*$",
+            )
+        )
         self.assertEqual(
         self.assertEqual(
             self.owner.basefactor_set.get().totpfactor.last_verified_timestep, 0
             self.owner.basefactor_set.get().totpfactor.last_verified_timestep, 0
         )
         )
@@ -41,9 +52,67 @@ class TOTPFactorTestCase(DomainOwnerTestCase):
         # Can't fetch the secret
         # Can't fetch the secret
         response = self.client.get(self.reverse("v1:totp-detail", pk=totp["id"]))
         response = self.client.get(self.reverse("v1:totp-detail", pk=totp["id"]))
         self.assertEqual(
         self.assertEqual(
-            response.data, {k: v for k, v in totp.items() if k != "secret"}
+            response.data, {k: v for k, v in totp.items() if k not in ("secret", "uri")}
         )
         )
 
 
         # Ensure that MFA is not active yet
         # Ensure that MFA is not active yet
         response = self.client.get(self.reverse("v1:domain-list"))
         response = self.client.get(self.reverse("v1:domain-list"))
         self.assertEqual(len(response.data), 2)
         self.assertEqual(len(response.data), 2)
+        self.assertFalse(self.owner.mfa_enabled)
+
+        # Verify requires a code
+        url = self.reverse("v1:totp-detail", pk=totp["id"]) + "verify/"
+        response = self.client.post(url)
+        self.assertResponse(
+            response, status.HTTP_400_BAD_REQUEST, {"code": ["This field is required."]}
+        )
+
+        # Wrong code won't work
+        now = datetime.now()
+        step = timedelta(seconds=30)
+        authenticator = TOTP(totp["secret"], digits=6)
+        url = self.reverse("v1:totp-detail", pk=totp["id"]) + "verify/"
+        for message, codes in {
+            "This field may not be blank.": [""],
+            "Invalid code.": [
+                "000000",
+                authenticator.at(now - 2 * step),
+                authenticator.at(now + 2 * step),
+            ],
+        }.items():
+            for code in codes:
+                response = self.client.post(url, {"code": code})
+                self.assertResponse(
+                    response, status.HTTP_400_BAD_REQUEST, {"code": [message]}
+                )
+
+        # Correct code works
+        credentials_changed = self.owner.credentials_changed
+        response = self.client.post(url, {"code": authenticator.at(now)})
+        self.assertResponse(
+            response, status.HTTP_200_OK, {"detail": "The code was correct."}
+        )
+        self.assertTrue(self.owner.mfa_enabled)
+        self.owner.refresh_from_db()
+
+        # Successful verification activates MFA and registers credential change
+        self.assertTrue(self.owner.mfa_enabled)
+        self.assertGreater(self.owner.credentials_changed, credentials_changed)
+
+        # Graceful validation window
+        factor = self.owner.basefactor_set.get().totpfactor
+        factor.last_verified_timestep -= 2
+        factor.save()
+        window_codes = [authenticator.at(now + i * step) for i in (-1, 0, 1)]
+        for code in window_codes:
+            response = self.client.post(url, {"code": code})
+            self.assertResponse(
+                response, status.HTTP_200_OK, {"detail": "The code was correct."}
+            )
+
+        # Replay won't work
+        for code in window_codes:
+            response = self.client.post(url, {"code": code})
+            self.assertResponse(
+                response, status.HTTP_400_BAD_REQUEST, {"code": ["Invalid code."]}
+            )

+ 10 - 0
api/desecapi/views/mfa.py

@@ -1,10 +1,12 @@
 from rest_framework import status, viewsets
 from rest_framework import status, viewsets
+from rest_framework.decorators import action
 from rest_framework.permissions import IsAuthenticated
 from rest_framework.permissions import IsAuthenticated
 from rest_framework.response import Response
 from rest_framework.response import Response
 
 
 from desecapi import permissions
 from desecapi import permissions
 from desecapi.serializers import (
 from desecapi.serializers import (
     AuthenticatedCreateTOTPFactorUserActionSerializer,
     AuthenticatedCreateTOTPFactorUserActionSerializer,
+    TOTPCodeSerializer,
     TOTPFactorSerializer,
     TOTPFactorSerializer,
 )
 )
 
 
@@ -31,3 +33,11 @@ class TOTPViewSet(IdempotentDestroyMixin, viewsets.ModelViewSet):
         AuthenticatedCreateTOTPFactorUserActionSerializer.build_and_save(
         AuthenticatedCreateTOTPFactorUserActionSerializer.build_and_save(
             user=self.request.user, name=serializer.validated_data.get("name", "")
             user=self.request.user, name=serializer.validated_data.get("name", "")
         )
         )
+
+    @action(detail=True, methods=["post"])
+    def verify(self, request, pk=None):
+        serializer = TOTPCodeSerializer(
+            data=request.data, context=self.get_serializer_context()
+        )
+        serializer.is_valid(raise_exception=True)
+        return Response({"detail": "The code was correct."})

+ 1 - 0
api/requirements.txt

@@ -11,6 +11,7 @@ django-pgtrigger~=2.5.1  # Upgrade to 3.x on occasion. Trigger management syntax
 django-prometheus~=2.2.0
 django-prometheus~=2.2.0
 dnspython~=2.2.0
 dnspython~=2.2.0
 httpretty~=1.0.5
 httpretty~=1.0.5
+pyotp~=2.6.0
 psycopg2~=2.9.2
 psycopg2~=2.9.2
 prometheus-client~=0.12.0  # added to control django-prometheus' dependency version
 prometheus-client~=0.12.0  # added to control django-prometheus' dependency version
 psl-dns~=1.1.0
 psl-dns~=1.1.0