diff --git a/hindsight-api-slim/hindsight_api/api/http.py b/hindsight-api-slim/hindsight_api/api/http.py index c357f4116..9a0fc7d54 100644 --- a/hindsight-api-slim/hindsight_api/api/http.py +++ b/hindsight-api-slim/hindsight_api/api/http.py @@ -2885,6 +2885,57 @@ def get_request_context(authorization: str | None = Header(default=None)) -> Req api_key = authorization.strip() return RequestContext(api_key=api_key) + def precheck_for(operation: str): + """ + Build a FastAPI dependency that runs ``OperationValidator.precheck``. + + FastAPI resolves dependencies before deserialising the route's body + parameter. Wiring this dependency on the billable POST routes lets + an extension reject a request — e.g. with HTTP 402 when a tenant's + balance is exhausted — without the request body ever being read or + materialised in memory. + + The dependency intentionally: + - authenticates the tenant (so ``request_context.tenant_id`` is + resolved before the precheck runs); + - falls through silently when no validator is configured or the + validator's default no-op precheck is in effect; + - converts a rejection ``ValidationResult`` into the corresponding + ``HTTPException`` directly (the per-route ``OperationValidationError`` + catch blocks don't see exceptions raised in dependencies, so we + translate here instead of relying on each handler's try/except). + + Args: + operation: Short identifier for the route, e.g. ``"retain"``. + + Returns: + A FastAPI dependency callable suitable for ``Depends(...)``. + """ + + async def _precheck_dep( + bank_id: str, + request_context: RequestContext = Depends(get_request_context), + ) -> None: + validator = getattr(app.state.memory, "_operation_validator", None) + if validator is None: + return + from hindsight_api.extensions import PrecheckContext + + await app.state.memory._authenticate_tenant(request_context) + ctx = PrecheckContext( + operation=operation, + bank_id=bank_id, + request_context=request_context, + ) + result = await validator.precheck(ctx) + if not result.allowed: + raise HTTPException( + status_code=result.status_code, + detail=result.reason or "Operation not allowed", + ) + + return _precheck_dep + # Global exception handler for authentication errors @app.exception_handler(AuthenticationError) async def authentication_error_handler(request, exc: AuthenticationError): @@ -3142,7 +3193,10 @@ async def api_get_observation_history( ) @audited("recall") async def api_recall( - bank_id: str, request: RecallRequest, request_context: RequestContext = Depends(get_request_context) + bank_id: str, + request: RecallRequest, + request_context: RequestContext = Depends(get_request_context), + _precheck: None = Depends(precheck_for("recall")), ): """Run a recall and return results with trace.""" import time @@ -3330,7 +3384,10 @@ def _fact_to_result(fact: "MemoryFact") -> RecallResult: ) @audited("reflect") async def api_reflect( - bank_id: str, request: ReflectRequest, request_context: RequestContext = Depends(get_request_context) + bank_id: str, + request: ReflectRequest, + request_context: RequestContext = Depends(get_request_context), + _precheck: None = Depends(precheck_for("reflect")), ): metrics = get_metrics_collector() @@ -3828,6 +3885,7 @@ async def api_create_mental_model( bank_id: str, body: CreateMentalModelRequest, request_context: RequestContext = Depends(get_request_context), + _precheck: None = Depends(precheck_for("mental_model_create")), ): """Create a mental model (async - returns operation_id).""" try: @@ -3876,6 +3934,7 @@ async def api_refresh_mental_model( bank_id: str, mental_model_id: str, request_context: RequestContext = Depends(get_request_context), + _precheck: None = Depends(precheck_for("mental_model_refresh")), ): """Refresh a mental model by re-running its source query (async).""" try: @@ -5722,7 +5781,10 @@ async def api_list_webhook_deliveries( ) @audited("retain") async def api_retain( - bank_id: str, request: RetainRequest, request_context: RequestContext = Depends(get_request_context) + bank_id: str, + request: RetainRequest, + request_context: RequestContext = Depends(get_request_context), + _precheck: None = Depends(precheck_for("retain")), ): """Retain memories with optional async processing.""" metrics = get_metrics_collector() @@ -5892,6 +5954,7 @@ async def api_file_retain( files: list[UploadFile] = File(..., description="Files to upload and convert"), request: str = Form(..., description="JSON string with FileRetainRequest model"), request_context: RequestContext = Depends(get_request_context), + _precheck: None = Depends(precheck_for("files_retain")), ): """Upload and convert files to memories.""" from hindsight_api.config import get_config diff --git a/hindsight-api-slim/hindsight_api/extensions/__init__.py b/hindsight-api-slim/hindsight_api/extensions/__init__.py index e56085a74..0027a51d4 100644 --- a/hindsight-api-slim/hindsight_api/extensions/__init__.py +++ b/hindsight-api-slim/hindsight_api/extensions/__init__.py @@ -40,6 +40,7 @@ # Core operations OperationValidationError, OperationValidatorExtension, + PrecheckContext, RecallContext, RecallResult, ReflectContext, @@ -72,6 +73,7 @@ "DeferOperation", "OperationValidationError", "OperationValidatorExtension", + "PrecheckContext", "RecallContext", "RecallResult", "ReflectContext", diff --git a/hindsight-api-slim/hindsight_api/extensions/operation_validator.py b/hindsight-api-slim/hindsight_api/extensions/operation_validator.py index 6570cddb8..3ab3f7ffc 100644 --- a/hindsight-api-slim/hindsight_api/extensions/operation_validator.py +++ b/hindsight-api-slim/hindsight_api/extensions/operation_validator.py @@ -82,6 +82,33 @@ def reject(cls, reason: str, status_code: int = 403) -> "ValidationResult": # ============================================================================= +@dataclass +class PrecheckContext: + """Context for a pre-body-parse precheck on an operation. + + Unlike :class:`RetainContext` / :class:`RecallContext` / etc., this + context is constructed *before* the request body is deserialised. It + therefore intentionally carries only the cheap, already-resolved + pieces of request state: + + - ``operation``: a short string identifying the route, e.g. ``"retain"``, + ``"recall"``, ``"reflect"``, ``"files_retain"``, ``"mental_model_create"``, + ``"mental_model_refresh"``. + - ``bank_id``: parsed from the URL path. + - ``request_context``: the authenticated :class:`RequestContext` (tenant + already resolved by the tenant extension). + + Implementations should keep precheck cheap and side-effect-free. The + full per-request validators (``validate_retain`` / ``validate_recall`` + / ``validate_reflect``) still run after the body is parsed and remain + the source of truth for the precise per-call cost / quota arithmetic. + """ + + operation: str + bank_id: str + request_context: "RequestContext" + + @dataclass class RetainContext: """Context for a retain operation validation (pre-operation). @@ -407,6 +434,42 @@ class OperationValidatorExtension(Extension, ABC): - consolidate (mental models consolidation) """ + # ========================================================================= + # Pre-body-parse hook (optional - default no-op) + # ========================================================================= + + async def precheck(self, ctx: PrecheckContext) -> ValidationResult: + """ + Cheap pre-body-parse check, called before the request body is read. + + FastAPI resolves ``Depends`` callables before deserialising the route + body; routes that wire ``precheck`` as a dependency therefore short + -circuit here without ever materialising the JSON payload in memory. + That makes this the right hook for "should this caller be allowed to + spend resources on this request at all" decisions — e.g. a balance + is exhausted, a key is revoked, or a tenant is rate-limited. + + Implementations should: + - Be cheap: prefer cached lookups, avoid heavy DB queries. + - Use only data on ``ctx`` (operation name + bank_id + request_context); + the body is not yet available. + - Be conservative on errors: prefer ``ValidationResult.accept()`` so + a transient lookup failure doesn't turn into a request rejection. + The post-body ``validate_*`` hooks still run and remain the source + of truth for the precise per-call cost check. + + Default implementation accepts everything. Override to opt in. + + Args: + ctx: Pre-body context with operation name, bank_id, and + request_context (tenant already resolved). + + Returns: + ValidationResult indicating whether the request may proceed to + body parsing and the post-parse validators. + """ + return ValidationResult.accept() + # ========================================================================= # Pre-operation validation hooks (abstract - must be implemented) # ========================================================================= diff --git a/hindsight-api-slim/tests/test_extensions.py b/hindsight-api-slim/tests/test_extensions.py index 292da62b4..28c687eaa 100644 --- a/hindsight-api-slim/tests/test_extensions.py +++ b/hindsight-api-slim/tests/test_extensions.py @@ -13,6 +13,7 @@ HttpExtension, OperationValidationError, OperationValidatorExtension, + PrecheckContext, RecallContext, RecallResult, ReflectContext, @@ -814,3 +815,264 @@ def test_core_routes_still_work_with_extension(self, memory): # Banks list endpoint should work response = client.get("/v1/default/banks") assert response.status_code in (200, 500) # May fail if DB not ready + + +# ============================================================================ +# Precheck (pre-body-parse) tests +# ============================================================================ +# +# The precheck() hook is wired as a FastAPI Depends on the billable POST +# routes. FastAPI resolves dependencies before deserialising the route's body +# parameter, so a rejecting precheck never causes the request body to be read +# or materialised in memory. The test below uses a Pydantic model_validator +# that records every parse to assert that body parsing never runs on the +# rejection path. + + +class RecordingPrecheckValidator(OperationValidatorExtension): + """Validator that records every precheck call and can be configured to reject. + + Used to drive the FastAPI dependency that runs precheck() before body parse. + The validate_* hooks below are required-abstract no-ops so the class is + instantiable; the tests here only exercise precheck. + """ + + def __init__(self, *, reject: bool = False, status_code: int = 402, + reason: str = "rejected by precheck") -> None: + super().__init__(config={}) + self.reject = reject + self.status_code = status_code + self.reason = reason + self.precheck_calls: list[PrecheckContext] = [] + + async def precheck(self, ctx: PrecheckContext) -> ValidationResult: + self.precheck_calls.append(ctx) + if self.reject: + return ValidationResult.reject(self.reason, status_code=self.status_code) + return ValidationResult.accept() + + async def validate_retain(self, ctx: RetainContext) -> ValidationResult: + return ValidationResult.accept() + + async def validate_recall(self, ctx: RecallContext) -> ValidationResult: + return ValidationResult.accept() + + async def validate_reflect(self, ctx: ReflectContext) -> ValidationResult: + return ValidationResult.accept() + + +class TestPrecheckDefault: + """The base OperationValidatorExtension.precheck is a no-op accept.""" + + @pytest.mark.asyncio + async def test_default_precheck_accepts(self): + validator = RecordingPrecheckValidator(reject=False) + # Bypass our override by calling the base implementation directly. + ctx = PrecheckContext( + operation="retain", + bank_id="bank-x", + request_context=RequestContext(), + ) + result = await OperationValidatorExtension.precheck(validator, ctx) + assert result.allowed is True + assert result.reason is None + + +class TestPrecheckHttpWiring: + """precheck() is wired as a FastAPI Depends on the billable POST routes. + + These tests do NOT use the heavy ``memory`` fixture (which requires a + running pg0 + migrations). Instead they construct a minimal FastAPI + app that mirrors the same Depends ordering used in + ``hindsight_api.api.http`` (a ``Depends(precheck_for(...))`` resolved + before the Pydantic body parameter), so the contract under test — + "rejection happens before body parse" — can be exercised in isolation. + + The critical assertion in test_precheck_rejection_skips_body_parse is + that a rejection response is returned without the request body being + deserialised by Pydantic — i.e. the body parser was never invoked on + the rejection path. + """ + + @staticmethod + def _build_app(validator): + """Mirror the precheck wiring from ``hindsight_api.api.http`` in a + standalone FastAPI app.""" + from fastapi import Depends, FastAPI, HTTPException + from pydantic import BaseModel, model_validator + + from hindsight_api.extensions import PrecheckContext + from hindsight_api.models import RequestContext + + body_parses: list[str] = [] + + class _RetainBody(BaseModel): + items: list + + @model_validator(mode="before") + @classmethod + def _record(cls, v): + body_parses.append("retain") + return v + + class _RecallBody(BaseModel): + query: str + + @model_validator(mode="before") + @classmethod + def _record(cls, v): + body_parses.append("recall") + return v + + class _ReflectBody(BaseModel): + query: str + + @model_validator(mode="before") + @classmethod + def _record(cls, v): + body_parses.append("reflect") + return v + + async def _request_context() -> RequestContext: + return RequestContext() + + def _precheck_for(operation: str): + async def _dep( + bank_id: str, + request_context: RequestContext = Depends(_request_context), + ) -> None: + ctx = PrecheckContext( + operation=operation, + bank_id=bank_id, + request_context=request_context, + ) + result = await validator.precheck(ctx) + if not result.allowed: + raise HTTPException( + status_code=result.status_code, + detail=result.reason or "Operation not allowed", + ) + + return _dep + + app = FastAPI() + + @app.post("/v1/default/banks/{bank_id}/memories") + async def retain( + bank_id: str, + body: _RetainBody, + _: None = Depends(_precheck_for("retain")), + ): + return {"ok": True, "bank_id": bank_id, "n": len(body.items)} + + @app.post("/v1/default/banks/{bank_id}/memories/recall") + async def recall( + bank_id: str, + body: _RecallBody, + _: None = Depends(_precheck_for("recall")), + ): + return {"ok": True} + + @app.post("/v1/default/banks/{bank_id}/reflect") + async def reflect( + bank_id: str, + body: _ReflectBody, + _: None = Depends(_precheck_for("reflect")), + ): + return {"ok": True} + + @app.get("/v1/default/banks/{bank_id}/memories/list") + async def list_memories(bank_id: str): + return {"ok": True} + + return app, body_parses + + def test_precheck_accept_lets_request_through_to_body_parse(self): + validator = RecordingPrecheckValidator(reject=False) + app, body_parses = self._build_app(validator) + client = TestClient(app) + + resp = client.post( + "/v1/default/banks/precheck-bank/memories", + json={"items": [{"content": "x"}]}, + ) + assert resp.status_code == 200 + assert len(validator.precheck_calls) == 1 + assert validator.precheck_calls[0].operation == "retain" + assert validator.precheck_calls[0].bank_id == "precheck-bank" + assert body_parses == ["retain"] + + def test_precheck_rejection_returns_status_and_reason(self): + validator = RecordingPrecheckValidator( + reject=True, status_code=402, reason="Insufficient credits" + ) + app, _ = self._build_app(validator) + client = TestClient(app) + + resp = client.post( + "/v1/default/banks/precheck-bank/memories", + json={"items": [{"content": "x"}]}, + ) + assert resp.status_code == 402 + assert resp.json()["detail"] == "Insufficient credits" + + def test_precheck_rejection_skips_body_parse(self): + """The critical assertion: rejection happens before Pydantic + deserialises the body. We send an oversized body and verify the + body-parse counter never incremented. + """ + validator = RecordingPrecheckValidator( + reject=True, status_code=402, reason="rejected by precheck" + ) + app, body_parses = self._build_app(validator) + client = TestClient(app) + + resp = client.post( + "/v1/default/banks/precheck-bank/memories", + json={"items": [{"content": "x" * 100_000} for _ in range(50)]}, + ) + assert resp.status_code == 402 + assert "rejected by precheck" in resp.json()["detail"] + assert body_parses == [], ( + "request body was deserialised despite a rejecting precheck — " + "the Depends-before-body-parse contract is broken" + ) + + def test_precheck_rejection_skips_body_parse_for_recall(self): + validator = RecordingPrecheckValidator( + reject=True, status_code=402, reason="rejected" + ) + app, body_parses = self._build_app(validator) + client = TestClient(app) + + resp = client.post( + "/v1/default/banks/precheck-bank/memories/recall", + json={"query": "x" * 100_000}, + ) + assert resp.status_code == 402 + assert validator.precheck_calls[-1].operation == "recall" + assert body_parses == [] + + def test_precheck_rejection_skips_body_parse_for_reflect(self): + validator = RecordingPrecheckValidator( + reject=True, status_code=402, reason="rejected" + ) + app, body_parses = self._build_app(validator) + client = TestClient(app) + + resp = client.post( + "/v1/default/banks/precheck-bank/reflect", + json={"query": "x" * 100_000}, + ) + assert resp.status_code == 402 + assert validator.precheck_calls[-1].operation == "reflect" + assert body_parses == [] + + def test_precheck_does_not_run_on_get(self): + validator = RecordingPrecheckValidator(reject=True) + app, _ = self._build_app(validator) + client = TestClient(app) + + resp = client.get("/v1/default/banks/precheck-bank/memories/list") + assert resp.status_code == 200 + assert len(validator.precheck_calls) == 0