Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions src/utils/pydantic_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Helpers for running Pydantic AI agents against Llama Stack (Responses API compatibility)."""

from __future__ import annotations

from typing import Any, Final, cast

from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
from llama_stack_client import AsyncLlamaStackClient
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings

from models.common.responses.responses_api_params import ResponsesApiParams
from pydantic_ai_lightspeed.llamastack import LlamaStackProvider

_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset(
{
"conversation",
"max_infer_iters",
"tools",
"tool_choice",
"include",
"text",
"reasoning",
"prompt",
"metadata",
"max_tool_calls",
"safety_identifier",
}
)


def _llama_stack_provider_from_client(
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
) -> LlamaStackProvider:
"""Construct a Pydantic AI Llama Stack provider backed by the same client as ``/query``."""
Comment thread
asimurka marked this conversation as resolved.
if isinstance(client, AsyncLlamaStackAsLibraryClient):
return LlamaStackProvider(library_client=client)
api_key = client.api_key or "not-needed"
base = str(client.base_url).rstrip("/")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's ok, but IIRC there's some utility function to do it (correctly ;)

base_url = base if base.endswith("/v1") else f"{base}/v1"
return LlamaStackProvider(
base_url=base_url,
api_key=api_key,
http_client=client._client, # pylint: disable=protected-access
)
Comment thread
asimurka marked this conversation as resolved.


def _model_settings_from_responses_params(
responses_params: ResponsesApiParams,
) -> OpenAIResponsesModelSettings:
"""Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings."""
payload = responses_params.model_dump(exclude_none=True)
extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS}
settings_dict: dict[str, Any] = {}
if extra_body:
settings_dict["extra_body"] = extra_body
if responses_params.max_output_tokens is not None:
settings_dict["max_tokens"] = responses_params.max_output_tokens
if responses_params.temperature is not None:
settings_dict["temperature"] = responses_params.temperature
if responses_params.parallel_tool_calls is not None:
settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls
if responses_params.extra_headers:
settings_dict["extra_headers"] = dict(responses_params.extra_headers)
settings_dict["openai_store"] = responses_params.store
if responses_params.previous_response_id is not None:
settings_dict["openai_previous_response_id"] = (
responses_params.previous_response_id
)
return cast(OpenAIResponsesModelSettings, settings_dict)


def build_agent(
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
responses_params: ResponsesApiParams,
) -> Agent[None, str]:
"""Build a Pydantic AI agent that mirrors ``responses_params`` on the Llama Stack backend.

Uses ``LlamaStackProvider`` with the same ``AsyncLlamaStackClient`` (or library client)
as the query endpoint, and ``OpenAIResponsesModel`` so requests follow the Responses API.
Llama-Stack-specific fields (conversation, tools, MCP headers, etc.) are passed via
``model_settings['extra_body']`` so they merge into the OpenAI client request body.

Parameters:
client: Initialized Llama Stack client from ``AsyncLlamaStackClientHolder().get_client()``.
responses_params: Parameters produced by ``prepare_responses_params`` for this turn.

Returns:
``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same
stack configuration as ``client.responses.create(**responses_params.model_dump())``.
"""
provider = _llama_stack_provider_from_client(client)
settings = _model_settings_from_responses_params(responses_params)

model = OpenAIResponsesModel(
responses_params.model,
provider=provider,
settings=settings,
)
return Agent(
model,
instructions=responses_params.instructions,
defer_model_check=True,
)
270 changes: 270 additions & 0 deletions tests/unit/utils/test_pydantic_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""Unit tests for utils/pydantic_ai module."""

# pylint: disable=protected-access

import httpx
import pytest
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
from pytest_mock import MockerFixture

from utils.pydantic_ai import (
_LLS_RESPONSES_EXTRA_FIELDS,
_llama_stack_provider_from_client,
_model_settings_from_responses_params,
build_agent,
)


class TestLlamaStackProviderFromClient:
"""Tests for _llama_stack_provider_from_client factory."""

def test_library_client(self, mocker: MockerFixture) -> None:
"""Test that a library client creates a provider with library_client kwarg."""
mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient)
mock_lib_client.provider_data = None

