Skip to content

Commit 52b4770

Browse files
authored
Merge pull request #1817 from jrobertboos/lcore-2309
LCORE-2309: Added Pydantic AI Bridge
2 parents 3190c8a + 08bee6f commit 52b4770

2 files changed

Lines changed: 374 additions & 0 deletions

File tree

src/utils/pydantic_ai.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Helpers for running Pydantic AI agents against Llama Stack (Responses API compatibility)."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, Final, cast
6+
7+
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
8+
from llama_stack_client import AsyncLlamaStackClient
9+
from pydantic_ai import Agent
10+
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
11+
12+
from models.common.responses.responses_api_params import ResponsesApiParams
13+
from pydantic_ai_lightspeed.llamastack import LlamaStackProvider
14+
15+
_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset(
16+
{
17+
"conversation",
18+
"max_infer_iters",
19+
"tools",
20+
"tool_choice",
21+
"include",
22+
"text",
23+
"reasoning",
24+
"prompt",
25+
"metadata",
26+
"max_tool_calls",
27+
"safety_identifier",
28+
}
29+
)
30+
31+
32+
def _llama_stack_provider_from_client(
33+
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
34+
) -> LlamaStackProvider:
35+
"""Construct a Pydantic AI Llama Stack provider backed by the same client as ``/query``."""
36+
if isinstance(client, AsyncLlamaStackAsLibraryClient):
37+
return LlamaStackProvider(library_client=client)
38+
api_key = client.api_key or "not-needed"
39+
base = str(client.base_url).rstrip("/")
40+
base_url = base if base.endswith("/v1") else f"{base}/v1"
41+
return LlamaStackProvider(
42+
base_url=base_url,
43+
api_key=api_key,
44+
http_client=client._client, # pylint: disable=protected-access
45+
)
46+
47+
48+
def _model_settings_from_responses_params(
49+
responses_params: ResponsesApiParams,
50+
) -> OpenAIResponsesModelSettings:
51+
"""Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings."""
52+
payload = responses_params.model_dump(exclude_none=True)
53+
extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS}
54+
settings_dict: dict[str, Any] = {}
55+
if extra_body:
56+
settings_dict["extra_body"] = extra_body
57+
if responses_params.max_output_tokens is not None:
58+
settings_dict["max_tokens"] = responses_params.max_output_tokens
59+
if responses_params.temperature is not None:
60+
settings_dict["temperature"] = responses_params.temperature
61+
if responses_params.parallel_tool_calls is not None:
62+
settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls
63+
if responses_params.extra_headers:
64+
settings_dict["extra_headers"] = dict(responses_params.extra_headers)
65+
settings_dict["openai_store"] = responses_params.store
66+
if responses_params.previous_response_id is not None:
67+
settings_dict["openai_previous_response_id"] = (
68+
responses_params.previous_response_id
69+
)
70+
return cast(OpenAIResponsesModelSettings, settings_dict)
71+
72+
73+
def build_agent(
74+
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
75+
responses_params: ResponsesApiParams,
76+
) -> Agent[None, str]:
77+
"""Build a Pydantic AI agent that mirrors ``responses_params`` on the Llama Stack backend.
78+
79+
Uses ``LlamaStackProvider`` with the same ``AsyncLlamaStackClient`` (or library client)
80+
as the query endpoint, and ``OpenAIResponsesModel`` so requests follow the Responses API.
81+
Llama-Stack-specific fields (conversation, tools, MCP headers, etc.) are passed via
82+
``model_settings['extra_body']`` so they merge into the OpenAI client request body.
83+
84+
Parameters:
85+
client: Initialized Llama Stack client from ``AsyncLlamaStackClientHolder().get_client()``.
86+
responses_params: Parameters produced by ``prepare_responses_params`` for this turn.
87+
88+
Returns:
89+
``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same
90+
stack configuration as ``client.responses.create(**responses_params.model_dump())``.
91+
"""
92+
provider = _llama_stack_provider_from_client(client)
93+
settings = _model_settings_from_responses_params(responses_params)
94+
95+
model = OpenAIResponsesModel(
96+
responses_params.model,
97+
provider=provider,
98+
settings=settings,
99+
)
100+
return Agent(
101+
model,
102+
instructions=responses_params.instructions,
103+
defer_model_check=True,
104+
)
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""Unit tests for utils/pydantic_ai module."""
2+
3+
# pylint: disable=protected-access
4+
5+
import httpx
6+
import pytest
7+
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
8+
from pytest_mock import MockerFixture
9+
10+
from utils.pydantic_ai import (
11+
_LLS_RESPONSES_EXTRA_FIELDS,
12+
_llama_stack_provider_from_client,
13+
_model_settings_from_responses_params,
14+
build_agent,
15+
)
16+
17+
18+
class TestLlamaStackProviderFromClient:
19+
"""Tests for _llama_stack_provider_from_client factory."""
20+
21+
def test_library_client(self, mocker: MockerFixture) -> None:
22+
"""Test that a library client creates a provider with library_client kwarg."""
23+
mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient)
24+
mock_lib_client.provider_data = None
25+
26+
provider = _llama_stack_provider_from_client(mock_lib_client)
27+
28+
assert provider._library_client is mock_lib_client
29+
30+
def test_remote_client_with_api_key(self, mocker: MockerFixture) -> None:
31+
"""Test that a remote client uses its api_key."""
32+
mock_client = mocker.Mock()
33+
mock_client.base_url = "http://my-server:8321"
34+
mock_client.api_key = "my-secret"
35+
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)
36+
37+
provider = _llama_stack_provider_from_client(mock_client)
38+
39+
assert provider.client.api_key == "my-secret"
40+
assert "my-server:8321" in provider.base_url
41+
42+
def test_remote_client_without_api_key(self, mocker: MockerFixture) -> None:
43+
"""Test that a remote client without api_key defaults to 'not-needed'."""
44+
mock_client = mocker.Mock()
45+
mock_client.base_url = "http://my-server:8321"
46+
mock_client.api_key = None
47+
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)
48+
49+
provider = _llama_stack_provider_from_client(mock_client)
50+
51+
assert provider.client.api_key == "not-needed"
52+
53+
def test_remote_client_passes_http_client(self, mocker: MockerFixture) -> None:
54+
"""Test that a remote client's internal http_client is forwarded."""
55+
mock_http_client = mocker.Mock(spec=httpx.AsyncClient)
56+
mock_client = mocker.Mock()
57+
mock_client.base_url = "http://my-server:8321"
58+
mock_client.api_key = "key"
59+
mock_client._client = mock_http_client
60+
61+
provider = _llama_stack_provider_from_client(mock_client)
62+
63+
assert provider._client._client is mock_http_client
64+
65+
66+
class TestModelSettingsFromResponsesParams:
67+
"""Tests for _model_settings_from_responses_params mapping."""
68+
69+
@pytest.fixture(name="minimal_params")
70+
def minimal_params_fixture(self, mocker: MockerFixture) -> object:
71+
"""Create minimal ResponsesApiParams mock with required fields only."""
72+
params = mocker.Mock()
73+
params.model_dump.return_value = {"model": "test/model", "input": "hello"}
74+
params.max_output_tokens = None
75+
params.temperature = None
76+
params.parallel_tool_calls = None
77+
params.extra_headers = None
78+
params.store = False
79+
params.previous_response_id = None
80+
return params
81+
82+
def test_minimal_params_returns_store_false(self, minimal_params: object) -> None:
83+
"""Test that minimal params produce settings with openai_store=False."""
84+
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
85+
assert settings["openai_store"] is False
86+
87+
def test_minimal_params_no_extra_body(self, minimal_params: object) -> None:
88+
"""Test that minimal params without extra fields omit extra_body."""
89+
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
90+
assert "extra_body" not in settings
91+
92+
def test_max_output_tokens_mapped(self, minimal_params: object) -> None:
93+
"""Test that max_output_tokens is mapped to max_tokens."""
94+
minimal_params.max_output_tokens = 1024 # type: ignore[attr-defined]
95+
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
96+
assert settings["max_tokens"] == 1024
97+
98+
def test_temperature_mapped(self, minimal_params: object) -> None:
99+
"""Test that temperature is passed through."""
100+
minimal_params.temperature = 0.7 # type: ignore[attr-defined]
101+
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
102+
assert settings["temperature"] == 0.7
103+
104+
def test_parallel_tool_calls_mapped(self, minimal_params: object) -> None:
105+
"""Test that parallel_tool_calls is passed through."""
106+
minimal_params.parallel_tool_calls = True # type: ignore[attr-defined]
107+
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
108+
assert settings["parallel_tool_calls"] is True
109+
110+
def test_extra_headers_mapped(self, minimal_params: object) -> None:
111+
"""Test that extra_headers are converted to a dict."""
112+
minimal_params.extra_headers = {"x-custom": "value"} # type: ignore[attr-defined]
113+
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
114+
assert settings["extra_headers"] == {"x-custom": "value"}
115+
116+
def test_store_true_mapped(self, minimal_params: object) -> None:
117+
"""Test that store=True is passed as openai_store."""
118+
minimal_params.store = True # type: ignore[attr-defined]
119+
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
120+
assert settings["openai_store"] is True
121+
122+
def test_previous_response_id_mapped(self, minimal_params: object) -> None:
123+
"""Test that previous_response_id is passed as openai_previous_response_id."""
124+
minimal_params.previous_response_id = "resp_abc123" # type: ignore[attr-defined]
125+
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
126+
assert settings["openai_previous_response_id"] == "resp_abc123"
127+
128+
def test_extra_body_from_lls_fields(self, mocker: MockerFixture) -> None:
129+
"""Test that LLS-specific fields are placed into extra_body."""
130+
params = mocker.Mock()
131+
params.model_dump.return_value = {
132+
"model": "test/model",
133+
"conversation": "conv-123",
134+
"max_infer_iters": 5,
135+
"tools": [{"type": "function"}],
136+
"tool_choice": "auto",
137+
}
138+
params.max_output_tokens = None
139+
params.temperature = None
140+
params.parallel_tool_calls = None
141+
params.extra_headers = None
142+
params.store = False
143+
params.previous_response_id = None
144+
145+
settings = _model_settings_from_responses_params(params)
146+
147+
assert "extra_body" in settings
148+
assert settings["extra_body"]["conversation"] == "conv-123"
149+
assert settings["extra_body"]["max_infer_iters"] == 5
150+
assert settings["extra_body"]["tools"] == [{"type": "function"}]
151+
assert settings["extra_body"]["tool_choice"] == "auto"
152+
153+
def test_extra_body_only_includes_known_fields(self, mocker: MockerFixture) -> None:
154+
"""Test that extra_body only includes fields in _LLS_RESPONSES_EXTRA_FIELDS."""
155+
params = mocker.Mock()
156+
params.model_dump.return_value = {
157+
"model": "test/model",
158+
"conversation": "conv-1",
159+
"unknown_field": "should-not-appear",
160+
}
161+
params.max_output_tokens = None
162+
params.temperature = None
163+
params.parallel_tool_calls = None
164+
params.extra_headers = None
165+
params.store = False
166+
params.previous_response_id = None
167+
168+
settings = _model_settings_from_responses_params(params)
169+
170+
assert "unknown_field" not in settings.get("extra_body", {})
171+
assert settings["extra_body"]["conversation"] == "conv-1"
172+
173+
174+
class TestLlsResponsesExtraFields:
175+
"""Tests for the _LLS_RESPONSES_EXTRA_FIELDS constant."""
176+
177+
def test_is_frozenset(self) -> None:
178+
"""Test that _LLS_RESPONSES_EXTRA_FIELDS is a frozenset."""
179+
assert isinstance(_LLS_RESPONSES_EXTRA_FIELDS, frozenset)
180+
181+
def test_contains_expected_fields(self) -> None:
182+
"""Test that key fields are present."""
183+
expected = {
184+
"conversation",
185+
"max_infer_iters",
186+
"tools",
187+
"tool_choice",
188+
"include",
189+
"text",
190+
"reasoning",
191+
"prompt",
192+
"metadata",
193+
"max_tool_calls",
194+
"safety_identifier",
195+
}
196+
assert expected == _LLS_RESPONSES_EXTRA_FIELDS
197+
198+
199+
class TestBuildAgent:
200+
"""Tests for the build_agent factory function."""
201+
202+
def test_returns_agent_with_correct_model(self, mocker: MockerFixture) -> None:
203+
"""Test that build_agent returns an Agent with the specified model name."""
204+
mock_client = mocker.Mock()
205+
mock_client.base_url = "http://localhost:8321"
206+
mock_client.api_key = "test-key"
207+
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)
208+
209+
mock_params = mocker.Mock()
210+
mock_params.model = "provider/my-model"
211+
mock_params.instructions = "Be helpful."
212+
mock_params.model_dump.return_value = {
213+
"model": "provider/my-model",
214+
"conversation": "conv-1",
215+
}
216+
mock_params.max_output_tokens = None
217+
mock_params.temperature = None
218+
mock_params.parallel_tool_calls = None
219+
mock_params.extra_headers = None
220+
mock_params.store = False
221+
mock_params.previous_response_id = None
222+
223+
agent = build_agent(mock_client, mock_params)
224+
225+
assert agent is not None
226+
227+
def test_agent_has_instructions(self, mocker: MockerFixture) -> None:
228+
"""Test that build_agent passes instructions to the Agent."""
229+
mock_client = mocker.Mock()
230+
mock_client.base_url = "http://localhost:8321"
231+
mock_client.api_key = "test-key"
232+
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)
233+
234+
mock_params = mocker.Mock()
235+
mock_params.model = "provider/my-model"
236+
mock_params.instructions = "You are a helpful assistant."
237+
mock_params.model_dump.return_value = {"model": "provider/my-model"}
238+
mock_params.max_output_tokens = None
239+
mock_params.temperature = None
240+
mock_params.parallel_tool_calls = None
241+
mock_params.extra_headers = None
242+
mock_params.store = False
243+
mock_params.previous_response_id = None
244+
245+
agent = build_agent(mock_client, mock_params)
246+
247+
assert "You are a helpful assistant." in agent._instructions
248+
249+
def test_agent_with_library_client(self, mocker: MockerFixture) -> None:
250+
"""Test that build_agent works with a library client."""
251+
mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient)
252+
mock_lib_client.provider_data = None
253+
254+
mock_params = mocker.Mock()
255+
mock_params.model = "provider/my-model"
256+
mock_params.instructions = None
257+
mock_params.model_dump.return_value = {
258+
"model": "provider/my-model",
259+
"conversation": "conv-1",
260+
}
261+
mock_params.max_output_tokens = None
262+
mock_params.temperature = None
263+
mock_params.parallel_tool_calls = None
264+
mock_params.extra_headers = None
265+
mock_params.store = True
266+
mock_params.previous_response_id = None
267+
268+
agent = build_agent(mock_lib_client, mock_params)
269+
270+
assert agent is not None

0 commit comments

Comments
 (0)