|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import asyncio |
| 4 | +from types import SimpleNamespace |
| 5 | + |
| 6 | +import pytest |
| 7 | +from fastapi import Depends, FastAPI |
| 8 | +from fastapi.testclient import TestClient |
| 9 | + |
| 10 | +from src.api import dependencies as deps |
| 11 | +from src.api.middleware import RequestContextMiddleware, SecurityHeadersMiddleware |
| 12 | +from src.api.routes.health import router as health_router |
| 13 | +from src.schemas.retrieval import RetrievalResult |
| 14 | + |
| 15 | + |
| 16 | +class FakeIngestPipeline: |
| 17 | + model = SimpleNamespace(model="fake-ingest") |
| 18 | + |
| 19 | + async def run(self, **kwargs): |
| 20 | + return {"classification_result": SimpleNamespace(classifications=[])} |
| 21 | + |
| 22 | + def close(self): |
| 23 | + pass |
| 24 | + |
| 25 | + |
| 26 | +class FakeRetrievalPipeline: |
| 27 | + model = SimpleNamespace(model="fake-retrieval") |
| 28 | + |
| 29 | + async def run(self, query: str, user_id: str, top_k: int = 5): |
| 30 | + return RetrievalResult(query=query, answer=f"answer for {user_id}", sources=[], confidence=0.1) |
| 31 | + |
| 32 | + def close(self): |
| 33 | + pass |
| 34 | + |
| 35 | + |
| 36 | +@pytest.fixture |
| 37 | +def dependency_app(monkeypatch): |
| 38 | + monkeypatch.setattr(deps.settings, "api_keys", ["test-static-key"], raising=False) |
| 39 | + deps._init_error = None |
| 40 | + deps._pipelines_ready.set() |
| 41 | + deps.set_pipelines(FakeIngestPipeline(), FakeRetrievalPipeline()) |
| 42 | + |
| 43 | + app = FastAPI() |
| 44 | + app.add_middleware(SecurityHeadersMiddleware) |
| 45 | + app.add_middleware(RequestContextMiddleware) |
| 46 | + app.include_router(health_router) |
| 47 | + |
| 48 | + @app.get("/protected") |
| 49 | + async def protected(user: dict = Depends(deps.require_api_key)): |
| 50 | + return {"user_id": user["id"], "email": user["email"]} |
| 51 | + |
| 52 | + @app.get("/pipeline") |
| 53 | + async def pipeline(_ready=Depends(deps.require_ready)): |
| 54 | + return {"ingest": deps.get_ingest_pipeline().model.model} |
| 55 | + |
| 56 | + return app |
| 57 | + |
| 58 | + |
| 59 | +def test_health_route_uses_readiness_state(dependency_app): |
| 60 | + deps.set_startup_time(0) |
| 61 | + |
| 62 | + response = TestClient(dependency_app).get("/health") |
| 63 | + |
| 64 | + assert response.status_code == 200 |
| 65 | + assert response.json()["data"]["status"] == "ready" |
| 66 | + |
| 67 | + |
| 68 | +def test_auth_dependency_rejects_missing_and_accepts_static_bearer_key(dependency_app): |
| 69 | + client = TestClient(dependency_app) |
| 70 | + |
| 71 | + missing = client.get("/protected") |
| 72 | + assert missing.status_code == 401 |
| 73 | + |
| 74 | + ok = client.get("/protected", headers={"Authorization": "Bearer test-static-key"}) |
| 75 | + assert ok.status_code == 200 |
| 76 | + assert ok.json()["email"] == "static@xmem.ai" |
| 77 | + assert ok.headers["x-content-type-options"] == "nosniff" |
| 78 | + assert "x-request-id" in ok.headers |
| 79 | + |
| 80 | + |
| 81 | +def test_dependency_injection_returns_configured_pipeline(dependency_app): |
| 82 | + response = TestClient(dependency_app).get("/pipeline") |
| 83 | + |
| 84 | + assert response.status_code == 200 |
| 85 | + assert response.json() == {"ingest": "fake-ingest"} |
| 86 | + |
| 87 | + |
| 88 | +@pytest.mark.asyncio |
| 89 | +async def test_rate_limiter_blocks_after_limit(monkeypatch): |
| 90 | + limiter = deps._SlidingWindowRateLimiter(max_requests=1, window_seconds=60) |
| 91 | + assert await limiter.check("user-1") == (True, 0) |
| 92 | + assert await limiter.check("user-1") == (False, 0) |
0 commit comments