diff --git a/api/core/constants.py b/api/core/constants.py index 12707fcbaeff..5cbb9e0427f1 100644 --- a/api/core/constants.py +++ b/api/core/constants.py @@ -8,3 +8,4 @@ FLAGSMITH_SIGNATURE_HEADER = "X-Flagsmith-Signature" FLAGSMITH_UPDATED_AT_HEADER = "X-Flagsmith-Document-Updated-At" +SDK_ENVIRONMENT_KEY_HEADER = "X_ENVIRONMENT_KEY" diff --git a/api/environments/identities/views.py b/api/environments/identities/views.py index 19f3199f9eaf..9240700b6042 100644 --- a/api/environments/identities/views.py +++ b/api/environments/identities/views.py @@ -10,13 +10,14 @@ from django.utils import timezone from django.utils.decorators import method_decorator from django.views.decorators.cache import cache_page +from django.views.decorators.vary import vary_on_headers from drf_yasg.utils import swagger_auto_schema # type: ignore[import-untyped] from rest_framework import status, viewsets from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from app.pagination import CustomPagination -from core.constants import FLAGSMITH_UPDATED_AT_HEADER +from core.constants import FLAGSMITH_UPDATED_AT_HEADER, SDK_ENVIRONMENT_KEY_HEADER from core.request_origin import RequestOrigin from edge_api.identities.tasks import forward_identity_request from environments.identities.models import Identity @@ -161,6 +162,7 @@ class SDKIdentities(SDKAPIView): query_serializer=SDKIdentitiesQuerySerializer(), operation_id="identify_user", ) + @method_decorator(vary_on_headers(SDK_ENVIRONMENT_KEY_HEADER)) @method_decorator( cache_page( timeout=settings.GET_IDENTITIES_ENDPOINT_CACHE_SECONDS, diff --git a/api/features/views.py b/api/features/views.py index 4f0138247bda..ef55ffe82978 100644 --- a/api/features/views.py +++ b/api/features/views.py @@ -10,6 +10,7 @@ from django.utils import timezone from django.utils.decorators import method_decorator from django.views.decorators.cache import cache_page +from django.views.decorators.vary import vary_on_headers from drf_yasg import openapi # type: ignore[import-untyped] from drf_yasg.utils import swagger_auto_schema # type: ignore[import-untyped] from rest_framework import mixins, serializers, status, viewsets @@ -24,7 +25,7 @@ from app.pagination import CustomPagination from app_analytics.analytics_db_service import get_feature_evaluation_data from app_analytics.influxdb_wrapper import get_multiple_event_list_for_feature -from core.constants import FLAGSMITH_UPDATED_AT_HEADER +from core.constants import FLAGSMITH_UPDATED_AT_HEADER, SDK_ENVIRONMENT_KEY_HEADER from core.request_origin import RequestOrigin from environments.authentication import EnvironmentKeyAuthentication from environments.identities.models import Identity @@ -782,6 +783,7 @@ class SDKFeatureStates(GenericAPIView): # type: ignore[type-arg] query_serializer=SDKFeatureStatesQuerySerializer(), responses={200: FeatureStateSerializerFull(many=True)}, ) + @method_decorator(vary_on_headers(SDK_ENVIRONMENT_KEY_HEADER)) @method_decorator( cache_page( timeout=settings.GET_FLAGS_ENDPOINT_CACHE_SECONDS, diff --git a/api/tests/unit/conftest.py b/api/tests/unit/conftest.py index 5c324ef36c08..84c63b752b90 100644 --- a/api/tests/unit/conftest.py +++ b/api/tests/unit/conftest.py @@ -1,7 +1,8 @@ from unittest.mock import MagicMock import pytest -from django.core.cache import BaseCache +from django.core.cache import BaseCache, caches +from django.core.cache.backends.locmem import LocMemCache from pytest_django.fixtures import SettingsWrapper from pytest_mock import MockerFixture @@ -232,3 +233,23 @@ def populate_environment_document_cache( persistent_environment_document_cache.get.return_value = ( map_environment_to_environment_document(environment) ) + + +@pytest.fixture() +def use_local_mem_cache_for_cache_middleware(mocker: MockerFixture) -> None: + # Ensure the default cache is LocMemCache + default_cache = caches["default"] + assert isinstance(default_cache, LocMemCache) + + # Patch CacheMiddleware to use 'default' cache and a non-zero timeout + # This is necessary because override_settings doesn't reliably affect middleware behavior + from django.middleware.cache import CacheMiddleware + + original_init = CacheMiddleware.__init__ + + def custom_init(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + original_init(self, *args, **kwargs) + self.page_timeout = 10 # enable caching for the view + self.cache_alias = "default" # force use of in-memory test cache + + mocker.patch.object(CacheMiddleware, "__init__", custom_init) diff --git a/api/tests/unit/environments/identities/test_unit_identities_views.py b/api/tests/unit/environments/identities/test_unit_identities_views.py index a053a924fe21..02e637de19d7 100644 --- a/api/tests/unit/environments/identities/test_unit_identities_views.py +++ b/api/tests/unit/environments/identities/test_unit_identities_views.py @@ -17,7 +17,11 @@ from rest_framework.permissions import IsAuthenticated from rest_framework.test import APIClient -from core.constants import FLAGSMITH_UPDATED_AT_HEADER, STRING +from core.constants import ( + FLAGSMITH_UPDATED_AT_HEADER, + SDK_ENVIRONMENT_KEY_HEADER, + STRING, +) from environments.identities.helpers import ( get_hashed_percentage_for_object_ids, ) @@ -338,6 +342,54 @@ def test_identities_endpoint_returns_all_feature_states_for_identity_if_feature_ assert len(response.data["flags"]) == 2 +def test_get_flags_for_identities_with_cache( + environment: Environment, + feature: Feature, + django_assert_num_queries: DjangoAssertNumQueries, + use_local_mem_cache_for_cache_middleware: None, + project_two_feature: Feature, + project_two_environment: Environment, +) -> None: + # Given + base_url = reverse("api-v1:sdk-identities") + url = base_url + "?identifier=some-identifier" + + # Create clients for two separate environments + environment_one_client = APIClient( + headers={SDK_ENVIRONMENT_KEY_HEADER: environment.api_key} + ) + project_two_environment_client = APIClient( + headers={SDK_ENVIRONMENT_KEY_HEADER: project_two_environment.api_key} + ) + + # Fetch flags for both environments once to warm the cache + environment_one_response = environment_one_client.get(url) + assert environment_one_response.status_code == status.HTTP_200_OK + + project_two_environment_response = project_two_environment_client.get(url) + assert project_two_environment_response.status_code == status.HTTP_200_OK + + # When + with django_assert_num_queries(0): + for _ in range(10): + environment_one_response = environment_one_client.get(url) + assert environment_one_response.status_code == status.HTTP_200_OK + + project_two_environment_response = project_two_environment_client.get(url) + assert project_two_environment_response.status_code == status.HTTP_200_OK + + # Then + # Each response must return the correct feature for its environment + assert ( + environment_one_response.json()["flags"][0]["feature"]["id"] + == feature.id + ) + assert ( + project_two_environment_response.json()["flags"][0]["feature"]["id"] + == project_two_feature.id + ) + + @mock.patch("integrations.amplitude.amplitude.AmplitudeWrapper.identify_user_async") def test_identities_endpoint_get_all_feature_amplitude_called( mock_amplitude_wrapper: mock.MagicMock, diff --git a/api/tests/unit/features/test_unit_features_views.py b/api/tests/unit/features/test_unit_features_views.py index 23c12dc07155..552f2ebed230 100644 --- a/api/tests/unit/features/test_unit_features_views.py +++ b/api/tests/unit/features/test_unit_features_views.py @@ -34,7 +34,7 @@ IDENTITY_FEATURE_STATE_UPDATED_MESSAGE, ) from audit.models import AuditLog, RelatedObjectType # type: ignore[attr-defined] -from core.constants import FLAGSMITH_UPDATED_AT_HEADER +from core.constants import FLAGSMITH_UPDATED_AT_HEADER, SDK_ENVIRONMENT_KEY_HEADER from environments.dynamodb import ( DynamoEnvironmentV2Wrapper, DynamoIdentityWrapper, @@ -799,6 +799,49 @@ def test_get_flags__server_key_only_feature__return_expected( assert not response.json() +def test_get_flags_cache( + environment: Environment, + feature: Feature, + django_assert_num_queries: DjangoAssertNumQueries, + project_two_feature: Feature, + project_two_environment: Environment, + use_local_mem_cache_for_cache_middleware: None, +) -> None: + # Given + url = reverse("api-v1:flags") + + # Create clients for two separate environments + environment_one_client = APIClient( + headers={SDK_ENVIRONMENT_KEY_HEADER: environment.api_key} + ) + project_two_environment_client = APIClient( + headers={SDK_ENVIRONMENT_KEY_HEADER: project_two_environment.api_key} + ) + # Fetch flags for both environments once to warm the cache + environment_one_response = environment_one_client.get(url) + assert environment_one_response.status_code == status.HTTP_200_OK + + project_two_environment_response = project_two_environment_client.get(url) + assert project_two_environment_response.status_code == status.HTTP_200_OK + + # When + with django_assert_num_queries(0): + for _ in range(10): + environment_one_response = environment_one_client.get(url) + assert environment_one_response.status_code == status.HTTP_200_OK + + project_two_environment_response = project_two_environment_client.get(url) + assert project_two_environment_response.status_code == status.HTTP_200_OK + + # Then + # Each response must return the correct feature for its environment + assert environment_one_response.json()[0]["feature"]["id"] == feature.id + assert ( + project_two_environment_response.json()[0]["feature"]["id"] + == project_two_feature.id + ) + + def test_get_flags__server_key_only_feature__server_key_auth__return_expected( api_client: APIClient, environment_api_key: EnvironmentAPIKey,