Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions hindsight-api-slim/hindsight_api/engine/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,11 @@ def create_llm_provider(
VertexAI and LiteLLM providers (each merges them in its own parameter
space). Keys must use each provider's native names (e.g. ``max_tokens``
for OpenAI/Anthropic vs ``max_output_tokens`` for Gemini).
default_headers: Custom headers passed as ``default_headers`` to provider SDK clients
(used by operators routing through proxies / request-tracing middleware). Currently
wired into the Anthropic provider; other providers may opt in as needed.
default_headers: Custom headers passed to provider SDK clients (used by operators
routing through proxies / request-tracing middleware). Wired into the Anthropic
provider (SDK ``default_headers``) and the LiteLLM-backed providers — ``litellm``,
``litellmrouter`` and ``bedrock`` — as the LiteLLM ``extra_headers`` completion
kwarg; other providers may opt in as needed.
vertexai_project_id: Vertex AI project ID (for VertexAI provider).
vertexai_region: Vertex AI region (for VertexAI provider).
vertexai_credentials: Vertex AI credentials object (for VertexAI provider).
Expand Down Expand Up @@ -375,6 +377,7 @@ def create_llm_provider(
model=model,
reasoning_effort=reasoning_effort,
extra_body=extra_body,
default_headers=default_headers,
)

elif provider_lower == "litellmrouter":
Expand All @@ -393,6 +396,7 @@ def create_llm_provider(
config=litellmrouter_config,
reasoning_effort=reasoning_effort,
extra_body=extra_body,
default_headers=default_headers,
)

elif provider_lower == "bedrock":
Expand All @@ -405,6 +409,7 @@ def create_llm_provider(
model=bedrock_model,
reasoning_effort=reasoning_effort,
extra_body=extra_body,
default_headers=default_headers,
bedrock_service_tier=bedrock_service_tier,
)

Expand Down
15 changes: 15 additions & 0 deletions hindsight-api-slim/hindsight_api/engine/providers/litellm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
timeout: float | None = None,
extra_body: dict[str, Any] | None = None,
bedrock_service_tier: str | None = None,
default_headers: dict[str, Any] | None = None,
**kwargs: Any,
):
super().__init__(provider, api_key, base_url, model, reasoning_effort, **kwargs)
Expand All @@ -84,6 +85,13 @@ def __init__(
# drops any the target model rejects (litellm.drop_params=True below).
# Sourced from llm_extra_body (env: HINDSIGHT_API_LLM_EXTRA_BODY).
self._extra_body: dict[str, Any] = extra_body or {}
# Operator-configured default headers forwarded to litellm.acompletion as
# ``extra_headers`` (used by deployments routing through proxies / request-
# tracing middleware). Mirrors the Anthropic provider's default_headers
# wiring. Sourced from llm_default_headers (env: HINDSIGHT_API_LLM_DEFAULT_HEADERS).
# Copied so a caller-owned dict can't be mutated through us, and a fresh
# copy is handed to each call below to avoid cross-request contamination.
self._default_headers: dict[str, Any] = dict(default_headers or {})
self.bedrock_service_tier = bedrock_service_tier

try:
Expand Down Expand Up @@ -144,6 +152,13 @@ def _build_common_kwargs(
for key, value in self._extra_body.items():
kwargs.setdefault(key, value)

# Forward operator-configured default headers as ``extra_headers`` so they
# reach the provider behind LiteLLM (proxies / request-tracing middleware).
# ``setdefault`` keeps any explicit per-call ``extra_headers`` authoritative;
# a per-call copy prevents LiteLLM/downstream from mutating the stored dict.
if self._default_headers:
kwargs.setdefault("extra_headers", dict(self._default_headers))

# Bedrock service tier: flex (50% cheaper), priority, or reserved
if self.model.startswith("bedrock/") and self.bedrock_service_tier is not None:
kwargs["service_tier"] = self.bedrock_service_tier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ def _build_common_kwargs(
kwargs["max_completion_tokens"] = self._cap_max_completion_tokens(max_completion_tokens)
if temperature is not None:
kwargs["temperature"] = temperature

# Forward operator-configured default headers as ``extra_headers`` so they
# reach the provider behind the Router (proxies / request-tracing middleware).
# This override deliberately omits api_key/base_url/extra_body (those live in
# the per-deployment Router config), but headers are a cross-cutting operator
# concern, so we inject them here too — mirroring the base provider.
# ``setdefault`` keeps any explicit per-call ``extra_headers`` authoritative;
# a per-call copy prevents LiteLLM/downstream from mutating the stored dict.
if self._default_headers:
kwargs.setdefault("extra_headers", dict(self._default_headers))
return kwargs

async def verify_connection(self) -> None:
Expand Down
6 changes: 2 additions & 4 deletions hindsight-api-slim/hindsight_api/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ def _ensure_pgvector_extension_in_public(conn: Connection) -> None:
else:
# Extension in wrong schema - try to fix if we have permissions
logger.warning(
f"pgvector extension found in schema '{ext_schema}' instead of 'public'. "
f"Attempting to relocate..."
f"pgvector extension found in schema '{ext_schema}' instead of 'public'. Attempting to relocate..."
)
try:
conn.execute(text("DROP EXTENSION vector CASCADE"))
Expand Down Expand Up @@ -131,8 +130,7 @@ def _ensure_pgvector_extension_in_public(conn: Connection) -> None:
f"See: https://github.com/pgvector/pgvector#installation"
)
raise RuntimeError(
"pgvector extension is required but not installed. "
"Please install it with: CREATE EXTENSION vector;"
"pgvector extension is required but not installed. Please install it with: CREATE EXTENSION vector;"
) from e


Expand Down
108 changes: 107 additions & 1 deletion hindsight-api-slim/tests/test_llm_extra_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest

EXTRA_BODY = {"temperature": 0.2, "top_p": 0.9}
DEFAULT_HEADERS = {"X-Component-Id": "hindsight", "X-Trace": "abc"}


# ─── config / env parsing ─────────────────────────────────────────────────────
Expand Down Expand Up @@ -348,7 +349,7 @@ class StructuredAnswer(BaseModel):
# ─── LiteLLM ──────────────────────────────────────────────────────────────────


def _make_litellm_provider(extra_body=None):
def _make_litellm_provider(extra_body=None, default_headers=None):
pytest.importorskip("litellm")
from hindsight_api.engine.providers.litellm_llm import LiteLLMLLM

Expand All @@ -358,6 +359,7 @@ def _make_litellm_provider(extra_body=None):
base_url="",
model="gpt-4o",
extra_body=extra_body,
default_headers=default_headers,
)


Expand Down Expand Up @@ -419,3 +421,107 @@ def test_litellm_router_forwards_extra_body():
extra_body=EXTRA_BODY,
)
assert provider._extra_body == EXTRA_BODY


