diff --git a/api/environments/serializers.py b/api/environments/serializers.py index c6a3ae8724b8..448c8d261e46 100644 --- a/api/environments/serializers.py +++ b/api/environments/serializers.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any +from typing import Any from rest_framework import serializers @@ -117,7 +117,7 @@ class Meta(EnvironmentSerializerWithMetadata.Meta): ) -class CreateUpdateEnvironmentSerializer( +class _BaseCreateUpdateEnvironmentSerializer( ReadOnlyIfNotValidPlanMixin, EnvironmentSerializerWithMetadata ): invalid_plans = ("free",) @@ -132,30 +132,32 @@ class Meta(EnvironmentSerializerWithMetadata.Meta): ) ] + +class CreateEnvironmentSerializer(_BaseCreateUpdateEnvironmentSerializer): def get_subscription(self) -> Subscription | None: view = self.context["view"] + if getattr(view, "swagger_fake_view", False): + return None + + project_id = view.request.data["project"] + project = Project.objects.select_related( + "organisation", "organisation__subscription" + ).get(id=project_id) + + return getattr(project.organisation, "subscription", None) - if view.action == "create": - # handle `project` not being part of the data - # When request comes from drf-spectacular (as part of schema generation) - project_id = view.request.data.get("project") - if not project_id: - return None - - project = Project.objects.select_related( - "organisation", "organisation__subscription" - ).get(id=project_id) - - return getattr(project.organisation, "subscription", None) - elif view.action in ("update", "partial_update"): - # Handle schema generation when instance is None. - if self.instance is None: - return None - if TYPE_CHECKING: - assert isinstance(self.instance, Environment) - return getattr(self.instance.project.organisation, "subscription", None) - - return None + +class UpdateEnvironmentSerializer(_BaseCreateUpdateEnvironmentSerializer): + class Meta(_BaseCreateUpdateEnvironmentSerializer.Meta): + read_only_fields = EnvironmentSerializerLight.Meta.read_only_fields + ( # type: ignore[assignment] + "project", + ) + + def get_subscription(self) -> Subscription | None: + view = self.context["view"] + if getattr(view, "swagger_fake_view", False): + return None + return getattr(self.instance.project.organisation, "subscription", None) # type: ignore[union-attr] class CloneEnvironmentSerializer(EnvironmentSerializerLight): diff --git a/api/environments/views.py b/api/environments/views.py index 3dc02109dc84..d1cacb352a43 100644 --- a/api/environments/views.py +++ b/api/environments/views.py @@ -54,10 +54,11 @@ ) from .serializers import ( CloneEnvironmentSerializer, - CreateUpdateEnvironmentSerializer, + CreateEnvironmentSerializer, EnvironmentAPIKeySerializer, EnvironmentRetrieveSerializerWithMetadata, EnvironmentSerializerWithMetadata, + UpdateEnvironmentSerializer, WebhookSerializer, ) @@ -100,8 +101,10 @@ def get_serializer_class(self): # type: ignore[no-untyped-def] return CloneEnvironmentSerializer if self.action == "retrieve": return EnvironmentRetrieveSerializerWithMetadata - elif self.action in ("create", "update", "partial_update"): - return CreateUpdateEnvironmentSerializer + elif self.action == "create": + return CreateEnvironmentSerializer + elif self.action in ("update", "partial_update"): + return UpdateEnvironmentSerializer return EnvironmentSerializerWithMetadata def get_serializer_context(self): # type: ignore[no-untyped-def] diff --git a/api/tests/unit/environments/test_unit_environments_views.py b/api/tests/unit/environments/test_unit_environments_views.py index 88e5efd86b33..7ad0c0635a97 100644 --- a/api/tests/unit/environments/test_unit_environments_views.py +++ b/api/tests/unit/environments/test_unit_environments_views.py @@ -1088,6 +1088,38 @@ def test_environment_update_cannot_change_is_creating( assert response.json()["is_creating"] is False +def test_environment_update_cannot_change_project( + environment: Environment, + project: Project, + organisation: Organisation, + admin_client_new: APIClient, +) -> None: + # Given - an environment in the original project + original_project = project + url = reverse("api-v1:environments:environment-detail", args=[environment.api_key]) + + # and a different project + other_project = Project.objects.create( + name="Other Project", organisation=organisation + ) + + data = { + "project": other_project.id, + "name": environment.name, + } + + # When + response = admin_client_new.put( + url, data=json.dumps(data), content_type="application/json" + ) + + # Then - the project should NOT change + assert response.status_code == status.HTTP_200_OK + environment.refresh_from_db() + assert environment.project_id == original_project.id + assert response.json()["project"] == original_project.id + + def test_get_document( environment: Environment, project: Project,