provider = _llama_stack_provider_from_client(mock_lib_client)

assert provider._library_client is mock_lib_client

def test_remote_client_with_api_key(self, mocker: MockerFixture) -> None:
"""Test that a remote client uses its api_key."""
mock_client = mocker.Mock()
mock_client.base_url = "http://my-server:8321"
mock_client.api_key = "my-secret"
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)

provider = _llama_stack_provider_from_client(mock_client)

assert provider.client.api_key == "my-secret"
assert "my-server:8321" in provider.base_url

def test_remote_client_without_api_key(self, mocker: MockerFixture) -> None:
"""Test that a remote client without api_key defaults to 'not-needed'."""
mock_client = mocker.Mock()
mock_client.base_url = "http://my-server:8321"
mock_client.api_key = None
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)

provider = _llama_stack_provider_from_client(mock_client)

assert provider.client.api_key == "not-needed"

def test_remote_client_passes_http_client(self, mocker: MockerFixture) -> None:
"""Test that a remote client's internal http_client is forwarded."""
mock_http_client = mocker.Mock(spec=httpx.AsyncClient)
mock_client = mocker.Mock()
mock_client.base_url = "http://my-server:8321"
mock_client.api_key = "key"
mock_client._client = mock_http_client

provider = _llama_stack_provider_from_client(mock_client)

assert provider._client._client is mock_http_client


class TestModelSettingsFromResponsesParams:
"""Tests for _model_settings_from_responses_params mapping."""

@pytest.fixture(name="minimal_params")
def minimal_params_fixture(self, mocker: MockerFixture) -> object:
"""Create minimal ResponsesApiParams mock with required fields only."""
params = mocker.Mock()
params.model_dump.return_value = {"model": "test/model", "input": "hello"}
params.max_output_tokens = None
params.temperature = None
params.parallel_tool_calls = None
params.extra_headers = None
params.store = False
params.previous_response_id = None
return params

def test_minimal_params_returns_store_false(self, minimal_params: object) -> None:
"""Test that minimal params produce settings with openai_store=False."""
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["openai_store"] is False

def test_minimal_params_no_extra_body(self, minimal_params: object) -> None:
"""Test that minimal params without extra fields omit extra_body."""
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert "extra_body" not in settings

def test_max_output_tokens_mapped(self, minimal_params: object) -> None:
"""Test that max_output_tokens is mapped to max_tokens."""
minimal_params.max_output_tokens = 1024 # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["max_tokens"] == 1024

def test_temperature_mapped(self, minimal_params: object) -> None:
"""Test that temperature is passed through."""
minimal_params.temperature = 0.7 # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["temperature"] == 0.7

def test_parallel_tool_calls_mapped(self, minimal_params: object) -> None:
"""Test that parallel_tool_calls is passed through."""
minimal_params.parallel_tool_calls = True # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["parallel_tool_calls"] is True

def test_extra_headers_mapped(self, minimal_params: object) -> None:
"""Test that extra_headers are converted to a dict."""
minimal_params.extra_headers = {"x-custom": "value"} # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["extra_headers"] == {"x-custom": "value"}

def test_store_true_mapped(self, minimal_params: object) -> None:
"""Test that store=True is passed as openai_store."""
minimal_params.store = True # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["openai_store"] is True

def test_previous_response_id_mapped(self, minimal_params: object) -> None:
"""Test that previous_response_id is passed as openai_previous_response_id."""
minimal_params.previous_response_id = "resp_abc123" # type: ignore[attr-defined]
settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type]
assert settings["openai_previous_response_id"] == "resp_abc123"

def test_extra_body_from_lls_fields(self, mocker: MockerFixture) -> None:
"""Test that LLS-specific fields are placed into extra_body."""
params = mocker.Mock()
params.model_dump.return_value = {
"model": "test/model",
"conversation": "conv-123",
"max_infer_iters": 5,
"tools": [{"type": "function"}],
"tool_choice": "auto",
}
params.max_output_tokens = None
params.temperature = None
params.parallel_tool_calls = None
params.extra_headers = None
params.store = False
params.previous_response_id = None

