Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions extropy/core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Configure via `extropy config` CLI or programmatically via extropy.config.configure().
"""

from .providers import get_provider
from .providers import get_provider, get_simulation_provider
from .providers.base import TokenUsage, ValidatorCallback, RetryCallback
from ..config import get_config, parse_model_string

Expand Down Expand Up @@ -80,12 +80,10 @@ async def simple_call_async(
Model is passed explicitly from simulation caller (provider/model format).
Returns (structured_data, token_usage) tuple.
"""
if model:
provider, model_name = _resolve_provider_and_model(model)
else:
config = get_config()
model_string = config.resolve_sim_strong()
provider, model_name = _resolve_provider_and_model(model_string)
config = get_config()
model_string = model or config.resolve_sim_strong()
_, model_name = parse_model_string(model_string)
provider = get_simulation_provider(model_string)
return await provider.simple_call_async(
prompt=prompt,
response_schema=response_schema,
Expand Down
12 changes: 7 additions & 5 deletions extropy/core/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,16 @@ def get_pipeline_provider() -> LLMProvider:
return _get_or_create_provider(provider, f"pipeline:{provider}")


def get_simulation_provider() -> LLMProvider:
"""Get the cached provider for simulation phase (agent reasoning).
def get_simulation_provider(model_string: str | None = None) -> LLMProvider:
"""Get a cached provider for simulation phase async calls.

Uses the provider from the resolved simulation strong model.
Args:
model_string: Optional explicit model string ("provider/model"). If
omitted, uses resolved simulation strong model from config.
"""
config = get_config()
strong_model = config.resolve_sim_strong()
provider, _ = parse_model_string(strong_model)
resolved_model = model_string or config.resolve_sim_strong()
provider, _ = parse_model_string(resolved_model)
return _get_or_create_provider(provider, f"simulation:{provider}")


Expand Down
107 changes: 107 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, Mock

from extropy.core import llm
from extropy.core.providers.base import TokenUsage


def test_simple_call_async_uses_cached_simulation_provider_default_model(monkeypatch):
provider = MagicMock()
provider.simple_call_async = AsyncMock(
return_value=({"ok": True}, TokenUsage(input_tokens=11, output_tokens=7))
)
config = SimpleNamespace(resolve_sim_strong=lambda: "openai/gpt-5", providers={})
get_sim_provider = Mock(return_value=provider)

monkeypatch.setattr(llm, "get_config", lambda: config)
monkeypatch.setattr(llm, "get_simulation_provider", get_sim_provider)
monkeypatch.setattr(
llm,
"get_provider",
Mock(
side_effect=AssertionError("simple_call_async should use simulation cache")
),
)

result, usage = asyncio.run(
llm.simple_call_async(
prompt="hello",
response_schema={"type": "object"},
schema_name="response",
)
)

assert result == {"ok": True}
assert usage.input_tokens == 11
assert usage.output_tokens == 7
get_sim_provider.assert_called_once_with("openai/gpt-5")
provider.simple_call_async.assert_awaited_once()
assert provider.simple_call_async.await_args.kwargs["model"] == "gpt-5"


def test_simple_call_async_uses_cached_simulation_provider_for_explicit_model(
monkeypatch,
):
provider = MagicMock()
provider.simple_call_async = AsyncMock(return_value=({"ok": True}, TokenUsage()))
config = SimpleNamespace(resolve_sim_strong=lambda: "openai/gpt-5", providers={})
get_sim_provider = Mock(return_value=provider)

monkeypatch.setattr(llm, "get_config", lambda: config)
monkeypatch.setattr(llm, "get_simulation_provider", get_sim_provider)
monkeypatch.setattr(
llm,
"get_provider",
Mock(
side_effect=AssertionError(
"explicit async model should still use simulation cache"
)
),
)

asyncio.run(
llm.simple_call_async(
prompt="hello",
response_schema={"type": "object"},
schema_name="response",
model="anthropic/claude-sonnet-4-5",
)
)

get_sim_provider.assert_called_once_with("anthropic/claude-sonnet-4-5")
provider.simple_call_async.assert_awaited_once()
assert provider.simple_call_async.await_args.kwargs["model"] == "claude-sonnet-4-5"


def test_simple_call_sync_path_still_uses_regular_provider_factory(monkeypatch):
provider = MagicMock()
provider.simple_call.return_value = {"ok": True}
config = SimpleNamespace(
resolve_pipeline_fast=lambda: "openai/gpt-5-mini", providers={}
)
get_provider = Mock(return_value=provider)

monkeypatch.setattr(llm, "get_config", lambda: config)
monkeypatch.setattr(llm, "get_provider", get_provider)
monkeypatch.setattr(
llm,
"get_simulation_provider",
Mock(
side_effect=AssertionError(
"sync calls should not use simulation provider cache"
)
),
)

result = llm.simple_call(
prompt="hello",
response_schema={"type": "object"},
schema_name="response",
model="openai/gpt-5-mini",
)

assert result == {"ok": True}
get_provider.assert_called_once_with("openai", config.providers)
provider.simple_call.assert_called_once()
assert provider.simple_call.call_args.kwargs["model"] == "gpt-5-mini"
Loading