diff --git a/.github/scripts/fetch_auth0_test_token.py b/.github/scripts/fetch_auth0_test_token.py index 1e490628..7cf65a7d 100644 --- a/.github/scripts/fetch_auth0_test_token.py +++ b/.github/scripts/fetch_auth0_test_token.py @@ -10,14 +10,16 @@ def main() -> int: url = f"https://{os.environ['AUTH0_DOMAIN']}/oauth/token" - payload = json.dumps( - { - "client_id": os.environ["AUTH0_CLIENT_ID"], - "client_secret": os.environ["AUTH0_CLIENT_SECRET"], - "audience": os.environ["AUTH0_AUDIENCE"], - "grant_type": "client_credentials", - } - ).encode("utf-8") + token_request = { + "client_id": os.environ["AUTH0_CLIENT_ID"], + "client_secret": os.environ["AUTH0_CLIENT_SECRET"], + "audience": os.environ["AUTH0_AUDIENCE"], + "grant_type": "client_credentials", + } + scope = os.environ.get("AUTH0_TEST_TOKEN_SCOPES") + if scope: + token_request["scope"] = scope + payload = json.dumps(token_request).encode("utf-8") request = urllib.request.Request( url, diff --git a/.github/workflows/deploy-staged.yml b/.github/workflows/deploy-staged.yml index d446049b..d002a0a0 100644 --- a/.github/workflows/deploy-staged.yml +++ b/.github/workflows/deploy-staged.yml @@ -264,6 +264,7 @@ jobs: AUTH0_AUDIENCE: ${{ secrets.AUTH0_AUDIENCE_NO_DOMAIN }} AUTH0_CLIENT_ID: ${{ secrets.AUTH0_TEST_TOKEN_CLIENT_ID }} AUTH0_CLIENT_SECRET: ${{ secrets.AUTH0_TEST_TOKEN_CLIENT_SECRET }} + AUTH0_TEST_TOKEN_SCOPES: ${{ secrets.AUTH0_TEST_TOKEN_SCOPES }} run: python .github/scripts/fetch_auth0_test_token.py - name: Run deployed integration tests diff --git a/changelog.d/calculate-analytics-requests-endpoint.added.md b/changelog.d/calculate-analytics-requests-endpoint.added.md new file mode 100644 index 00000000..7c7857d8 --- /dev/null +++ b/changelog.d/calculate-analytics-requests-endpoint.added.md @@ -0,0 +1,2 @@ +Added an authenticated, scoped calculate analytics endpoint for value-free +request and unique variable-key queries. diff --git a/config/README.md b/config/README.md index 083da266..4ae89586 100644 --- a/config/README.md +++ b/config/README.md @@ -168,6 +168,7 @@ auth: address: Auth0 domain (without https:// or trailing slash) audience: Auth0 audience/API identifier test_token: JWT token used only for pre-deployment GitHub Actions tests + test_token_scopes: Space-delimited OAuth scopes for the static test token ai: enabled: Whether AI features are enabled (true/false) (these features are only used in the alpha-mode AI explainer endpoint) @@ -308,6 +309,7 @@ AUTH0_AUDIENCE_NO_DOMAIN=https://your-api-identifier When Auth0 is enabled, the following endpoints require valid JWT tokens: - `//calculate` - Main calculation endpoint - `//ai-analysis` - AI analysis endpoint (remains in alpha) +- `/analytics/calculate/requests` - Calculate analytics endpoint; additionally requires the `read:calculate-analytics` scope The following endpoints remain unprotected: - `/` - Home endpoint @@ -341,6 +343,7 @@ AUTH__ENABLED=true # Enable Auth0 authentication AUTH0_ADDRESS_NO_DOMAIN=${{ secrets.AUTH0_ADDRESS_NO_DOMAIN }} AUTH0_AUDIENCE_NO_DOMAIN=${{ secrets.AUTH0_AUDIENCE_NO_DOMAIN }} AUTH0_TEST_TOKEN_NO_DOMAIN=${{ secrets.AUTH0_TEST_TOKEN_NO_DOMAIN }} # Used for local testing purposes +AUTH0_TEST_TOKEN_SCOPES=read:calculate-analytics # Used for scoped local testing # Analytics configuration (opt-in) ANALYTICS__ENABLED=true # Enable user analytics @@ -390,7 +393,8 @@ auth: auth0: address: ${AUTH0_DOMAIN} audience: ${AUTH0_AUDIENCE} - test_bearer_token: ${AUTH0_TOKEN} + test_token: ${AUTH0_TOKEN} + test_token_scopes: ${AUTH0_TOKEN_SCOPES} ai: enabled: true diff --git a/config/default.yaml b/config/default.yaml index e7bddf27..27d89aa9 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -47,6 +47,8 @@ auth: audience: "" # Override with AUTH0_AUDIENCE_NO_DOMAIN # Test JWT token used only for GitHub Actions tests pre-deployment test_token: "" # Override with AUTH0_TEST_TOKEN_NO_DOMAIN + # Space-delimited OAuth scopes for the static test token + test_token_scopes: "" # Override with AUTH0_TEST_TOKEN_SCOPES # AI services configuration ai: diff --git a/config/production.yaml.example b/config/production.yaml.example index c921824a..e8433604 100644 --- a/config/production.yaml.example +++ b/config/production.yaml.example @@ -23,6 +23,7 @@ auth: address: ${AUTH0_ADDRESS_NO_DOMAIN} # From env var audience: ${AUTH0_AUDIENCE_NO_DOMAIN} # From env var test_token: "" # Used only for executing pre-deploy integration tests + test_token_scopes: "" # Used only for executing pre-deploy integration tests ai: enabled: true diff --git a/config/test_with_auth.yaml b/config/test_with_auth.yaml index 657ea8af..90f543c0 100644 --- a/config/test_with_auth.yaml +++ b/config/test_with_auth.yaml @@ -12,6 +12,7 @@ auth: address: ${AUTH0_ADDRESS_NO_DOMAIN} # From env var audience: ${AUTH0_AUDIENCE_NO_DOMAIN} # From env var test_token: ${AUTH0_TEST_TOKEN_NO_DOMAIN} # From env var + test_token_scopes: ${AUTH0_TEST_TOKEN_SCOPES} # From env var ai: enabled: false diff --git a/policyengine_household_api/api.py b/policyengine_household_api/api.py index 96474dbc..47e44303 100644 --- a/policyengine_household_api/api.py +++ b/policyengine_household_api/api.py @@ -21,7 +21,7 @@ ) # Internal imports -from .decorators.auth import create_auth_decorator +from .decorators.auth import ANALYTICS_READ_SCOPE, create_auth_decorator from policyengine_household_api.decorators.analytics import ( log_analytics_if_enabled, ) @@ -29,6 +29,7 @@ # Endpoints from .endpoints import ( get_home, + get_calculate_analytics_requests, get_calculate, generate_ai_explainer, ) @@ -78,6 +79,13 @@ def calculate(country_id): return get_calculate(country_id) +@app.route("/analytics/calculate/requests", methods=["GET"]) +@require_auth_if_enabled([ANALYTICS_READ_SCOPE]) +@limiter.limit("60 per minute") +def calculate_analytics_requests(): + return get_calculate_analytics_requests() + + @app.route("//ai-analysis", methods=["POST"]) @require_auth_if_enabled() def ai_analysis(country_id: str): diff --git a/policyengine_household_api/decorators/auth.py b/policyengine_household_api/decorators/auth.py index a675b45f..c5611bf1 100644 --- a/policyengine_household_api/decorators/auth.py +++ b/policyengine_household_api/decorators/auth.py @@ -10,7 +10,9 @@ from authlib.integrations.flask_oauth2 import ResourceProtector from authlib.oauth2.rfc6750 import BearerTokenValidator from ..auth.validation import Auth0JWTBearerTokenValidator -from ..utils.config_loader import get_config, get_config_value +from ..utils.config_loader import get_config_value + +ANALYTICS_READ_SCOPE = "read:calculate-analytics" class StaticBearerToken: @@ -33,15 +35,16 @@ def get_scope(self) -> str: class StaticBearerTokenValidator(BearerTokenValidator): """Accept a single configured bearer token for test environments.""" - def __init__(self, expected_token: str): + def __init__(self, expected_token: str, scopes: str | None = ""): super().__init__() self.expected_token = expected_token + self.scopes = scopes or "" def authenticate_token( self, token_string: Optional[str] ) -> Optional[StaticBearerToken]: if token_string == self.expected_token: - return StaticBearerToken(token_string) + return StaticBearerToken(token_string, scope=self.scopes) return None @@ -98,6 +101,9 @@ def _setup_authentication(self) -> None: self._auth_enabled = get_config_value("auth.enabled", False) app_environment = get_config_value("app.environment", "") auth0_test_token = get_config_value("auth.auth0.test_token", "") + auth0_test_token_scopes = get_config_value( + "auth.auth0.test_token_scopes", "" + ) # Get Auth0 configuration values auth0_address = get_config_value("auth.auth0.address", "") @@ -108,7 +114,9 @@ def _setup_authentication(self) -> None: if app_environment == "test_with_auth" and auth0_test_token: resource_protector = ResourceProtector() resource_protector.register_token_validator( - StaticBearerTokenValidator(auth0_test_token) + StaticBearerTokenValidator( + auth0_test_token, auth0_test_token_scopes + ) ) self._decorator = resource_protector elif auth0_address and auth0_audience: diff --git a/policyengine_household_api/endpoints/__init__.py b/policyengine_household_api/endpoints/__init__.py index 9591df9a..d4a74acf 100644 --- a/policyengine_household_api/endpoints/__init__.py +++ b/policyengine_household_api/endpoints/__init__.py @@ -1,3 +1,8 @@ -from .home import get_home -from .household import get_calculate -from .household_explainer import generate_ai_explainer +from .analytics import ( + get_calculate_analytics_requests as get_calculate_analytics_requests, +) +from .home import get_home as get_home +from .household import get_calculate as get_calculate +from .household_explainer import ( + generate_ai_explainer as generate_ai_explainer, +) diff --git a/policyengine_household_api/endpoints/analytics.py b/policyengine_household_api/endpoints/analytics.py new file mode 100644 index 00000000..fde18adb --- /dev/null +++ b/policyengine_household_api/endpoints/analytics.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +import json +from typing import Any + +from flask import Response, request +from sqlalchemy import func + +from policyengine_household_api.data.analytics_setup import ( + is_analytics_enabled, + is_analytics_schema_ready, +) +from policyengine_household_api.data.models import ( + CalculateRequest, + CalculateRequestVariable, +) + + +DEFAULT_REQUEST_LIMIT = 1_000 +MAX_REQUEST_LIMIT = 10_000 +TRUE_VALUES = {"1", "true", "yes"} +FALSE_VALUES = {"0", "false", "no"} + + +@dataclass(frozen=True) +class CalculateAnalyticsQuery: + start_time: datetime | None + end_time: datetime | None + unique: bool + limit: int + + +def get_calculate_analytics_requests() -> Response: + try: + query = _parse_query_args() + except ValueError as e: + return _json_response( + {"status": "error", "message": str(e)}, + status=400, + ) + + analytics_storage_error = _analytics_storage_error_response() + if analytics_storage_error is not None: + return analytics_storage_error + + response_body: dict[str, Any] = { + "status": "ok", + "message": None, + "start_time": _datetime_to_json(query.start_time), + "end_time": _datetime_to_json(query.end_time), + "unique": query.unique, + } + if query.unique: + response_body["unique_keys"] = _unique_variable_keys(query) + else: + response_body["requests"] = _calculate_requests(query) + + return _json_response(response_body) + + +def _analytics_storage_error_response() -> Response | None: + if not is_analytics_enabled(): + return _json_response( + { + "status": "error", + "message": "Analytics is not enabled for this API instance.", + }, + status=503, + ) + + if not is_analytics_schema_ready(): + return _json_response( + { + "status": "error", + "message": "Analytics storage is not ready.", + }, + status=503, + ) + + return None + + +def _parse_query_args() -> CalculateAnalyticsQuery: + start_time = _parse_optional_datetime( + _first_query_arg("start_time", "start"), + "start_time", + ) + end_time = _parse_optional_datetime( + _first_query_arg("end_time", "end"), + "end_time", + ) + if start_time and end_time and start_time > end_time: + raise ValueError("`start_time` must be before or equal to `end_time`") + + return CalculateAnalyticsQuery( + start_time=start_time, + end_time=end_time, + unique=_parse_bool(request.args.get("unique"), default=False), + limit=_parse_limit(request.args.get("limit")), + ) + + +def _first_query_arg(*names: str) -> str | None: + for name in names: + value = request.args.get(name) + if value: + return value + return None + + +def _parse_optional_datetime( + value: str | None, + name: str, +) -> datetime | None: + if value is None: + return None + + try: + parsed = datetime.fromisoformat(value.replace("Z", "+00:00")) + except ValueError as e: + raise ValueError(f"`{name}` must be an ISO 8601 datetime") from e + + if parsed.tzinfo is None: + return parsed + return parsed.astimezone(timezone.utc).replace(tzinfo=None) + + +def _parse_bool(value: str | None, *, default: bool) -> bool: + if value is None: + return default + + normalized = value.lower() + if normalized in TRUE_VALUES: + return True + if normalized in FALSE_VALUES: + return False + raise ValueError("`unique` must be true or false") + + +def _parse_limit(value: str | None) -> int: + if value is None: + return DEFAULT_REQUEST_LIMIT + + try: + limit = int(value) + except ValueError as e: + raise ValueError("`limit` must be an integer") from e + + if limit < 1 or limit > MAX_REQUEST_LIMIT: + raise ValueError(f"`limit` must be between 1 and {MAX_REQUEST_LIMIT}") + return limit + + +def _calculate_requests( + query: CalculateAnalyticsQuery, +) -> list[dict[str, Any]]: + request_query = _apply_time_filters(CalculateRequest.query, query) + calculate_requests = ( + request_query.order_by(CalculateRequest.created_at.desc()) + .limit(query.limit) + .all() + ) + if not calculate_requests: + return [] + + variables_by_request_id = _variables_by_request_id( + [calculate_request.id for calculate_request in calculate_requests] + ) + return [ + _request_to_dict( + calculate_request, + variables_by_request_id.get(calculate_request.id, []), + ) + for calculate_request in calculate_requests + ] + + +def _variables_by_request_id( + request_ids: list[int], +) -> dict[int, list[CalculateRequestVariable]]: + variable_rows = ( + CalculateRequestVariable.query.filter( + CalculateRequestVariable.request_id.in_(request_ids) + ) + .order_by( + CalculateRequestVariable.request_id, + CalculateRequestVariable.variable_name, + CalculateRequestVariable.entity_type, + CalculateRequestVariable.source, + ) + .all() + ) + variables_by_request_id: dict[int, list[CalculateRequestVariable]] = {} + for variable_row in variable_rows: + variables_by_request_id.setdefault( + variable_row.request_id, + [], + ).append(variable_row) + return variables_by_request_id + + +def _unique_variable_keys( + query: CalculateAnalyticsQuery, +) -> list[dict[str, Any]]: + variable_query = _apply_time_filters(CalculateRequestVariable.query, query) + rows = ( + variable_query.with_entities( + CalculateRequestVariable.variable_name, + CalculateRequestVariable.entity_type, + CalculateRequestVariable.source, + CalculateRequestVariable.period_granularity, + CalculateRequestVariable.availability_status, + CalculateRequestVariable.variable_name_truncated, + func.count(func.distinct(CalculateRequestVariable.request_id)), + func.sum(CalculateRequestVariable.occurrence_count), + func.min(CalculateRequestVariable.created_at), + func.max(CalculateRequestVariable.created_at), + ) + .group_by( + CalculateRequestVariable.variable_name, + CalculateRequestVariable.entity_type, + CalculateRequestVariable.source, + CalculateRequestVariable.period_granularity, + CalculateRequestVariable.availability_status, + CalculateRequestVariable.variable_name_truncated, + ) + .order_by( + CalculateRequestVariable.variable_name, + CalculateRequestVariable.entity_type, + CalculateRequestVariable.source, + ) + .all() + ) + + return [ + { + "variable_name": row[0], + "entity_type": row[1], + "source": row[2], + "period_granularity": row[3], + "availability_status": row[4], + "variable_name_truncated": bool(row[5]), + "request_count": int(row[6] or 0), + "occurrence_count": int(row[7] or 0), + "first_seen": _datetime_to_json(row[8]), + "last_seen": _datetime_to_json(row[9]), + } + for row in rows + ] + + +def _apply_time_filters(query, filters: CalculateAnalyticsQuery): + model = query.column_descriptions[0]["entity"] + created_at = model.created_at + if filters.start_time: + query = query.filter(created_at >= filters.start_time) + if filters.end_time: + query = query.filter(created_at <= filters.end_time) + return query + + +def _request_to_dict( + calculate_request: CalculateRequest, + variable_rows: list[CalculateRequestVariable], +) -> dict[str, Any]: + return { + "request_uuid": calculate_request.request_uuid, + "created_at": _datetime_to_json(calculate_request.created_at), + "api_version": calculate_request.api_version, + "country_id": calculate_request.country_id, + "model_version": calculate_request.model_version, + "endpoint": calculate_request.endpoint, + "method": calculate_request.method, + "response_status_code": calculate_request.response_status_code, + "distinct_variable_count": calculate_request.distinct_variable_count, + "unsupported_variable_count": ( + calculate_request.unsupported_variable_count + ), + "deprecated_allowlisted_variable_count": ( + calculate_request.deprecated_allowlisted_variable_count + ), + "variables": [ + _variable_to_dict(variable_row) for variable_row in variable_rows + ], + } + + +def _variable_to_dict( + variable_row: CalculateRequestVariable, +) -> dict[str, Any]: + return { + "variable_name": variable_row.variable_name, + "entity_type": variable_row.entity_type, + "source": variable_row.source, + "period_granularity": variable_row.period_granularity, + "entity_count": variable_row.entity_count, + "period_count": variable_row.period_count, + "occurrence_count": variable_row.occurrence_count, + "availability_status": variable_row.availability_status, + "variable_name_truncated": bool(variable_row.variable_name_truncated), + } + + +def _datetime_to_json(value: datetime | None) -> str | None: + if value is None: + return None + if value.tzinfo is not None: + value = value.astimezone(timezone.utc).replace(tzinfo=None) + return value.isoformat() + "Z" + + +def _json_response(payload: dict[str, Any], *, status: int = 200) -> Response: + return Response( + json.dumps(payload), + status=status, + mimetype="application/json", + ) diff --git a/policyengine_household_api/openapi_spec.yaml b/policyengine_household_api/openapi_spec.yaml index c92fb9e2..0a68233a 100644 --- a/policyengine_household_api/openapi_spec.yaml +++ b/policyengine_household_api/openapi_spec.yaml @@ -686,6 +686,125 @@ paths: type: string message: type: string + /analytics/calculate/requests: + get: + summary: Get calculate request variable analytics + operationId: get_calculate_analytics_requests + description: Returns value-free analytics records for inbound calculate requests. Requires a bearer token with the read:calculate-analytics scope. Each record includes request metadata and grouped variable keys, but not household values, household entity IDs, request bodies, response bodies, or client IDs. + security: + - bearerAuth: [] + parameters: + - name: start_time + in: query + description: "Optional inclusive ISO 8601 lower bound for request creation time. Alias: start." + required: false + schema: + type: string + format: date-time + - name: end_time + in: query + description: "Optional inclusive ISO 8601 upper bound for request creation time. Alias: end." + required: false + schema: + type: string + format: date-time + - name: unique + in: query + description: When true, return one row per unique variable key instead of one row per calculate request. + required: false + schema: + type: boolean + default: false + - name: limit + in: query + description: Maximum number of request records returned when unique is false. + required: false + schema: + type: integer + default: 1000 + minimum: 1 + maximum: 10000 + responses: + 200: + description: Calculate request analytics. + content: + application/json: + schema: + $ref: "#/components/schemas/CalculateAnalyticsResponse" + examples: + request_records: + summary: Request records + value: + status: ok + message: null + start_time: "2026-05-07T00:00:00Z" + end_time: "2026-05-13T00:00:00Z" + unique: false + requests: + - request_uuid: "018f79e7-6ee3-7621-9eda-6d29cf0cf910" + created_at: "2026-05-10T12:00:00Z" + api_version: "0.17.0" + country_id: us + model_version: "1.691.1" + endpoint: calculate + method: POST + response_status_code: 200 + distinct_variable_count: 2 + unsupported_variable_count: 0 + deprecated_allowlisted_variable_count: 0 + variables: + - variable_name: employment_income + entity_type: person + source: household_input + period_granularity: year + entity_count: 1 + period_count: 1 + occurrence_count: 1 + availability_status: supported + variable_name_truncated: false + unique_keys: + summary: Unique variable keys + value: + status: ok + message: null + start_time: null + end_time: null + unique: true + unique_keys: + - variable_name: age + entity_type: person + source: household_input + period_granularity: year + availability_status: supported + variable_name_truncated: false + request_count: 12 + occurrence_count: 18 + first_seen: "2026-05-01T00:00:00Z" + last_seen: "2026-05-13T00:00:00Z" + 400: + description: Invalid analytics query parameter. + content: + application/json: + schema: + $ref: "#/components/schemas/ApiErrorResponse" + 401: + description: Missing or invalid bearer token. + content: + application/json: + schema: + $ref: "#/components/schemas/AuthErrorResponse" + 403: + description: The bearer token is valid but does not include the read:calculate-analytics scope. + content: + application/json: + schema: + $ref: "#/components/schemas/AuthErrorResponse" + 503: + description: Analytics storage is disabled or not ready. + content: + application/json: + schema: + $ref: "#/components/schemas/ApiErrorResponse" /{country_id}/economy/{policy_id}/over/{baseline_policy_id}: get: summary: Calculate the economic impact of a policy @@ -976,7 +1095,7 @@ components: type: http scheme: bearer bearerFormat: JWT - description: Hosted API requests require an Auth0 bearer token. Local Docker requests usually run without authentication. + description: Hosted API requests require an Auth0 bearer token. Local Docker requests usually run without authentication. Analytics endpoints may require additional OAuth scopes. schemas: CalculateRequest: type: object @@ -1168,3 +1287,148 @@ components: type: string error_description: type: string + CalculateAnalyticsResponse: + type: object + required: + - status + - message + - start_time + - end_time + - unique + properties: + status: + type: string + enum: + - ok + message: + type: string + nullable: true + start_time: + type: string + format: date-time + nullable: true + end_time: + type: string + format: date-time + nullable: true + unique: + type: boolean + requests: + type: array + items: + $ref: "#/components/schemas/CalculateAnalyticsRequestRecord" + unique_keys: + type: array + items: + $ref: "#/components/schemas/CalculateAnalyticsVariableKeySummary" + CalculateAnalyticsRequestRecord: + type: object + properties: + request_uuid: + type: string + format: uuid + created_at: + type: string + format: date-time + api_version: + type: string + nullable: true + country_id: + type: string + model_version: + type: string + nullable: true + endpoint: + type: string + nullable: true + method: + type: string + enum: + - DELETE + - GET + - HEAD + - OPTIONS + - PATCH + - POST + - PUT + response_status_code: + type: integer + nullable: true + distinct_variable_count: + type: integer + unsupported_variable_count: + type: integer + deprecated_allowlisted_variable_count: + type: integer + variables: + type: array + items: + $ref: "#/components/schemas/CalculateAnalyticsVariableRecord" + CalculateAnalyticsVariableRecord: + type: object + properties: + variable_name: + type: string + entity_type: + type: string + source: + $ref: "#/components/schemas/CalculateAnalyticsVariableSource" + period_granularity: + $ref: "#/components/schemas/CalculateAnalyticsPeriodGranularity" + entity_count: + type: integer + period_count: + type: integer + occurrence_count: + type: integer + availability_status: + $ref: "#/components/schemas/CalculateAnalyticsAvailabilityStatus" + variable_name_truncated: + type: boolean + CalculateAnalyticsVariableKeySummary: + type: object + properties: + variable_name: + type: string + entity_type: + type: string + source: + $ref: "#/components/schemas/CalculateAnalyticsVariableSource" + period_granularity: + $ref: "#/components/schemas/CalculateAnalyticsPeriodGranularity" + availability_status: + $ref: "#/components/schemas/CalculateAnalyticsAvailabilityStatus" + variable_name_truncated: + type: boolean + request_count: + type: integer + occurrence_count: + type: integer + first_seen: + type: string + format: date-time + last_seen: + type: string + format: date-time + CalculateAnalyticsVariableSource: + type: string + enum: + - household_input + - requested_output + - mixed + - axis + CalculateAnalyticsAvailabilityStatus: + type: string + enum: + - supported + - deprecated_allowlisted + - unsupported + CalculateAnalyticsPeriodGranularity: + type: string + enum: + - year + - month + - day + - mixed + - none + - unknown diff --git a/policyengine_household_api/utils/config_loader.py b/policyengine_household_api/utils/config_loader.py index e19934e1..13dabd2f 100644 --- a/policyengine_household_api/utils/config_loader.py +++ b/policyengine_household_api/utils/config_loader.py @@ -50,6 +50,7 @@ class ConfigLoader: "AUTH0_ADDRESS_NO_DOMAIN": "auth.auth0.address", "AUTH0_AUDIENCE_NO_DOMAIN": "auth.auth0.audience", "AUTH0_TEST_TOKEN_NO_DOMAIN": "auth.auth0.test_token", + "AUTH0_TEST_TOKEN_SCOPES": "auth.auth0.test_token_scopes", # AI settings "ANTHROPIC_API_KEY": "ai.anthropic.api_key", # Server settings diff --git a/tests/conftest.py b/tests/conftest.py index 72eebd67..efccf28f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,10 @@ pytest_plugins = [ "tests.fixtures.data.analytics_setup", "tests.fixtures.data.analytics_setup_patches", + "tests.fixtures.decorators.auth", "tests.fixtures.decorators.analytics", "tests.fixtures.decorators.analytics_patches", + "tests.fixtures.endpoints.analytics", "tests.fixtures.endpoints.household", ] diff --git a/tests/fixtures/decorators/auth.py b/tests/fixtures/decorators/auth.py index 533fce7c..67892ce9 100644 --- a/tests/fixtures/decorators/auth.py +++ b/tests/fixtures/decorators/auth.py @@ -3,8 +3,10 @@ """ import pytest -from unittest.mock import Mock, patch, MagicMock -from typing import Dict, Any +from typing import Any +from unittest.mock import Mock, patch + +from policyengine_household_api.decorators.auth import ANALYTICS_READ_SCOPE # Sample Auth0 configuration data AUTH0_CONFIG_DATA = { @@ -29,6 +31,7 @@ "auth0": { **AUTH0_CONFIG_DATA, "test_token": "test-jwt-token", + "test_token_scopes": ANALYTICS_READ_SCOPE, }, }, } @@ -126,6 +129,7 @@ def config_side_effect(path: str, default: Any = None) -> Any: "auth.auth0.address": AUTH0_CONFIG_DATA["address"], "auth.auth0.audience": AUTH0_CONFIG_DATA["audience"], "auth.auth0.test_token": "test-jwt-token", + "auth.auth0.test_token_scopes": ANALYTICS_READ_SCOPE, } return config_map.get(path, default) diff --git a/tests/fixtures/endpoints/analytics.py b/tests/fixtures/endpoints/analytics.py new file mode 100644 index 00000000..1166c16e --- /dev/null +++ b/tests/fixtures/endpoints/analytics.py @@ -0,0 +1,200 @@ +from datetime import datetime + +import pytest +from flask import Flask, Response + +from policyengine_household_api.data.analytics_setup import db +from policyengine_household_api.data.models import ( + CalculateRequest, + CalculateRequestVariable, + Visit, +) +from policyengine_household_api.decorators.auth import ( + ANALYTICS_READ_SCOPE, + create_auth_decorator, +) +from policyengine_household_api.endpoints.analytics import ( + get_calculate_analytics_requests, +) + + +TEST_AUTH_TOKEN = "test-jwt-token" + + +@pytest.fixture +def analytics_endpoint_app(tmp_path, monkeypatch): + monkeypatch.setattr( + "policyengine_household_api.endpoints.analytics.is_analytics_enabled", + lambda: True, + ) + monkeypatch.setattr( + "policyengine_household_api.endpoints.analytics." + "is_analytics_schema_ready", + lambda: True, + ) + + app = Flask(__name__) + app.config["SQLALCHEMY_DATABASE_URI"] = ( + f"sqlite:///{tmp_path / 'analytics.db'}" + ) + db.init_app(app) + + with app.app_context(): + db.create_all() + yield app + db.session.remove() + db.drop_all() + + +@pytest.fixture +def add_calculate_analytics_request(): + return _add_calculate_analytics_request + + +@pytest.fixture +def calculate_analytics_variable(): + return _calculate_analytics_variable + + +@pytest.fixture +def scoped_analytics_client_factory(tmp_path, monkeypatch): + apps = [] + + def factory(scopes: str = ""): + _patch_test_auth_config(monkeypatch, scopes) + _patch_analytics_storage_ready(monkeypatch) + + app = Flask(__name__) + app.config["SQLALCHEMY_DATABASE_URI"] = ( + f"sqlite:///{tmp_path / f'analytics-{len(apps)}.db'}" + ) + db.init_app(app) + + auth = create_auth_decorator() + app.add_url_rule( + "/analytics/calculate/requests", + "calculate_analytics_requests", + auth([ANALYTICS_READ_SCOPE])(get_calculate_analytics_requests), + methods=["GET"], + ) + app.add_url_rule( + "/us/calculate", + "calculate", + auth()(lambda: Response("ok", status=200)), + methods=["POST"], + ) + + with app.app_context(): + db.create_all() + + apps.append(app) + return app.test_client() + + yield factory + + for app in apps: + with app.app_context(): + db.session.remove() + db.drop_all() + + +def _add_calculate_analytics_request( + request_uuid: str, + created_at: datetime, + variable_rows: list[dict], +) -> CalculateRequest: + visit = Visit() + visit.client_id = "test-client" + visit.datetime = created_at + visit.api_version = "0.17.0" + visit.endpoint = "calculate" + visit.method = "POST" + visit.content_length_bytes = 123 + db.session.add(visit) + db.session.flush() + + calculate_request = CalculateRequest() + calculate_request.visit_id = visit.id + calculate_request.request_uuid = request_uuid + calculate_request.client_id = "test-client" + calculate_request.api_version = "0.17.0" + calculate_request.country_id = "us" + calculate_request.model_version = "1.691.1" + calculate_request.endpoint = "calculate" + calculate_request.method = "POST" + calculate_request.content_length_bytes = 123 + calculate_request.response_status_code = 200 + calculate_request.distinct_variable_count = len(variable_rows) + calculate_request.unsupported_variable_count = sum( + variable["availability_status"] == "unsupported" + for variable in variable_rows + ) + calculate_request.deprecated_allowlisted_variable_count = 0 + calculate_request.created_at = created_at + db.session.add(calculate_request) + db.session.flush() + + for variable_row in variable_rows: + variable = CalculateRequestVariable() + variable.request_id = calculate_request.id + variable.client_id = "test-client" + variable.created_at = created_at + variable.country_id = "us" + variable.api_version = "0.17.0" + variable.model_version = "1.691.1" + variable.response_status_code = 200 + for key, value in variable_row.items(): + setattr(variable, key, value) + db.session.add(variable) + + db.session.commit() + return calculate_request + + +def _patch_test_auth_config(monkeypatch, scopes: str) -> None: + def get_config_value(path: str, default=None): + config = { + "app.environment": "test_with_auth", + "auth.enabled": True, + "auth.auth0.address": "test-tenant.auth0.com", + "auth.auth0.audience": "https://test-api-identifier", + "auth.auth0.test_token": TEST_AUTH_TOKEN, + "auth.auth0.test_token_scopes": scopes, + } + return config.get(path, default) + + monkeypatch.setattr( + "policyengine_household_api.decorators.auth.get_config_value", + get_config_value, + ) + + +def _patch_analytics_storage_ready(monkeypatch) -> None: + monkeypatch.setattr( + "policyengine_household_api.endpoints.analytics.is_analytics_enabled", + lambda: True, + ) + monkeypatch.setattr( + "policyengine_household_api.endpoints.analytics." + "is_analytics_schema_ready", + lambda: True, + ) + + +def _calculate_analytics_variable( + variable_name: str, + *, + occurrence_count: int = 1, + status: str = "supported", +) -> dict: + return { + "variable_name": variable_name, + "variable_name_truncated": False, + "entity_type": "person", + "source": "household_input", + "period_granularity": "year", + "entity_count": 1, + "period_count": 1, + "occurrence_count": occurrence_count, + "availability_status": status, + } diff --git a/tests/fixtures/utils/config_loader.py b/tests/fixtures/utils/config_loader.py index b638e4e5..b85f8594 100644 --- a/tests/fixtures/utils/config_loader.py +++ b/tests/fixtures/utils/config_loader.py @@ -48,6 +48,7 @@ "USER_ANALYTICS_DB_PASSWORD": "test-password", "AUTH0_ADDRESS_NO_DOMAIN": "test-auth0-address", "AUTH0_AUDIENCE_NO_DOMAIN": "test-auth0-audience", + "AUTH0_TEST_TOKEN_SCOPES": "read:calculate-analytics", "ANTHROPIC_API_KEY": "sk-ant-test-key", "PORT": "9090", } @@ -459,6 +460,7 @@ def temp_realistic_values_file(): "AUTH0_AUDIENCE_NO_DOMAIN=https://household.api.policyengine.org\n" ) f.write("AUTH0_TEST_TOKEN_NO_DOMAIN=test-jwt-token\n") + f.write("AUTH0_TEST_TOKEN_SCOPES=read:calculate-analytics\n") f.write("USER_ANALYTICS_DB_CONNECTION_NAME=project:region:instance\n") f.write("USER_ANALYTICS_DB_USERNAME=analytics_user\n") f.write("USER_ANALYTICS_DB_PASSWORD=analytics_pass\n") @@ -489,6 +491,7 @@ def temp_realistic_config_with_vars(): address: ${AUTH0_ADDRESS_NO_DOMAIN} audience: ${AUTH0_AUDIENCE_NO_DOMAIN} test_token: ${AUTH0_TEST_TOKEN_NO_DOMAIN} + test_token_scopes: ${AUTH0_TEST_TOKEN_SCOPES} analytics: enabled: true diff --git a/tests/unit/auth/test_validation.py b/tests/unit/auth/test_validation.py index c92fd346..0907f2ca 100644 --- a/tests/unit/auth/test_validation.py +++ b/tests/unit/auth/test_validation.py @@ -4,12 +4,18 @@ import time from unittest.mock import patch +import pytest import jwt +from authlib.oauth2.rfc6750.errors import ( + InsufficientScopeError, + InvalidTokenError, +) from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from jwt.algorithms import RSAAlgorithm from policyengine_household_api.auth import validation +from policyengine_household_api.decorators.auth import ANALYTICS_READ_SCOPE class TestAuth0JWTBearerTokenValidator: @@ -129,56 +135,177 @@ def test__given_successful_fetch__is_cached(self): def test__given_rs256_jwks__authenticates_signed_token(self): """Regression guard for Authlib 1.7's joserfc key path.""" - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - ) - public_jwk = json.loads(RSAAlgorithm.to_jwk(private_key.public_key())) - public_jwk.update( + private_key = _private_key() + validator = _validator_for_key(private_key) + token = _signed_token( + private_key, { - "kid": "test-key", - "use": "sig", - "alg": "RS256", - } + "iss": "https://tenant.example/", + "aud": "audience", + "exp": int(time.time()) + 300, + "sub": "client-id", + }, ) - jwks = {"keys": [public_jwk]} - private_pem = private_key.private_bytes( - serialization.Encoding.PEM, - serialization.PrivateFormat.PKCS8, - serialization.NoEncryption(), - ) - token = jwt.encode( + claims = validator.authenticate_token(token) + + assert claims["sub"] == "client-id" + + def test__given_valid_jwt_with_required_scope__validate_token_accepts( + self, + ): + private_key = _private_key() + validator = _validator_for_key(private_key) + token = _signed_token( + private_key, { "iss": "https://tenant.example/", "aud": "audience", "exp": int(time.time()) + 300, "sub": "client-id", + "scope": ANALYTICS_READ_SCOPE, }, - private_pem, - algorithm="RS256", - headers={"kid": "test-key"}, ) - class FakeResponse: - def __enter__(self): - return self + claims = validator.authenticate_token(token) + validator.validate_token(claims, [ANALYTICS_READ_SCOPE], None) - def __exit__(self, *_args): - return None + @pytest.mark.parametrize( + "claim_overrides", + [ + {"aud": "wrong-audience"}, + {"iss": "https://wrong-tenant.example/"}, + {"exp": int(time.time()) - 300}, + ], + ) + def test__given_jwt_with_invalid_standard_claim__validate_token_rejects( + self, + claim_overrides, + ): + private_key = _private_key() + validator = _validator_for_key(private_key) + claims = { + "iss": "https://tenant.example/", + "aud": "audience", + "exp": int(time.time()) + 300, + "sub": "client-id", + "scope": ANALYTICS_READ_SCOPE, + **claim_overrides, + } + token = _signed_token(private_key, claims) - def read(self): - return json.dumps(jwks).encode() + parsed_claims = validator.authenticate_token(token) + with pytest.raises(InvalidTokenError): + validator.validate_token( + parsed_claims, [ANALYTICS_READ_SCOPE], None + ) - with patch( - "policyengine_household_api.auth.validation.urlopen", - return_value=FakeResponse(), - ): - validator = validation.Auth0JWTBearerTokenValidator( - "tenant.example", - "audience", + def test__given_jwt_signed_by_wrong_key__validate_token_rejects(self): + trusted_key = _private_key() + untrusted_key = _private_key() + validator = _validator_for_key(trusted_key) + token = _signed_token( + untrusted_key, + { + "iss": "https://tenant.example/", + "aud": "audience", + "exp": int(time.time()) + 300, + "sub": "client-id", + "scope": ANALYTICS_READ_SCOPE, + }, + ) + + parsed_claims = validator.authenticate_token(token) + with pytest.raises(InvalidTokenError): + validator.validate_token( + parsed_claims, [ANALYTICS_READ_SCOPE], None ) + def test__given_jwt_without_required_scope__validate_token_rejects(self): + private_key = _private_key() + validator = _validator_for_key(private_key) + token = _signed_token( + private_key, + { + "iss": "https://tenant.example/", + "aud": "audience", + "exp": int(time.time()) + 300, + "sub": "client-id", + }, + ) + claims = validator.authenticate_token(token) + with pytest.raises(InsufficientScopeError): + validator.validate_token(claims, [ANALYTICS_READ_SCOPE], None) - assert claims["sub"] == "client-id" + def test__given_jwt_with_permissions_but_no_scope__validate_token_rejects( + self, + ): + private_key = _private_key() + validator = _validator_for_key(private_key) + token = _signed_token( + private_key, + { + "iss": "https://tenant.example/", + "aud": "audience", + "exp": int(time.time()) + 300, + "sub": "client-id", + "permissions": [ANALYTICS_READ_SCOPE], + }, + ) + + claims = validator.authenticate_token(token) + with pytest.raises(InsufficientScopeError): + validator.validate_token(claims, [ANALYTICS_READ_SCOPE], None) + + +def _private_key(): + return rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + +def _signed_token(private_key, claims: dict) -> str: + private_pem = private_key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ) + return jwt.encode( + claims, + private_pem, + algorithm="RS256", + headers={"kid": "test-key"}, + ) + + +def _validator_for_key(private_key): + public_jwk = json.loads(RSAAlgorithm.to_jwk(private_key.public_key())) + public_jwk.update( + { + "kid": "test-key", + "use": "sig", + "alg": "RS256", + } + ) + jwks = {"keys": [public_jwk]} + + class FakeResponse: + def __enter__(self): + return self + + def __exit__(self, *_args): + return None + + def read(self): + return json.dumps(jwks).encode() + + with patch( + "policyengine_household_api.auth.validation.urlopen", + return_value=FakeResponse(), + ): + return validation.Auth0JWTBearerTokenValidator( + "tenant.example", + "audience", + ) diff --git a/tests/unit/decorators/test_auth.py b/tests/unit/decorators/test_auth.py index a2089d43..b01d8428 100644 --- a/tests/unit/decorators/test_auth.py +++ b/tests/unit/decorators/test_auth.py @@ -4,6 +4,7 @@ from unittest.mock import Mock from policyengine_household_api.decorators.auth import ( + ANALYTICS_READ_SCOPE, NoOpDecorator, ConditionalAuthDecorator, create_auth_decorator, @@ -11,16 +12,6 @@ ) from tests.fixtures.decorators.auth import ( AUTH0_CONFIG_DATA, - auth_enabled_environment, - auth_test_environment, - auth_disabled_environment, - auth_enabled_missing_config_environment, - auth_backward_compat_environment, - auth_partial_config_environment, - mock_resource_protector, - mock_auth0_validator, - mock_flask_app, - sample_view_function, ) @@ -85,11 +76,15 @@ def test__given_test_auth_environment__uses_static_token_validator( ) assert isinstance(registered_validator, StaticBearerTokenValidator) assert registered_validator.expected_token == "test-jwt-token" + assert registered_validator.scopes == ANALYTICS_READ_SCOPE assert decorator.get_decorator() is mock_protector_instance assert decorator.is_enabled is True auth_test_environment.assert_any_call("app.environment", "") auth_test_environment.assert_any_call("auth.auth0.test_token", "") + auth_test_environment.assert_any_call( + "auth.auth0.test_token_scopes", "" + ) def test__given_auth_enabled_with_valid_config__auth0_is_configured( self, @@ -207,3 +202,23 @@ def test__given_auth_disabled__returns_noop_decorator( decorator = create_auth_decorator() assert isinstance(decorator, NoOpDecorator) + + +class TestStaticBearerTokenValidator: + def test__given_static_token_without_scopes__token_has_empty_scope(self): + validator = StaticBearerTokenValidator("test-jwt-token") + + token = validator.authenticate_token("test-jwt-token") + + assert token is not None + assert token.get_scope() == "" + + def test__given_static_token_with_scopes__token_exposes_scopes(self): + validator = StaticBearerTokenValidator( + "test-jwt-token", ANALYTICS_READ_SCOPE + ) + + token = validator.authenticate_token("test-jwt-token") + + assert token is not None + assert token.get_scope() == ANALYTICS_READ_SCOPE diff --git a/tests/unit/endpoints/test_analytics_endpoint.py b/tests/unit/endpoints/test_analytics_endpoint.py new file mode 100644 index 00000000..78c3afe1 --- /dev/null +++ b/tests/unit/endpoints/test_analytics_endpoint.py @@ -0,0 +1,286 @@ +from datetime import datetime +import json + +import pytest + +from policyengine_household_api.decorators.auth import ANALYTICS_READ_SCOPE +from policyengine_household_api.endpoints.analytics import ( + get_calculate_analytics_requests, +) +from tests.fixtures.endpoints.analytics import TEST_AUTH_TOKEN + + +def test__calculate_analytics_requests__filters_by_time_window( + analytics_endpoint_app, + add_calculate_analytics_request, + calculate_analytics_variable, +): + old_request = add_calculate_analytics_request( + "old-request", + datetime(2026, 5, 1, 12, 0, 0), + [calculate_analytics_variable("age")], + ) + included_request = add_calculate_analytics_request( + "included-request", + datetime(2026, 5, 10, 12, 0, 0), + [calculate_analytics_variable("employment_income")], + ) + + with analytics_endpoint_app.test_request_context( + "/analytics/calculate/requests?" + "start_time=2026-05-07T00:00:00Z&" + "end_time=2026-05-13T00:00:00Z" + ): + response = get_calculate_analytics_requests() + + payload = json.loads(response.data) + assert response.status_code == 200 + assert payload["unique"] is False + assert [request["request_uuid"] for request in payload["requests"]] == [ + included_request.request_uuid + ] + assert payload["requests"][0]["variables"] == [ + { + "variable_name": "employment_income", + "entity_type": "person", + "source": "household_input", + "period_granularity": "year", + "entity_count": 1, + "period_count": 1, + "occurrence_count": 1, + "availability_status": "supported", + "variable_name_truncated": False, + } + ] + assert old_request.request_uuid not in { + request["request_uuid"] for request in payload["requests"] + } + assert "client_id" not in payload["requests"][0] + + +def test__calculate_analytics_requests__unique_returns_grouped_keys( + analytics_endpoint_app, + add_calculate_analytics_request, + calculate_analytics_variable, +): + add_calculate_analytics_request( + "request-one", + datetime(2026, 5, 10, 12, 0, 0), + [calculate_analytics_variable("age", occurrence_count=2)], + ) + add_calculate_analytics_request( + "request-two", + datetime(2026, 5, 11, 12, 0, 0), + [ + calculate_analytics_variable("age"), + calculate_analytics_variable("bad_input", status="unsupported"), + ], + ) + + with analytics_endpoint_app.test_request_context( + "/analytics/calculate/requests?unique=true" + ): + response = get_calculate_analytics_requests() + + payload = json.loads(response.data) + assert response.status_code == 200 + assert payload["unique"] is True + assert "requests" not in payload + assert payload["unique_keys"] == [ + { + "variable_name": "age", + "entity_type": "person", + "source": "household_input", + "period_granularity": "year", + "availability_status": "supported", + "variable_name_truncated": False, + "request_count": 2, + "occurrence_count": 3, + "first_seen": "2026-05-10T12:00:00Z", + "last_seen": "2026-05-11T12:00:00Z", + }, + { + "variable_name": "bad_input", + "entity_type": "person", + "source": "household_input", + "period_granularity": "year", + "availability_status": "unsupported", + "variable_name_truncated": False, + "request_count": 1, + "occurrence_count": 1, + "first_seen": "2026-05-11T12:00:00Z", + "last_seen": "2026-05-11T12:00:00Z", + }, + ] + + +def test__calculate_analytics_requests__invalid_time_returns_400( + analytics_endpoint_app, +): + with analytics_endpoint_app.test_request_context( + "/analytics/calculate/requests?start_time=not-a-time" + ): + response = get_calculate_analytics_requests() + + payload = json.loads(response.data) + assert response.status_code == 400 + assert payload["status"] == "error" + assert "start_time" in payload["message"] + + +def test__calculate_analytics_requests__analytics_disabled_returns_503( + analytics_endpoint_app, + monkeypatch, +): + monkeypatch.setattr( + "policyengine_household_api.endpoints.analytics.is_analytics_enabled", + lambda: False, + ) + + with analytics_endpoint_app.test_request_context( + "/analytics/calculate/requests" + ): + response = get_calculate_analytics_requests() + + payload = json.loads(response.data) + assert response.status_code == 503 + assert payload == { + "status": "error", + "message": "Analytics is not enabled for this API instance.", + } + + +def test__calculate_analytics_requests__schema_not_ready_returns_503( + analytics_endpoint_app, + monkeypatch, +): + monkeypatch.setattr( + "policyengine_household_api.endpoints.analytics." + "is_analytics_schema_ready", + lambda: False, + ) + + with analytics_endpoint_app.test_request_context( + "/analytics/calculate/requests" + ): + response = get_calculate_analytics_requests() + + payload = json.loads(response.data) + assert response.status_code == 503 + assert payload == { + "status": "error", + "message": "Analytics storage is not ready.", + } + + +def test__calculate_analytics_requests_route__missing_token_returns_401( + scoped_analytics_client_factory, +): + client = scoped_analytics_client_factory(ANALYTICS_READ_SCOPE) + + response = client.get("/analytics/calculate/requests") + + assert response.status_code == 401 + + +def test__calculate_analytics_requests_route__token_without_scope_returns_403( + scoped_analytics_client_factory, +): + client = scoped_analytics_client_factory("") + + response = client.get( + "/analytics/calculate/requests", + headers={"Authorization": f"Bearer {TEST_AUTH_TOKEN}"}, + ) + + assert response.status_code == 403 + + +@pytest.mark.parametrize( + "authorization_header", + [ + f"Bearer {TEST_AUTH_TOKEN}-wrong", + "Bearer not-a-valid-token", + f"Basic {TEST_AUTH_TOKEN}", + "Bearer", + ], +) +def test__calculate_analytics_requests_route__malformed_or_wrong_token_returns_401( + scoped_analytics_client_factory, + authorization_header, +): + client = scoped_analytics_client_factory(ANALYTICS_READ_SCOPE) + + response = client.get( + "/analytics/calculate/requests", + headers={"Authorization": authorization_header}, + ) + + assert response.status_code == 401 + + +@pytest.mark.parametrize( + "scopes", + [ + "read:calculate-analytics-extra", + "prefix:read:calculate-analytics", + "read:calculate", + ], +) +def test__calculate_analytics_requests_route__deceptive_scope_returns_403( + scoped_analytics_client_factory, + scopes, +): + client = scoped_analytics_client_factory(scopes) + + response = client.get( + "/analytics/calculate/requests", + headers={"Authorization": f"Bearer {TEST_AUTH_TOKEN}"}, + ) + + assert response.status_code == 403 + + +def test__calculate_analytics_requests_route__token_with_scope_returns_200( + scoped_analytics_client_factory, +): + client = scoped_analytics_client_factory(ANALYTICS_READ_SCOPE) + + response = client.get( + "/analytics/calculate/requests", + headers={"Authorization": f"Bearer {TEST_AUTH_TOKEN}"}, + ) + + payload = json.loads(response.data) + assert response.status_code == 200 + assert payload["requests"] == [] + + +def test__calculate_analytics_requests_route__token_with_scope_among_others_returns_200( + scoped_analytics_client_factory, +): + client = scoped_analytics_client_factory( + f"openid profile {ANALYTICS_READ_SCOPE}" + ) + + response = client.get( + "/analytics/calculate/requests", + headers={"Authorization": f"Bearer {TEST_AUTH_TOKEN}"}, + ) + + payload = json.loads(response.data) + assert response.status_code == 200 + assert payload["requests"] == [] + + +def test__normal_protected_route__token_without_analytics_scope_returns_200( + scoped_analytics_client_factory, +): + client = scoped_analytics_client_factory("") + + response = client.post( + "/us/calculate", + headers={"Authorization": f"Bearer {TEST_AUTH_TOKEN}"}, + ) + + assert response.status_code == 200 diff --git a/tests/unit/utils/test_config_loader.py b/tests/unit/utils/test_config_loader.py index 033e860a..251788e3 100644 --- a/tests/unit/utils/test_config_loader.py +++ b/tests/unit/utils/test_config_loader.py @@ -201,6 +201,10 @@ def test__given_traditional_env_vars__config_maps_them_correctly( config["auth"]["auth0"]["audience"] == ENV_VAR_TEST_DATA["AUTH0_AUDIENCE_NO_DOMAIN"] ) + assert ( + config["auth"]["auth0"]["test_token_scopes"] + == ENV_VAR_TEST_DATA["AUTH0_TEST_TOKEN_SCOPES"] + ) assert ( config["ai"]["anthropic"]["api_key"] == ENV_VAR_TEST_DATA["ANTHROPIC_API_KEY"] @@ -616,6 +620,7 @@ def test__given_auth0_config_with_env_vars__substitution_enables_auth( clean_env.setenv("AUTH0_ADDRESS_NO_DOMAIN", "test.auth0.com") clean_env.setenv("AUTH0_AUDIENCE_NO_DOMAIN", "https://test-api") clean_env.setenv("AUTH0_TEST_TOKEN_NO_DOMAIN", "test-jwt-token") + clean_env.setenv("AUTH0_TEST_TOKEN_SCOPES", "read:calculate-analytics") config_data = { "auth": { @@ -624,6 +629,7 @@ def test__given_auth0_config_with_env_vars__substitution_enables_auth( "address": "${AUTH0_ADDRESS_NO_DOMAIN}", "audience": "${AUTH0_AUDIENCE_NO_DOMAIN}", "test_token": "${AUTH0_TEST_TOKEN_NO_DOMAIN}", + "test_token_scopes": "${AUTH0_TEST_TOKEN_SCOPES}", }, } } @@ -638,6 +644,10 @@ def test__given_auth0_config_with_env_vars__substitution_enables_auth( assert config["auth"]["auth0"]["address"] == "test.auth0.com" assert config["auth"]["auth0"]["audience"] == "https://test-api" assert config["auth"]["auth0"]["test_token"] == "test-jwt-token" + assert ( + config["auth"]["auth0"]["test_token_scopes"] + == "read:calculate-analytics" + ) def test__given_external_config_with_env_vars__substitution_occurs( self, tmp_path, clean_env @@ -955,6 +965,10 @@ def test__integration_with_real_config_structure( == "https://household.api.policyengine.org" ) assert config["auth"]["auth0"]["test_token"] == "test-jwt-token" + assert ( + config["auth"]["auth0"]["test_token_scopes"] + == "read:calculate-analytics" + ) assert ( config["analytics"]["database"]["connection_name"] == "project:region:instance" diff --git a/uv.lock b/uv.lock index a9922ef8..9a2a4fc6 100644 --- a/uv.lock +++ b/uv.lock @@ -2191,7 +2191,7 @@ wheels = [ [[package]] name = "policyengine-household-api" -version = "0.16.2" +version = "0.17.0" source = { editable = "." } dependencies = [ { name = "alembic" },