diff --git a/.gitignore b/.gitignore index 641714d163..e9bc835f1a 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ /env/ MANIFEST coverage.* - +venv/ !.github !.gitignore !.pre-commit-config.yaml diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 5f34b00194..4542337531 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -664,7 +664,24 @@ def run_child_validation(self, data): self.child.initial_data = data return super().run_child_validation(data) """ - return self.child.run_validation(data) + if not hasattr(self.child, 'instance'): + return self.child.run_validation(data) + + original_instance = self.child.instance + try: + if ( + hasattr(self, '_instance_map') and + isinstance(data, Mapping) and + original_instance is self.instance + ): + data_pk = data.get('id') + if data_pk is None: + data_pk = data.get('pk') + self.child.instance = self._instance_map.get(str(data_pk)) if data_pk is not None else None + + return self.child.run_validation(data) + finally: + self.child.instance = original_instance def to_internal_value(self, data): """ @@ -674,12 +691,16 @@ def to_internal_value(self, data): data = html.parse_html_list(data, default=[]) if not isinstance(data, list): - message = self.error_messages['not_a_list'].format( - input_type=type(data).__name__ - ) raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] - }, code='not_a_list') + api_settings.NON_FIELD_ERRORS_KEY: [ + ErrorDetail( + self.error_messages['not_a_list'].format( + input_type=type(data).__name__ + ), + code='not_a_list' + ) + ] + }) if not self.allow_empty and len(data) == 0: message = self.error_messages['empty'] @@ -702,19 +723,39 @@ def to_internal_value(self, data): ret = [] errors = [] - for item in data: - try: - validated = self.run_child_validation(item) - except ValidationError as exc: - errors.append(exc.detail) - else: - ret.append(validated) - errors.append({}) + # Build a primary key lookup for instance matching in many=True updates. + instance_map = {} + if self.instance is not None: + if isinstance(self.instance, Mapping): + instance_map = {str(k): v for k, v in self.instance.items()} + elif isinstance(self.instance, (list, tuple, models.query.QuerySet)): + for obj in self.instance: + pk = getattr(obj, 'pk', getattr(obj, 'id', None)) + if pk is not None: + key = str(pk) + # If duplicate keys are present, keep the last value, + # matching standard mapping assignment behavior. + instance_map[key] = obj + + self._instance_map = instance_map - if any(errors): - raise ValidationError(errors) + try: + for item in data: + try: + validated = self.run_child_validation(item) + except ValidationError as exc: + errors.append(exc.detail) + else: + ret.append(validated) + errors.append({}) - return ret + if any(errors): + raise ValidationError(errors) + + return ret + finally: + if hasattr(self, '_instance_map'): + delattr(self, '_instance_map') def to_representation(self, data): """ @@ -749,6 +790,13 @@ def save(self, **kwargs): """ Save and return a list of object instances. """ + assert hasattr(self, '_errors'), ( + 'You must call `.is_valid()` before calling `.save()`.' + ) + assert not self.errors, ( + 'You cannot call `.save()` on a serializer with invalid data.' + ) + # Guard against incorrect use of `serializer.save(commit=False)` assert 'commit' not in kwargs, ( "'commit' is not a valid keyword argument to the 'save()' method. " @@ -758,6 +806,14 @@ def save(self, **kwargs): "need to set extra attributes on the saved model instance. " "For example: 'serializer.save(owner=request.user)'.'" ) + assert not hasattr(self, '_data'), ( + "You cannot call `.save()` after accessing `serializer.data`." + "If you need to access data before committing to the database then " + "inspect 'serializer.validated_data' instead. " + ) + assert hasattr(self, '_validated_data'), ( + 'You must call `.is_valid()` before calling `.save()`.' + ) validated_data = [ {**attrs, **kwargs} for attrs in self.validated_data diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py index f76451a5ad..6db87aa6e8 100644 --- a/tests/test_serializer_lists.py +++ b/tests/test_serializer_lists.py @@ -775,3 +775,32 @@ def test(self): queryset = NullableOneToOneSource.objects.all() serializer = self.serializer(queryset, many=True) assert serializer.data + + +def test_many_true_instance_level_validation_uses_matched_instance(): + class Obj: + def __init__(self, id, valid): + self.id = id + self.valid = valid + + class TestSerializer(serializers.Serializer): + id = serializers.IntegerField() + status = serializers.CharField() + + def validate_status(self, value): + if self.instance is None: + raise serializers.ValidationError("Instance not matched") + if not self.instance.valid: + raise serializers.ValidationError("Invalid instance") + return value + + objs = [Obj(1, True), Obj(2, False)] + serializer = TestSerializer( + instance=objs, + data=[{"id": 1, "status": "ok"}, {"id": 2, "status": "fail"}], + many=True, + partial=True, + ) + + assert not serializer.is_valid() + assert serializer.errors == [{}, {'status': ['Invalid instance']}]