Skip to content

Commit 53e1b86

Browse files
committed
(feat) inital implementation
1 parent 45ba392 commit 53e1b86

1 file changed

Lines changed: 110 additions & 0 deletions

File tree

src/utils/pydantic_ai.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Helpers for running Pydantic AI agents against Llama Stack (Responses API compatibility)."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, Final, cast
6+
7+
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
8+
from llama_stack_client import AsyncLlamaStackClient
9+
from pydantic_ai import Agent
10+
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
11+
from pydantic_ai_lightspeed.llamastack import LlamaStackProvider
12+
13+
from models.common.responses.responses_api_params import ResponsesApiParams
14+
15+
_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset(
16+
{
17+
"conversation",
18+
"max_infer_iters",
19+
"tools",
20+
"tool_choice",
21+
"include",
22+
"text",
23+
"reasoning",
24+
"prompt",
25+
"metadata",
26+
"max_tool_calls",
27+
"safety_identifier",
28+
}
29+
)
30+
31+
32+
def _openai_v1_base_url(client: AsyncLlamaStackClient) -> str:
33+
"""Return OpenAI-compatible ``/v1`` base URL for the given service client."""
34+
base = str(client.base_url).rstrip("/")
35+
if base.endswith("/v1"):
36+
return base
37+
return f"{base}/v1"
38+
39+
40+
def _llama_stack_provider_from_client(
41+
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
42+
) -> LlamaStackProvider:
43+
"""Construct a Pydantic AI Llama Stack provider backed by the same client as ``/query``."""
44+
if isinstance(client, AsyncLlamaStackAsLibraryClient):
45+
return LlamaStackProvider(library_client=client)
46+
api_key = client.api_key or "not-needed"
47+
return LlamaStackProvider(
48+
base_url=_openai_v1_base_url(client),
49+
api_key=api_key,
50+
http_client=client._client,
51+
)
52+
53+
54+
def _model_settings_from_responses_params(
55+
responses_params: ResponsesApiParams,
56+
) -> OpenAIResponsesModelSettings:
57+
"""Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings."""
58+
payload = responses_params.model_dump(exclude_none=True)
59+
extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS}
60+
settings_dict: dict[str, Any] = {}
61+
if extra_body:
62+
settings_dict["extra_body"] = extra_body
63+
if responses_params.max_output_tokens is not None:
64+
settings_dict["max_tokens"] = responses_params.max_output_tokens
65+
if responses_params.temperature is not None:
66+
settings_dict["temperature"] = responses_params.temperature
67+
if responses_params.parallel_tool_calls is not None:
68+
settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls
69+
if responses_params.extra_headers:
70+
settings_dict["extra_headers"] = dict(responses_params.extra_headers)
71+
settings_dict["openai_store"] = responses_params.store
72+
if responses_params.previous_response_id is not None:
73+
settings_dict["openai_previous_response_id"] = (
74+
responses_params.previous_response_id
75+
)
76+
return cast(OpenAIResponsesModelSettings, settings_dict)
77+
78+
79+
def build_agent(
80+
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
81+
responses_params: ResponsesApiParams,
82+
) -> Agent[None, str]:
83+
"""Build a Pydantic AI agent that mirrors ``responses_params`` on the Llama Stack backend.
84+
85+
Uses ``LlamaStackProvider`` with the same ``AsyncLlamaStackClient`` (or library client)
86+
as the query endpoint, and ``OpenAIResponsesModel`` so requests follow the Responses API.
87+
Llama-Stack-specific fields (conversation, tools, MCP headers, etc.) are passed via
88+
``model_settings['extra_body']`` so they merge into the OpenAI client request body.
89+
90+
Parameters:
91+
client: Initialized Llama Stack client from ``AsyncLlamaStackClientHolder().get_client()``.
92+
responses_params: Parameters produced by ``prepare_responses_params`` for this turn.
93+
94+
Returns:
95+
``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same
96+
stack configuration as ``client.responses.create(**responses_params.model_dump())``.
97+
"""
98+
provider = _llama_stack_provider_from_client(client)
99+
settings = _model_settings_from_responses_params(responses_params)
100+
101+
model = OpenAIResponsesModel(
102+
responses_params.model,
103+
provider=provider,
104+
settings=settings,
105+
)
106+
return Agent(
107+
model,
108+
instructions=responses_params.instructions,
109+
defer_model_check=True,
110+
)

0 commit comments

Comments
 (0)