Skip to content

Commit fb16fc4

Browse files
sireikaclaudenicoloboschi
authored
feat(api): per-bank provider cost attribution via OpenAI user field (#1965)
* feat(api): per-bank provider cost attribution via OpenAI user field Lets operators attribute Hindsight's provider spend per bank. - Add a `_current_bank_id` engine ContextVar (mirroring the existing `_current_schema` pattern) bound in recall_async, retain_async, retain_batch_async, and execute_task, with a `get_current_bank_id()` accessor. Bindings use a token + finally reset. - Add `HINDSIGHT_API_LLM_SEND_BANK_AS_USER` (bool, default off). When on, outbound OpenAI-compatible LLM and embedding calls are tagged with `user=<bank_id>` so downstream cost gateways (OpenRouter usage accounting, LiteLLM, Helicone) can key spend per bank. Injection is centralized per call_params construction site and never overrides a `user` the caller already set. - Propagate the bank ContextVar into the embedding executor thread: generate_embeddings_batch now copies the current context before the run_in_executor offload (run_in_executor does not inherit contextvars), preserving the existing exception wrapping and 1:1 length validation. - Make the OpenRouter reranker base URL configurable via `HINDSIGHT_API_RERANKER_OPENROUTER_BASE_URL` (default unchanged: https://openrouter.ai/api/v1/rerank) so rerank can route through a metering gateway. The URL is a credential field (not bank-configurable). Tests cover ContextVar set/reset including on exception, user injection gated on flag + bank presence + no caller override (chat and tool-calling paths plus embeddings), real-executor context propagation, and the configurable rerank base URL. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * refactor(engine): bind the bank ContextVar via decorator, not inline try/finally The inline token/try/finally wraps re-indented the entire bodies of execute_task, retain_batch_async, and recall_async — ~1,130 lines of indentation-only churn in the diff for a ~40-line feature. Replace the four inline bindings with a @_bind_bank_id decorator that binds _current_bank_id from the method's bank_id argument (or a key in a dict argument, for execute_task's task_dict) with the same token + finally-reset semantics. Method bodies return to their original indentation, shrinking the memory_engine.py diff to +51/-1. Behavior is unchanged and now directly unit-tested: the decorator gets its own tests for positional/keyword binding, dict-key extraction, reset-on-exception, and non-string fallback to None. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * refactor(api): dedupe bank-attribution helper into shared module Collapse the two identical _apply_bank_attribution copies (embeddings + OpenAI-compatible LLM) into engine/bank_attribution.apply_bank_attribution. Add a docs note that the bank id is transmitted to the provider as the end-user identifier, and de-pad the new config rows. --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Co-authored-by: Nicolò Boschi <boschi1997@gmail.com>
1 parent 82c7df7 commit fb16fc4

12 files changed

Lines changed: 589 additions & 4 deletions

File tree

hindsight-api-slim/hindsight_api/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
145145
ENV_LLM_EXTRA_BODY = "HINDSIGHT_API_LLM_EXTRA_BODY"
146146
ENV_LLM_DEFAULT_HEADERS = "HINDSIGHT_API_LLM_DEFAULT_HEADERS"
147147
ENV_LLM_STRICT_SCHEMA = "HINDSIGHT_API_LLM_STRICT_SCHEMA"
148+
ENV_LLM_SEND_BANK_AS_USER = "HINDSIGHT_API_LLM_SEND_BANK_AS_USER"
148149

149150
# LiteLLM Router chain — provider-specific config consumed by the "litellmrouter"
150151
# provider. Each entry is a deployment; the Router tries them in declared order and
@@ -254,6 +255,7 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
254255
ENV_EMBEDDINGS_OPENROUTER_MODEL = "HINDSIGHT_API_EMBEDDINGS_OPENROUTER_MODEL"
255256
ENV_RERANKER_OPENROUTER_API_KEY = "HINDSIGHT_API_RERANKER_OPENROUTER_API_KEY"
256257
ENV_RERANKER_OPENROUTER_MODEL = "HINDSIGHT_API_RERANKER_OPENROUTER_MODEL"
258+
ENV_RERANKER_OPENROUTER_BASE_URL = "HINDSIGHT_API_RERANKER_OPENROUTER_BASE_URL"
257259

258260
# ZeroEntropy configuration (embeddings)
259261
ENV_EMBEDDINGS_ZEROENTROPY_API_KEY = "HINDSIGHT_API_EMBEDDINGS_ZEROENTROPY_API_KEY"
@@ -620,6 +622,7 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
620622
DEFAULT_LLM_MAX_BACKOFF = 60.0 # Max backoff cap in seconds for retry exponential backoff
621623
DEFAULT_LLM_TIMEOUT = 120.0 # seconds
622624
DEFAULT_LLM_REASONING_EFFORT = "low"
625+
DEFAULT_LLM_SEND_BANK_AS_USER = False # Opt-in: tag provider calls with user=<bank_id>
623626

624627
# Vertex AI defaults
625628
DEFAULT_LLM_VERTEXAI_PROJECT_ID = None # Required for Vertex AI
@@ -740,6 +743,7 @@ def _parse_strategy_boosts(raw: str | None) -> dict[str, str]:
740743
# OpenRouter defaults
741744
DEFAULT_EMBEDDINGS_OPENROUTER_MODEL = "perplexity/pplx-embed-v1-0.6b"
742745
DEFAULT_RERANKER_OPENROUTER_MODEL = "cohere/rerank-v3.5"
746+
DEFAULT_RERANKER_OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/rerank"
743747

744748
# ZeroEntropy defaults
745749
DEFAULT_EMBEDDINGS_ZEROENTROPY_MODEL = "zembed-1"
@@ -1229,6 +1233,11 @@ class HindsightConfig:
12291233
dict | None
12301234
) # Custom headers passed as default_headers to provider SDK clients (e.g. {"X-Component-Id": "hindsight"} for proxies / request tracing)
12311235
llm_strict_schema: bool # Grammar-enforce structured output via the provider's strongest schema mode (see DEFAULT_LLM_STRICT_SCHEMA)
1236+
# Tags outbound OpenAI-compatible LLM + embedding calls with `user=<bank_id>` for
1237+
# per-bank cost attribution. Downstream cost gateways (OpenRouter usage accounting,
1238+
# LiteLLM, Helicone) key attribution on the OpenAI `user` field. Opt-in; never
1239+
# overrides a `user` the caller already set.
1240+
llm_send_bank_as_user: bool
12321241

12331242
# LiteLLM Router chain (provider-specific; consumed by the "litellmrouter" provider).
12341243
# List of deployment dicts evaluated in order with fallback on transient errors.
@@ -1360,6 +1369,7 @@ class HindsightConfig:
13601369
reranker_cohere_timeout: float
13611370
reranker_openrouter_api_key: str | None
13621371
reranker_openrouter_model: str
1372+
reranker_openrouter_base_url: str
13631373
reranker_openrouter_timeout: float
13641374
reranker_litellm_api_base: str
13651375
reranker_litellm_api_key: str | None
@@ -1610,6 +1620,7 @@ class HindsightConfig:
16101620
"embeddings_tei_base_url",
16111621
"reranker_tei_base_url",
16121622
"reranker_cohere_base_url",
1623+
"reranker_openrouter_base_url",
16131624
"embeddings_zeroentropy_base_url",
16141625
"reranker_zeroentropy_base_url",
16151626
"reranker_siliconflow_base_url",
@@ -1904,6 +1915,8 @@ def from_env(cls) -> "HindsightConfig":
19041915
llm_extra_body=json.loads(os.getenv(ENV_LLM_EXTRA_BODY, "null")),
19051916
llm_default_headers=json.loads(os.getenv(ENV_LLM_DEFAULT_HEADERS, "null")),
19061917
llm_strict_schema=os.getenv(ENV_LLM_STRICT_SCHEMA, str(DEFAULT_LLM_STRICT_SCHEMA)).lower() in ("true", "1"),
1918+
llm_send_bank_as_user=os.getenv(ENV_LLM_SEND_BANK_AS_USER, str(DEFAULT_LLM_SEND_BANK_AS_USER)).lower()
1919+
in ("true", "1"),
19071920
llm_litellmrouter_config=_parse_llm_router_config(ENV_LLM_LITELLMROUTER_CONFIG),
19081921
# Vertex AI
19091922
llm_vertexai_project_id=os.getenv(ENV_LLM_VERTEXAI_PROJECT_ID) or DEFAULT_LLM_VERTEXAI_PROJECT_ID,
@@ -2183,6 +2196,9 @@ def from_env(cls) -> "HindsightConfig":
21832196
or os.getenv(ENV_OPENROUTER_API_KEY)
21842197
or os.getenv(ENV_LLM_API_KEY),
21852198
reranker_openrouter_model=os.getenv(ENV_RERANKER_OPENROUTER_MODEL, DEFAULT_RERANKER_OPENROUTER_MODEL),
2199+
reranker_openrouter_base_url=os.getenv(
2200+
ENV_RERANKER_OPENROUTER_BASE_URL, DEFAULT_RERANKER_OPENROUTER_BASE_URL
2201+
),
21862202
reranker_openrouter_timeout=float(
21872203
os.getenv(ENV_RERANKER_OPENROUTER_TIMEOUT, str(DEFAULT_RERANKER_OPENROUTER_TIMEOUT))
21882204
),
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Per-bank provider cost attribution via the OpenAI ``user`` field.
2+
3+
Shared by the OpenAI-compatible LLM path and the OpenAI embeddings path so both
4+
tag outbound requests identically. Opt-in via ``HINDSIGHT_API_LLM_SEND_BANK_AS_USER``;
5+
downstream cost gateways (OpenRouter usage accounting, LiteLLM, Helicone) key spend
6+
on the OpenAI ``user`` field.
7+
8+
Note: when enabled, the bank id is transmitted to the upstream provider as the
9+
end-user identifier. Banks that are themselves end-user identifiers are therefore
10+
forwarded to the provider — which is exactly what the OpenAI ``user`` field is for,
11+
but operators should opt in with that in mind.
12+
"""
13+
14+
from typing import Any
15+
16+
17+
def apply_bank_attribution(request: dict[str, Any]) -> None:
18+
"""Tag ``request`` with ``user=<bank_id>`` for per-bank cost attribution.
19+
20+
Mutates ``request`` in place. No-op when the flag is off, no bank is in context,
21+
or the caller already set ``user`` — we never override an explicit value.
22+
"""
23+
if "user" in request:
24+
return
25+
# Lazy imports: memory_engine imports the embeddings/provider modules that call
26+
# this, so a top-level import of memory_engine here would be circular.
27+
from ..config import get_config
28+
from .memory_engine import get_current_bank_id
29+
30+
if not get_config().llm_send_bank_as_user:
31+
return
32+
bank_id = get_current_bank_id()
33+
if bank_id:
34+
request["user"] = bank_id

hindsight-api-slim/hindsight_api/engine/cross_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1679,7 +1679,7 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
16791679
return CohereCrossEncoder(
16801680
api_key=api_key,
16811681
model=config.reranker_openrouter_model,
1682-
base_url="https://openrouter.ai/api/v1/rerank",
1682+
base_url=config.reranker_openrouter_base_url,
16831683
timeout=config.reranker_openrouter_timeout,
16841684
)
16851685
elif provider == "flashrank":

hindsight-api-slim/hindsight_api/engine/embeddings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
ENV_EMBEDDINGS_ZEROENTROPY_ENCODING_FORMAT,
5858
ENV_LLM_API_KEY,
5959
)
60+
from .bank_attribution import apply_bank_attribution
6061

6162
logger = logging.getLogger(__name__)
6263

@@ -705,6 +706,7 @@ def encode(self, texts: list[str]) -> list[list[float]]:
705706
}
706707
if self.dimensions is not None:
707708
request["dimensions"] = self.dimensions
709+
apply_bank_attribution(request)
708710

709711
response = self._client.embeddings.create(**request)
710712

hindsight-api-slim/hindsight_api/engine/memory_engine.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111

1212
import asyncio
1313
import contextvars
14+
import functools
15+
import inspect
1416
import json
1517
import logging
1618
import time
1719
import uuid
1820
from collections.abc import Awaitable, Callable
1921
from dataclasses import dataclass, field
2022
from datetime import UTC, datetime, timedelta, timezone
21-
from typing import TYPE_CHECKING, Any, Literal, cast, overload
23+
from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeVar, cast, overload
2224

2325
import asyncpg
2426
import httpx
@@ -67,6 +69,12 @@
6769
# Context variable for current schema (async-safe, per-task isolation)
6870
# Note: default is None, actual default comes from config via get_current_schema()
6971
_current_schema: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_schema", default=None)
72+
73+
# Context variable for the bank an operation runs for (async-safe, per-task isolation).
74+
# Set by the engine wherever it learns the bank (recall/retain/batch/task execution) so
75+
# downstream provider calls can attribute spend per bank — e.g. tagging the OpenAI `user`
76+
# field for cost gateways. None outside a bank-scoped operation.
77+
_current_bank_id: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_bank_id", default=None)
7078
MENTAL_MODEL_PENDING_CONTENT = "Generating content..."
7179

7280

@@ -79,6 +87,44 @@ def get_current_schema() -> str:
7987
return schema
8088

8189

90+
def get_current_bank_id() -> str | None:
91+
"""Get the bank id of the in-flight operation, or None outside a bank-scoped context."""
92+
return _current_bank_id.get()
93+
94+
95+
_P = ParamSpec("_P")
96+
_R = TypeVar("_R")
97+
98+
99+
def _bind_bank_id(
100+
arg: str = "bank_id", key: str | None = None
101+
) -> Callable[[Callable[_P, Awaitable[_R]]], Callable[_P, Awaitable[_R]]]:
102+
"""Bind ``_current_bank_id`` to an argument of the wrapped coroutine for the call's duration.
103+
104+
``arg`` names the parameter carrying the bank id; ``key`` optionally pulls it out of a
105+
dict-valued argument (e.g. ``task_dict["bank_id"]``). Token-based set/reset (including on
106+
exception) keeps the binding scoped to the call.
107+
"""
108+
109+
def decorate(func: Callable[_P, Awaitable[_R]]) -> Callable[_P, Awaitable[_R]]:
110+
sig = inspect.signature(func)
111+
112+
@functools.wraps(func)
113+
async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
114+
value = sig.bind(*args, **kwargs).arguments.get(arg)
115+
if key is not None and isinstance(value, dict):
116+
value = value.get(key)
117+
token = _current_bank_id.set(value if isinstance(value, str) else None)
118+
try:
119+
return await func(*args, **kwargs)
120+
finally:
121+
_current_bank_id.reset(token)
122+
123+
return wrapper
124+
125+
return decorate
126+
127+
82128
def count_tokens(text: str) -> int:
83129
"""Count tokens in text using tiktoken (cl100k_base encoding for GPT-4/3.5)."""
84130
return len(_get_tiktoken_encoding().encode(text))
@@ -1626,6 +1672,7 @@ async def _handle_refresh_mental_model(self, task_dict: dict[str, Any]):
16261672

16271673
logger.info(f"[REFRESH_MENTAL_MODEL_TASK] Completed for bank_id={bank_id}, mental_model_id={mental_model_id}")
16281674

1675+
@_bind_bank_id("task_dict", key="bank_id")
16291676
async def execute_task(self, task_dict: dict[str, Any]):
16301677
"""
16311678
Execute a task by routing it to the appropriate handler.
@@ -2933,6 +2980,7 @@ def retain(
29332980
ctx = request_context if request_context is not None else RC()
29342981
return asyncio.run(self.retain_async(bank_id, content, context, event_date, request_context=ctx))
29352982

2983+
@_bind_bank_id()
29362984
async def retain_async(
29372985
self,
29382986
bank_id: str,
@@ -2979,6 +3027,7 @@ async def retain_async(
29793027
# Return the first (and only) list of unit IDs
29803028
return result[0] if result else []
29813029

3030+
@_bind_bank_id()
29823031
async def retain_batch_async(
29833032
self,
29843033
bank_id: str,
@@ -3706,6 +3755,7 @@ def recall(
37063755
)
37073756
)
37083757

3758+
@_bind_bank_id()
37093759
async def recall_async(
37103760
self,
37113761
bank_id: str,

hindsight-api-slim/hindsight_api/engine/providers/openai_compatible_llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, LengthFinishReasonError
3434

3535
from hindsight_api.config import DEFAULT_LLM_TIMEOUT, ENV_LLM_TIMEOUT
36+
from hindsight_api.engine.bank_attribution import apply_bank_attribution
3637
from hindsight_api.engine.llm_interface import LLMInterface, OutputTooLongError
3738
from hindsight_api.engine.response_models import LLMToolCall, LLMToolCallResult, TokenUsage
3839
from hindsight_api.metrics import get_metrics_collector
@@ -595,6 +596,8 @@ async def call(
595596
call_params["messages"] = _ensure_json_word_in_user_message(call_params["messages"])
596597
call_params["response_format"] = {"type": "json_object"}
597598

599+
apply_bank_attribution(call_params)
600+
598601
last_exception = None
599602

600603
for attempt in range(max_retries + 1):
@@ -945,6 +948,8 @@ async def call_with_tools(
945948
if extra_body:
946949
call_params["extra_body"] = extra_body
947950

951+
apply_bank_attribution(call_params)
952+
948953
last_exception = None
949954

950955
for attempt in range(max_retries + 1):

hindsight-api-slim/hindsight_api/engine/retain/embedding_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import asyncio
6+
import contextvars
67
import logging
78
from typing import Literal, Protocol
89

@@ -89,7 +90,14 @@ async def generate_embeddings_batch(
8990
"""
9091
try:
9192
loop = asyncio.get_event_loop()
92-
embeddings = await loop.run_in_executor(None, _encode_with_input_type, embeddings_backend, texts, input_type)
93+
# run_in_executor runs the encode in a worker thread, which does NOT inherit
94+
# the caller's contextvars. Capture the current context and run the encode
95+
# inside it so context-dependent behavior (e.g. per-bank `user` attribution
96+
# read via get_current_bank_id()) survives the thread hop.
97+
ctx = contextvars.copy_context()
98+
embeddings = await loop.run_in_executor(
99+
None, lambda: ctx.run(_encode_with_input_type, embeddings_backend, texts, input_type)
100+
)
93101
except Exception as e:
94102
raise Exception(f"Failed to generate batch embeddings: {str(e)}")
95103

0 commit comments

Comments
 (0)