settings = _model_settings_from_responses_params(params)

assert "extra_body" in settings
assert settings["extra_body"]["conversation"] == "conv-123"
assert settings["extra_body"]["max_infer_iters"] == 5
assert settings["extra_body"]["tools"] == [{"type": "function"}]
assert settings["extra_body"]["tool_choice"] == "auto"

def test_extra_body_only_includes_known_fields(self, mocker: MockerFixture) -> None:
"""Test that extra_body only includes fields in _LLS_RESPONSES_EXTRA_FIELDS."""
params = mocker.Mock()
params.model_dump.return_value = {
"model": "test/model",
"conversation": "conv-1",
"unknown_field": "should-not-appear",
}
params.max_output_tokens = None
params.temperature = None
params.parallel_tool_calls = None
params.extra_headers = None
params.store = False
params.previous_response_id = None

settings = _model_settings_from_responses_params(params)

assert "unknown_field" not in settings.get("extra_body", {})
assert settings["extra_body"]["conversation"] == "conv-1"


class TestLlsResponsesExtraFields:
"""Tests for the _LLS_RESPONSES_EXTRA_FIELDS constant."""

def test_is_frozenset(self) -> None:
"""Test that _LLS_RESPONSES_EXTRA_FIELDS is a frozenset."""
assert isinstance(_LLS_RESPONSES_EXTRA_FIELDS, frozenset)

def test_contains_expected_fields(self) -> None:
"""Test that key fields are present."""
expected = {
"conversation",
"max_infer_iters",
"tools",
"tool_choice",
"include",
"text",
"reasoning",
"prompt",
"metadata",
"max_tool_calls",
"safety_identifier",
}
assert expected == _LLS_RESPONSES_EXTRA_FIELDS


class TestBuildAgent:
"""Tests for the build_agent factory function."""

def test_returns_agent_with_correct_model(self, mocker: MockerFixture) -> None:
"""Test that build_agent returns an Agent with the specified model name."""
mock_client = mocker.Mock()
mock_client.base_url = "http://localhost:8321"
mock_client.api_key = "test-key"
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)

mock_params = mocker.Mock()
mock_params.model = "provider/my-model"
mock_params.instructions = "Be helpful."
mock_params.model_dump.return_value = {
"model": "provider/my-model",
"conversation": "conv-1",
}
mock_params.max_output_tokens = None
mock_params.temperature = None
mock_params.parallel_tool_calls = None
mock_params.extra_headers = None
mock_params.store = False
mock_params.previous_response_id = None

agent = build_agent(mock_client, mock_params)

assert agent is not None

def test_agent_has_instructions(self, mocker: MockerFixture) -> None:
"""Test that build_agent passes instructions to the Agent."""
mock_client = mocker.Mock()
mock_client.base_url = "http://localhost:8321"
mock_client.api_key = "test-key"
mock_client._client = mocker.Mock(spec=httpx.AsyncClient)

mock_params = mocker.Mock()
mock_params.model = "provider/my-model"
mock_params.instructions = "You are a helpful assistant."
mock_params.model_dump.return_value = {"model": "provider/my-model"}
mock_params.max_output_tokens = None
mock_params.temperature = None
mock_params.parallel_tool_calls = None
mock_params.extra_headers = None
mock_params.store = False
mock_params.previous_response_id = None

agent = build_agent(mock_client, mock_params)

assert "You are a helpful assistant." in agent._instructions

def test_agent_with_library_client(self, mocker: MockerFixture) -> None:
"""Test that build_agent works with a library client."""
mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient)
mock_lib_client.provider_data = None

mock_params = mocker.Mock()
mock_params.model = "provider/my-model"
mock_params.instructions = None
mock_params.model_dump.return_value = {
"model": "provider/my-model",
"conversation": "conv-1",
}
mock_params.max_output_tokens = None
mock_params.temperature = None
mock_params.parallel_tool_calls = None
mock_params.extra_headers = None
mock_params.store = True
mock_params.previous_response_id = None

agent = build_agent(mock_lib_client, mock_params)

assert agent is not None
Loading