def test_litellm_stores_default_headers():
provider = _make_litellm_provider(default_headers=DEFAULT_HEADERS)
assert provider._default_headers == DEFAULT_HEADERS


def test_litellm_empty_default_headers_defaults_to_dict():
provider = _make_litellm_provider(default_headers=None)
assert provider._default_headers == {}


@pytest.mark.asyncio
async def test_litellm_call_passes_default_headers_as_extra_headers():
"""``call()`` forwards default_headers to acompletion via ``extra_headers``."""
provider = _make_litellm_provider(default_headers=DEFAULT_HEADERS)
provider._acompletion = AsyncMock(return_value=_fake_litellm_response())

with patch("hindsight_api.engine.providers.litellm_llm.get_metrics_collector"):
await provider.call(messages=[{"role": "user", "content": "hi"}], scope="test", max_retries=0)

assert provider._acompletion.call_args.kwargs.get("extra_headers") == DEFAULT_HEADERS


@pytest.mark.asyncio
async def test_litellm_no_default_headers_omits_extra_headers():
"""``call()`` does not pass ``extra_headers`` when none are configured."""
provider = _make_litellm_provider(default_headers=None)
provider._acompletion = AsyncMock(return_value=_fake_litellm_response())

with patch("hindsight_api.engine.providers.litellm_llm.get_metrics_collector"):
await provider.call(messages=[{"role": "user", "content": "hi"}], scope="test", max_retries=0)

assert "extra_headers" not in provider._acompletion.call_args.kwargs


@pytest.mark.asyncio
async def test_litellm_default_headers_passed_as_fresh_copy():
"""Each call gets its own ``extra_headers`` copy so downstream mutation can't
contaminate the stored headers or other requests."""
provider = _make_litellm_provider(default_headers=DEFAULT_HEADERS)
provider._acompletion = AsyncMock(return_value=_fake_litellm_response())

with patch("hindsight_api.engine.providers.litellm_llm.get_metrics_collector"):
await provider.call(messages=[{"role": "user", "content": "hi"}], scope="test", max_retries=0)

passed = provider._acompletion.call_args.kwargs["extra_headers"]
assert passed == DEFAULT_HEADERS
assert passed is not provider._default_headers
passed["X-Injected"] = "1"
assert "X-Injected" not in provider._default_headers


def test_litellm_default_headers_copied_from_caller_dict():
"""A caller-owned dict cannot be mutated through the provider."""
caller_dict = {"X-Component-Id": "hindsight"}
provider = _make_litellm_provider(default_headers=caller_dict)
caller_dict["X-Mutated"] = "1"
assert "X-Mutated" not in provider._default_headers


def _make_litellm_router_provider(default_headers=None):
pytest.importorskip("litellm")
from hindsight_api.engine.providers.litellm_router_llm import LiteLLMRouterLLM

config = {"model_list": [{"model_name": "default", "litellm_params": {"model": "gpt-4o", "api_key": "x"}}]}
return LiteLLMRouterLLM(
provider="litellmrouter",
api_key="",
base_url="",
model="default",
config=config,
default_headers=default_headers,
)


def test_litellm_router_stores_default_headers():
provider = _make_litellm_router_provider(default_headers=DEFAULT_HEADERS)
assert provider._default_headers == DEFAULT_HEADERS


@pytest.mark.asyncio
async def test_litellm_router_call_passes_default_headers_as_extra_headers():
"""The Router's ``_build_common_kwargs`` override must also forward default_headers
as ``extra_headers`` — storage alone doesn't reach the provider behind the Router."""
provider = _make_litellm_router_provider(default_headers=DEFAULT_HEADERS)
provider._acompletion = AsyncMock(return_value=_fake_litellm_response())

with patch("hindsight_api.engine.providers.litellm_llm.get_metrics_collector"):
await provider.call(messages=[{"role": "user", "content": "hi"}], scope="test", max_retries=0)

assert provider._acompletion.call_args.kwargs.get("extra_headers") == DEFAULT_HEADERS


@pytest.mark.asyncio
async def test_litellm_router_no_default_headers_omits_extra_headers():
"""The Router omits ``extra_headers`` entirely when none are configured."""
provider = _make_litellm_router_provider(default_headers=None)
provider._acompletion = AsyncMock(return_value=_fake_litellm_response())

with patch("hindsight_api.engine.providers.litellm_llm.get_metrics_collector"):
await provider.call(messages=[{"role": "user", "content": "hi"}], scope="test", max_retries=0)

assert "extra_headers" not in provider._acompletion.call_args.kwargs