Skip to content

Commit 861295d

Browse files
authored
refactor: replace set_gemini_safety_settings() with LLMProvider.with_config() (#474)
* refactor: replace set_gemini_safety_settings() with LLMProvider.with_config() Removes the fragile ContextVar-setter pattern where callers had to remember to call set_gemini_safety_settings() at every operation entry point. Instead, LLMProvider.with_config(resolved_config) returns a ConfiguredLLMProvider wrapper that: - injects per-bank settings (Gemini safety settings) on every call via token-based ContextVar set/reset — properly scoped, no leakage - proxies all attribute access to the underlying provider via __getattr__ - requires zero changes to LLMInterface or any provider implementations Call sites (retain, reflect, consolidation) now pass llm_config.with_config(resolved_config) to sub-components instead of setting a global context var and hoping nothing else runs in between. This pattern also composes naturally with a future per-bank provider factory: callers always receive something with a .call() method. * fix: pass messages/tools as kwargs in ConfiguredLLMProvider to preserve class-level patch compatibility
1 parent 15f4b87 commit 861295d

5 files changed

Lines changed: 149 additions & 84 deletions

File tree

hindsight-api/hindsight_api/engine/consolidation/consolidator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,9 @@ async def run_consolidation_job(
127127
# Resolve bank-specific config with hierarchical overrides
128128
config = await memory_engine._config_resolver.resolve_full_config(bank_id, request_context)
129129

130-
# Apply bank-specific Gemini safety settings for this request context
131-
from ..providers.gemini_llm import set_gemini_safety_settings
132-
133-
set_gemini_safety_settings(config.llm_gemini_safety_settings)
130+
# Build a configured LLM wrapper that applies per-bank settings (e.g. safety settings)
131+
# to every call without leaking across operations.
132+
llm_config = memory_engine._consolidation_llm_config.with_config(config)
134133

135134
perf = ConsolidationPerfLog(bank_id)
136135
max_memories_per_batch = config.consolidation_batch_size
@@ -281,6 +280,7 @@ async def run_consolidation_job(
281280
pass_results = await _process_memory_batch(
282281
conn=conn,
283282
memory_engine=memory_engine,
283+
llm_config=llm_config,
284284
bank_id=bank_id,
285285
memories=llm_batch,
286286
request_context=request_context,
@@ -318,6 +318,7 @@ async def run_consolidation_job(
318318
results = await _process_memory_batch(
319319
conn=conn,
320320
memory_engine=memory_engine,
321+
llm_config=llm_config,
321322
bank_id=bank_id,
322323
memories=llm_batch,
323324
request_context=request_context,
@@ -513,6 +514,7 @@ async def _trigger_mental_model_refreshes(
513514
async def _process_memory_batch(
514515
conn: "Connection",
515516
memory_engine: "MemoryEngine",
517+
llm_config: Any,
516518
bank_id: str,
517519
memories: list[dict[str, Any]],
518520
request_context: "RequestContext",
@@ -581,7 +583,7 @@ async def _process_memory_batch(
581583
# 3. Single LLM call
582584
t0 = time.time()
583585
llm_result = await _consolidate_batch_with_llm(
584-
memory_engine=memory_engine,
586+
llm_config=llm_config,
585587
memories=memories,
586588
union_observations=union_observations,
587589
union_source_facts=union_source_facts,
@@ -945,7 +947,7 @@ def _build_observations_for_llm(
945947

946948

947949
async def _consolidate_batch_with_llm(
948-
memory_engine: "MemoryEngine",
950+
llm_config: Any,
949951
memories: list[dict[str, Any]],
950952
union_observations: "list[MemoryFact]",
951953
union_source_facts: "dict[str, MemoryFact]",
@@ -981,7 +983,7 @@ def _fact_line(m: dict[str, Any]) -> str:
981983
last_exc: Exception | None = None
982984
for attempt in range(1, max_attempts + 1):
983985
try:
984-
response: _ConsolidationBatchResponse = await memory_engine._consolidation_llm_config.call(
986+
response: _ConsolidationBatchResponse = await llm_config.call(
985987
messages=[{"role": "user", "content": prompt}],
986988
response_format=_ConsolidationBatchResponse,
987989
scope="consolidation",

hindsight-api/hindsight_api/engine/llm_wrapper.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,23 @@ def _verify_claude_code_available(self) -> None:
622622
# SDK will automatically check for authentication when first used
623623
# No need to verify here - let it fail gracefully on first call with helpful error
624624

625+
def with_config(self, config: Any) -> "ConfiguredLLMProvider":
626+
"""
627+
Return a configured wrapper for a specific bank operation.
628+
629+
The wrapper applies per-bank overrides (e.g. Gemini safety settings)
630+
to every ``call()`` / ``call_with_tools()`` invocation without
631+
changing the underlying provider or its long-lived client connection.
632+
633+
Args:
634+
config: Resolved ``HindsightConfig`` for the current bank/request.
635+
636+
Returns:
637+
A ``ConfiguredLLMProvider`` that delegates to this provider with
638+
the supplied config applied.
639+
"""
640+
return ConfiguredLLMProvider(self, config.llm_gemini_safety_settings)
641+
625642
async def cleanup(self) -> None:
626643
"""Clean up resources."""
627644
pass
@@ -683,5 +700,58 @@ def for_judge(cls) -> "LLMProvider":
683700
return cls(provider=provider, api_key=api_key, base_url=base_url, model=model, reasoning_effort="high")
684701

685702

703+
class ConfiguredLLMProvider:
704+
"""
705+
Thin wrapper around LLMProvider that applies bank-specific config to every call.
706+
707+
Obtained via ``LLMProvider.with_config(resolved_config)``. The wrapper
708+
sets any provider-specific overrides (currently Gemini safety settings)
709+
immediately before each call using a ContextVar token, then resets it
710+
afterwards — so nesting is safe and the configuration cannot leak across
711+
operations.
712+
713+
All attribute access falls through to the underlying provider so callers
714+
that read ``llm.provider``, ``llm.model``, etc. continue to work without
715+
any changes.
716+
"""
717+
718+
def __init__(self, provider: "LLMProvider", gemini_safety_settings: list | None) -> None:
719+
# Use object.__setattr__ to avoid triggering __getattr__
720+
object.__setattr__(self, "_provider", provider)
721+
object.__setattr__(self, "_gemini_safety_settings", gemini_safety_settings)
722+
723+
# ── attribute passthrough ──────────────────────────────────────────────────
724+
725+
def __getattr__(self, name: str) -> Any:
726+
return getattr(object.__getattribute__(self, "_provider"), name)
727+
728+
# ── overridden call methods ────────────────────────────────────────────────
729+
730+
async def call(self, messages: list[dict[str, Any]], **kwargs: Any) -> Any:
731+
from .providers.gemini_llm import _safety_settings_ctx
732+
733+
token = _safety_settings_ctx.set(object.__getattribute__(self, "_gemini_safety_settings"))
734+
try:
735+
return await object.__getattribute__(self, "_provider").call(messages=messages, **kwargs)
736+
finally:
737+
_safety_settings_ctx.reset(token)
738+
739+
async def call_with_tools(
740+
self,
741+
messages: list[dict[str, Any]],
742+
tools: list[dict[str, Any]],
743+
**kwargs: Any,
744+
) -> "LLMToolCallResult":
745+
from .providers.gemini_llm import _safety_settings_ctx
746+
747+
token = _safety_settings_ctx.set(object.__getattribute__(self, "_gemini_safety_settings"))
748+
try:
749+
return await object.__getattribute__(self, "_provider").call_with_tools(
750+
messages=messages, tools=tools, **kwargs
751+
)
752+
finally:
753+
_safety_settings_ctx.reset(token)
754+
755+
686756
# Backwards compatibility alias
687757
LLMConfig = LLMProvider

hindsight-api/hindsight_api/engine/memory_engine.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,17 +1831,12 @@ async def _retain_batch_async_internal(
18311831
# Resolve bank-specific config for this operation
18321832
resolved_config = await self._config_resolver.resolve_full_config(bank_id, request_context)
18331833

1834-
# Apply bank-specific Gemini safety settings for this request context
1835-
from .providers.gemini_llm import set_gemini_safety_settings
1836-
1837-
set_gemini_safety_settings(resolved_config.llm_gemini_safety_settings)
1838-
18391834
# Create parent span for retain operation
18401835
with create_operation_span("retain", bank_id):
18411836
return await orchestrator.retain_batch(
18421837
pool=pool,
18431838
embeddings_model=self.embeddings,
1844-
llm_config=self._retain_llm_config,
1839+
llm_config=self._retain_llm_config.with_config(resolved_config),
18451840
entity_resolver=self.entity_resolver,
18461841
format_date_fn=self._format_readable_date,
18471842
bank_id=bank_id,
@@ -4468,11 +4463,7 @@ async def reflect_async(
44684463
# The agent can call lookup() to list available models if needed.
44694464
# This is critical for banks with many mental models to avoid huge prompts.
44704465

4471-
# Apply bank-specific Gemini safety settings for this request context
44724466
resolved_reflect_config = await self._config_resolver.resolve_full_config(bank_id, request_context)
4473-
from .providers.gemini_llm import set_gemini_safety_settings
4474-
4475-
set_gemini_safety_settings(resolved_reflect_config.llm_gemini_safety_settings)
44764467

44774468
# Compute max iterations based on budget
44784469
config = get_config()
@@ -4576,7 +4567,7 @@ async def expand_fn(memory_ids: list[str], depth: str) -> dict[str, Any]:
45764567

45774568
try:
45784569
agent_result = await run_reflect_agent(
4579-
llm_config=self._reflect_llm_config,
4570+
llm_config=self._reflect_llm_config.with_config(resolved_reflect_config),
45804571
bank_id=bank_id,
45814572
query=query,
45824573
bank_profile=profile,

hindsight-api/hindsight_api/engine/providers/gemini_llm.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,12 @@
2525

2626
logger = logging.getLogger(__name__)
2727

28-
# Context variable for per-request Gemini safety settings override (supports per-bank configuration)
28+
# Per-request Gemini safety settings override.
29+
# Set exclusively by ConfiguredLLMProvider.call() / call_with_tools() via token-based
30+
# set/reset, so it is properly scoped to each individual LLM call and never leaks.
2931
_safety_settings_ctx: ContextVar[list | None] = ContextVar("gemini_safety_settings", default=None)
3032

3133

32-
def set_gemini_safety_settings(settings: list | None) -> None:
33-
"""
34-
Set Gemini safety settings for the current async context.
35-
36-
This allows per-bank safety settings to be applied without changing
37-
the LLM provider interface. Call this before making LLM calls within
38-
an operation that has resolved bank-specific configuration.
39-
40-
Args:
41-
settings: List of safety setting dicts with 'category' and 'threshold' keys,
42-
or None to use the instance default (from env var).
43-
"""
44-
_safety_settings_ctx.set(settings)
45-
46-
4734
# Vertex AI imports (optional)
4835
try:
4936
import google.auth

hindsight-api/tests/test_gemini_safety_settings.py

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -216,81 +216,96 @@ async def test_call_with_tools_applies_safety_settings():
216216
assert "HARM_CATEGORY_HARASSMENT" in categories
217217

218218

219-
# ─── Context variable override ────────────────────────────────────────────────
219+
# ─── with_config() override ───────────────────────────────────────────────────
220220

221221

222-
@pytest.mark.asyncio
223-
async def test_context_var_overrides_instance_settings():
224-
"""The context var safety settings take precedence over instance defaults."""
225-
from hindsight_api.engine.providers.gemini_llm import set_gemini_safety_settings
222+
def _make_llm_provider(safety_settings=None):
223+
"""Return an LLMProvider (wrapping GeminiLLM) with a mocked genai.Client."""
224+
with patch("google.genai.Client") as mock_client_cls:
225+
mock_client_cls.return_value = MagicMock()
226+
from hindsight_api.engine.llm_wrapper import LLMProvider
226227

227-
# Instance has settings, but we'll override via context var with different settings
228-
instance_settings = [{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}]
229-
ctx_settings = [{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}]
228+
provider = LLMProvider(
229+
provider="gemini",
230+
api_key="fake-api-key",
231+
base_url="",
232+
model="gemini-2.5-flash",
233+
gemini_safety_settings=safety_settings,
234+
)
235+
# Replace the underlying Gemini client with a fresh mock
236+
provider._provider_impl._client = MagicMock()
237+
return provider
230238

231-
provider = _make_gemini_provider(safety_settings=instance_settings)
232239

233-
fake_response = MagicMock()
234-
fake_response.text = "hello"
235-
fake_response.candidates = [MagicMock(finish_reason="STOP")]
236-
fake_response.usage_metadata = MagicMock(prompt_token_count=5, candidates_token_count=2)
240+
def _fake_response():
241+
r = MagicMock()
242+
r.text = "hello"
243+
r.candidates = [MagicMock(finish_reason="STOP")]
244+
r.usage_metadata = MagicMock(prompt_token_count=5, candidates_token_count=2)
245+
return r
237246

238-
provider._client.aio.models.generate_content = AsyncMock(return_value=fake_response)
239247

240-
# Set context var override
241-
set_gemini_safety_settings(ctx_settings)
242-
try:
243-
await provider.call(
244-
messages=[{"role": "user", "content": "hi"}],
245-
scope="test",
246-
)
247-
finally:
248-
set_gemini_safety_settings(None) # Reset context
248+
def _make_config(safety_settings):
249+
"""Return a minimal config-like object with llm_gemini_safety_settings."""
250+
cfg = MagicMock()
251+
cfg.llm_gemini_safety_settings = safety_settings
252+
return cfg
249253

250-
call_args = provider._client.aio.models.generate_content.call_args
251-
config_arg = call_args.kwargs.get("config")
252254

253-
assert config_arg is not None
254-
assert config_arg.safety_settings is not None
255+
@pytest.mark.asyncio
256+
async def test_with_config_overrides_instance_settings():
257+
"""with_config() settings take precedence over the provider instance defaults."""
258+
instance_settings = [{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}]
259+
override_settings = [{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}]
260+
261+
provider = _make_llm_provider(safety_settings=instance_settings)
262+
provider._provider_impl._client.aio.models.generate_content = AsyncMock(return_value=_fake_response())
255263

256-
# Should use ctx_settings (HATE_SPEECH/BLOCK_NONE), not instance_settings (HARASSMENT/BLOCK_ONLY_HIGH)
264+
configured = provider.with_config(_make_config(override_settings))
265+
await configured.call(messages=[{"role": "user", "content": "hi"}], scope="test")
266+
267+
config_arg = provider._provider_impl._client.aio.models.generate_content.call_args.kwargs.get("config")
268+
assert config_arg is not None
257269
categories = [s.category.value if hasattr(s.category, "value") else str(s.category) for s in config_arg.safety_settings]
270+
# Should use override_settings (HATE_SPEECH), not instance_settings (HARASSMENT)
258271
assert "HARM_CATEGORY_HATE_SPEECH" in categories
259272
assert "HARM_CATEGORY_HARASSMENT" not in categories
260273

261274

262275
@pytest.mark.asyncio
263-
async def test_context_var_none_falls_back_to_instance():
264-
"""When context var is None (not set), instance settings are used."""
265-
from hindsight_api.engine.providers.gemini_llm import set_gemini_safety_settings
266-
276+
async def test_with_config_none_falls_back_to_instance():
277+
"""When with_config() supplies None, the instance default is used."""
267278
instance_settings = [{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}]
268-
provider = _make_gemini_provider(safety_settings=instance_settings)
269-
270-
fake_response = MagicMock()
271-
fake_response.text = "hello"
272-
fake_response.candidates = [MagicMock(finish_reason="STOP")]
273-
fake_response.usage_metadata = MagicMock(prompt_token_count=5, candidates_token_count=2)
274-
275-
provider._client.aio.models.generate_content = AsyncMock(return_value=fake_response)
276-
277-
# Explicitly set context var to None (fallback)
278-
set_gemini_safety_settings(None)
279279

280-
await provider.call(
281-
messages=[{"role": "user", "content": "hi"}],
282-
scope="test",
283-
)
280+
provider = _make_llm_provider(safety_settings=instance_settings)
281+
provider._provider_impl._client.aio.models.generate_content = AsyncMock(return_value=_fake_response())
284282

285-
call_args = provider._client.aio.models.generate_content.call_args
286-
config_arg = call_args.kwargs.get("config")
283+
configured = provider.with_config(_make_config(None))
284+
await configured.call(messages=[{"role": "user", "content": "hi"}], scope="test")
287285

286+
config_arg = provider._provider_impl._client.aio.models.generate_content.call_args.kwargs.get("config")
288287
assert config_arg is not None
289-
assert config_arg.safety_settings is not None
290288
categories = [s.category.value if hasattr(s.category, "value") else str(s.category) for s in config_arg.safety_settings]
291289
assert "HARM_CATEGORY_HARASSMENT" in categories
292290

293291

292+
@pytest.mark.asyncio
293+
async def test_with_config_resets_after_call():
294+
"""The ContextVar is properly reset after a with_config() call (no leakage)."""
295+
from hindsight_api.engine.providers.gemini_llm import _safety_settings_ctx
296+
297+
settings = [{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}]
298+
provider = _make_llm_provider(safety_settings=None)
299+
provider._provider_impl._client.aio.models.generate_content = AsyncMock(return_value=_fake_response())
300+
301+
before = _safety_settings_ctx.get()
302+
configured = provider.with_config(_make_config(settings))
303+
await configured.call(messages=[{"role": "user", "content": "hi"}], scope="test")
304+
after = _safety_settings_ctx.get()
305+
306+
assert after == before # ContextVar restored to its original value
307+
308+
294309
# ─── LLMProvider reads safety settings from config ────────────────────────────
295310

296311

0 commit comments

Comments
 (0)