validators.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from django.db import DataError
  2. from django.db.models import Model
  3. from rest_framework import serializers
  4. from rest_framework.exceptions import ValidationError
  5. from rest_framework.validators import qs_exists, qs_filter, UniqueTogetherValidator
  6. def qs_exclude(queryset, **kwargs):
  7. try:
  8. return queryset.exclude(**kwargs)
  9. except (TypeError, ValueError, DataError):
  10. return queryset.none()
  11. class ExclusionConstraintValidator(UniqueTogetherValidator):
  12. """
  13. Validator that implements ExclusionConstraints, currently very basic with support for one field only.
  14. Should be applied to the serializer class, not to an individual field.
  15. No-op if parent serializer is a list serializer (many=True). We expect the list serializer to assure exclusivity.
  16. """
  17. message = "This field violates an exclusion constraint."
  18. def __init__(self, queryset, fields, exclusion_condition, message=None):
  19. super().__init__(queryset, fields, message)
  20. self.exclusion_condition = exclusion_condition
  21. def filter_queryset(self, attrs, queryset, serializer):
  22. qs = super().filter_queryset(attrs, queryset, serializer)
  23. # Determine the exclusion filters and prepare the queryset.
  24. field_name = self.exclusion_condition[0]
  25. value = self.exclusion_condition[1]
  26. source = serializer.fields[field_name].source
  27. if serializer.instance is not None:
  28. if source not in attrs:
  29. attrs[source] = getattr(serializer.instance, source)
  30. exclusion_method = qs_exclude if attrs[source] == value else qs_filter
  31. return exclusion_method(qs, **{field_name: value})
  32. def __call__(self, attrs, serializer, *args, **kwargs):
  33. # Ignore validation if the many flag is set
  34. if getattr(serializer.root, "many", False):
  35. return
  36. self.enforce_required_fields(attrs, serializer)
  37. queryset = self.queryset
  38. queryset = self.filter_queryset(attrs, queryset, serializer)
  39. queryset = self.exclude_current_instance(attrs, queryset, serializer.instance)
  40. # Ignore validation if any field is None
  41. checked_values = [
  42. value for field, value in attrs.items() if field in self.fields
  43. ]
  44. if None not in checked_values and qs_exists(queryset):
  45. types = queryset.values_list("type", flat=True)
  46. types = ", ".join(types)
  47. message = self.message.format(types=types)
  48. raise ValidationError(message, code="exclusive")
  49. class Validator:
  50. message = "This field did not pass validation."
  51. def __init__(self, message=None):
  52. self.field_name = None
  53. self.message = message or self.message
  54. self.instance = None
  55. def __call__(self, value):
  56. raise NotImplementedError
  57. def __repr__(self):
  58. return "<%s>" % self.__class__.__name__
  59. class ReadOnlyOnUpdateValidator(Validator):
  60. message = "Can only be written on create."
  61. requires_context = True
  62. def __call__(self, value, serializer_field):
  63. field_name = serializer_field.source_attrs[-1]
  64. instance = getattr(serializer_field.parent, "instance", None)
  65. if isinstance(instance, Model) and value != getattr(instance, field_name):
  66. raise serializers.ValidationError(self.message, code="read-only-on-update")