-
Notifications
You must be signed in to change notification settings - Fork 94
LCORE-2309: Added Pydantic AI Bridge #1817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| """Helpers for running Pydantic AI agents against Llama Stack (Responses API compatibility).""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Final, cast | ||
|
|
||
| from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient | ||
| from llama_stack_client import AsyncLlamaStackClient | ||
| from pydantic_ai import Agent | ||
| from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings | ||
|
|
||
| from models.common.responses.responses_api_params import ResponsesApiParams | ||
| from pydantic_ai_lightspeed.llamastack import LlamaStackProvider | ||
|
|
||
| _LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset( | ||
| { | ||
| "conversation", | ||
| "max_infer_iters", | ||
| "tools", | ||
| "tool_choice", | ||
| "include", | ||
| "text", | ||
| "reasoning", | ||
| "prompt", | ||
| "metadata", | ||
| "max_tool_calls", | ||
| "safety_identifier", | ||
| } | ||
| ) | ||
|
|
||
|
|
||
| def _llama_stack_provider_from_client( | ||
| client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, | ||
| ) -> LlamaStackProvider: | ||
| """Construct a Pydantic AI Llama Stack provider backed by the same client as ``/query``.""" | ||
| if isinstance(client, AsyncLlamaStackAsLibraryClient): | ||
| return LlamaStackProvider(library_client=client) | ||
| api_key = client.api_key or "not-needed" | ||
| base = str(client.base_url).rstrip("/") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's ok, but IIRC there's some utility function to do it (correctly ;) |
||
| base_url = base if base.endswith("/v1") else f"{base}/v1" | ||
| return LlamaStackProvider( | ||
| base_url=base_url, | ||
| api_key=api_key, | ||
| http_client=client._client, # pylint: disable=protected-access | ||
| ) | ||
|
asimurka marked this conversation as resolved.
|
||
|
|
||
|
|
||
| def _model_settings_from_responses_params( | ||
| responses_params: ResponsesApiParams, | ||
| ) -> OpenAIResponsesModelSettings: | ||
| """Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings.""" | ||
| payload = responses_params.model_dump(exclude_none=True) | ||
| extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS} | ||
| settings_dict: dict[str, Any] = {} | ||
| if extra_body: | ||
| settings_dict["extra_body"] = extra_body | ||
| if responses_params.max_output_tokens is not None: | ||
| settings_dict["max_tokens"] = responses_params.max_output_tokens | ||
| if responses_params.temperature is not None: | ||
| settings_dict["temperature"] = responses_params.temperature | ||
| if responses_params.parallel_tool_calls is not None: | ||
| settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls | ||
| if responses_params.extra_headers: | ||
| settings_dict["extra_headers"] = dict(responses_params.extra_headers) | ||
| settings_dict["openai_store"] = responses_params.store | ||
| if responses_params.previous_response_id is not None: | ||
| settings_dict["openai_previous_response_id"] = ( | ||
| responses_params.previous_response_id | ||
| ) | ||
| return cast(OpenAIResponsesModelSettings, settings_dict) | ||
|
|
||
|
|
||
| def build_agent( | ||
| client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, | ||
| responses_params: ResponsesApiParams, | ||
| ) -> Agent[None, str]: | ||
| """Build a Pydantic AI agent that mirrors ``responses_params`` on the Llama Stack backend. | ||
|
|
||
| Uses ``LlamaStackProvider`` with the same ``AsyncLlamaStackClient`` (or library client) | ||
| as the query endpoint, and ``OpenAIResponsesModel`` so requests follow the Responses API. | ||
| Llama-Stack-specific fields (conversation, tools, MCP headers, etc.) are passed via | ||
| ``model_settings['extra_body']`` so they merge into the OpenAI client request body. | ||
|
|
||
| Parameters: | ||
| client: Initialized Llama Stack client from ``AsyncLlamaStackClientHolder().get_client()``. | ||
| responses_params: Parameters produced by ``prepare_responses_params`` for this turn. | ||
|
|
||
| Returns: | ||
| ``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same | ||
| stack configuration as ``client.responses.create(**responses_params.model_dump())``. | ||
| """ | ||
| provider = _llama_stack_provider_from_client(client) | ||
| settings = _model_settings_from_responses_params(responses_params) | ||
|
|
||
| model = OpenAIResponsesModel( | ||
| responses_params.model, | ||
| provider=provider, | ||
| settings=settings, | ||
| ) | ||
| return Agent( | ||
| model, | ||
| instructions=responses_params.instructions, | ||
| defer_model_check=True, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,270 @@ | ||
| """Unit tests for utils/pydantic_ai module.""" | ||
|
|
||
| # pylint: disable=protected-access | ||
|
|
||
| import httpx | ||
| import pytest | ||
| from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient | ||
| from pytest_mock import MockerFixture | ||
|
|
||
| from utils.pydantic_ai import ( | ||
| _LLS_RESPONSES_EXTRA_FIELDS, | ||
| _llama_stack_provider_from_client, | ||
| _model_settings_from_responses_params, | ||
| build_agent, | ||
| ) | ||
|
|
||
|
|
||
| class TestLlamaStackProviderFromClient: | ||
| """Tests for _llama_stack_provider_from_client factory.""" | ||
|
|
||
| def test_library_client(self, mocker: MockerFixture) -> None: | ||
| """Test that a library client creates a provider with library_client kwarg.""" | ||
| mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient) | ||
| mock_lib_client.provider_data = None | ||
|
|
||
| provider = _llama_stack_provider_from_client(mock_lib_client) | ||
|
|
||
| assert provider._library_client is mock_lib_client | ||
|
|
||
| def test_remote_client_with_api_key(self, mocker: MockerFixture) -> None: | ||
| """Test that a remote client uses its api_key.""" | ||
| mock_client = mocker.Mock() | ||
| mock_client.base_url = "http://my-server:8321" | ||
| mock_client.api_key = "my-secret" | ||
| mock_client._client = mocker.Mock(spec=httpx.AsyncClient) | ||
|
|
||
| provider = _llama_stack_provider_from_client(mock_client) | ||
|
|
||
| assert provider.client.api_key == "my-secret" | ||
| assert "my-server:8321" in provider.base_url | ||
|
|
||
| def test_remote_client_without_api_key(self, mocker: MockerFixture) -> None: | ||
| """Test that a remote client without api_key defaults to 'not-needed'.""" | ||
| mock_client = mocker.Mock() | ||
| mock_client.base_url = "http://my-server:8321" | ||
| mock_client.api_key = None | ||
| mock_client._client = mocker.Mock(spec=httpx.AsyncClient) | ||
|
|
||
| provider = _llama_stack_provider_from_client(mock_client) | ||
|
|
||
| assert provider.client.api_key == "not-needed" | ||
|
|
||
| def test_remote_client_passes_http_client(self, mocker: MockerFixture) -> None: | ||
| """Test that a remote client's internal http_client is forwarded.""" | ||
| mock_http_client = mocker.Mock(spec=httpx.AsyncClient) | ||
| mock_client = mocker.Mock() | ||
| mock_client.base_url = "http://my-server:8321" | ||
| mock_client.api_key = "key" | ||
| mock_client._client = mock_http_client | ||
|
|
||
| provider = _llama_stack_provider_from_client(mock_client) | ||
|
|
||
| assert provider._client._client is mock_http_client | ||
|
|
||
|
|
||
| class TestModelSettingsFromResponsesParams: | ||
| """Tests for _model_settings_from_responses_params mapping.""" | ||
|
|
||
| @pytest.fixture(name="minimal_params") | ||
| def minimal_params_fixture(self, mocker: MockerFixture) -> object: | ||
| """Create minimal ResponsesApiParams mock with required fields only.""" | ||
| params = mocker.Mock() | ||
| params.model_dump.return_value = {"model": "test/model", "input": "hello"} | ||
| params.max_output_tokens = None | ||
| params.temperature = None | ||
| params.parallel_tool_calls = None | ||
| params.extra_headers = None | ||
| params.store = False | ||
| params.previous_response_id = None | ||
| return params | ||
|
|
||
| def test_minimal_params_returns_store_false(self, minimal_params: object) -> None: | ||
| """Test that minimal params produce settings with openai_store=False.""" | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["openai_store"] is False | ||
|
|
||
| def test_minimal_params_no_extra_body(self, minimal_params: object) -> None: | ||
| """Test that minimal params without extra fields omit extra_body.""" | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert "extra_body" not in settings | ||
|
|
||
| def test_max_output_tokens_mapped(self, minimal_params: object) -> None: | ||
| """Test that max_output_tokens is mapped to max_tokens.""" | ||
| minimal_params.max_output_tokens = 1024 # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["max_tokens"] == 1024 | ||
|
|
||
| def test_temperature_mapped(self, minimal_params: object) -> None: | ||
| """Test that temperature is passed through.""" | ||
| minimal_params.temperature = 0.7 # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["temperature"] == 0.7 | ||
|
|
||
| def test_parallel_tool_calls_mapped(self, minimal_params: object) -> None: | ||
| """Test that parallel_tool_calls is passed through.""" | ||
| minimal_params.parallel_tool_calls = True # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["parallel_tool_calls"] is True | ||
|
|
||
| def test_extra_headers_mapped(self, minimal_params: object) -> None: | ||
| """Test that extra_headers are converted to a dict.""" | ||
| minimal_params.extra_headers = {"x-custom": "value"} # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["extra_headers"] == {"x-custom": "value"} | ||
|
|
||
| def test_store_true_mapped(self, minimal_params: object) -> None: | ||
| """Test that store=True is passed as openai_store.""" | ||
| minimal_params.store = True # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["openai_store"] is True | ||
|
|
||
| def test_previous_response_id_mapped(self, minimal_params: object) -> None: | ||
| """Test that previous_response_id is passed as openai_previous_response_id.""" | ||
| minimal_params.previous_response_id = "resp_abc123" # type: ignore[attr-defined] | ||
| settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] | ||
| assert settings["openai_previous_response_id"] == "resp_abc123" | ||
|
|
||
| def test_extra_body_from_lls_fields(self, mocker: MockerFixture) -> None: | ||
| """Test that LLS-specific fields are placed into extra_body.""" | ||
| params = mocker.Mock() | ||
| params.model_dump.return_value = { | ||
| "model": "test/model", | ||
| "conversation": "conv-123", | ||
| "max_infer_iters": 5, | ||
| "tools": [{"type": "function"}], | ||
| "tool_choice": "auto", | ||
| } | ||
| params.max_output_tokens = None | ||
| params.temperature = None | ||
| params.parallel_tool_calls = None | ||
| params.extra_headers = None | ||
| params.store = False | ||
| params.previous_response_id = None | ||
|
|
||
| settings = _model_settings_from_responses_params(params) | ||
|
|
||
| assert "extra_body" in settings | ||
| assert settings["extra_body"]["conversation"] == "conv-123" | ||
| assert settings["extra_body"]["max_infer_iters"] == 5 | ||
| assert settings["extra_body"]["tools"] == [{"type": "function"}] | ||
| assert settings["extra_body"]["tool_choice"] == "auto" | ||
|
|
||
| def test_extra_body_only_includes_known_fields(self, mocker: MockerFixture) -> None: | ||
| """Test that extra_body only includes fields in _LLS_RESPONSES_EXTRA_FIELDS.""" | ||
| params = mocker.Mock() | ||
| params.model_dump.return_value = { | ||
| "model": "test/model", | ||
| "conversation": "conv-1", | ||
| "unknown_field": "should-not-appear", | ||
| } | ||
| params.max_output_tokens = None | ||
| params.temperature = None | ||
| params.parallel_tool_calls = None | ||
| params.extra_headers = None | ||
| params.store = False | ||
| params.previous_response_id = None | ||
|
|
||
| settings = _model_settings_from_responses_params(params) | ||
|
|
||
| assert "unknown_field" not in settings.get("extra_body", {}) | ||
| assert settings["extra_body"]["conversation"] == "conv-1" | ||
|
|
||
|
|
||
| class TestLlsResponsesExtraFields: | ||
| """Tests for the _LLS_RESPONSES_EXTRA_FIELDS constant.""" | ||
|
|
||
| def test_is_frozenset(self) -> None: | ||
| """Test that _LLS_RESPONSES_EXTRA_FIELDS is a frozenset.""" | ||
| assert isinstance(_LLS_RESPONSES_EXTRA_FIELDS, frozenset) | ||
|
|
||
| def test_contains_expected_fields(self) -> None: | ||
| """Test that key fields are present.""" | ||
| expected = { | ||
| "conversation", | ||
| "max_infer_iters", | ||
| "tools", | ||
| "tool_choice", | ||
| "include", | ||
| "text", | ||
| "reasoning", | ||
| "prompt", | ||
| "metadata", | ||
| "max_tool_calls", | ||
| "safety_identifier", | ||
| } | ||
| assert expected == _LLS_RESPONSES_EXTRA_FIELDS | ||
|
|
||
|
|
||
| class TestBuildAgent: | ||
| """Tests for the build_agent factory function.""" | ||
|
|
||
| def test_returns_agent_with_correct_model(self, mocker: MockerFixture) -> None: | ||
| """Test that build_agent returns an Agent with the specified model name.""" | ||
| mock_client = mocker.Mock() | ||
| mock_client.base_url = "http://localhost:8321" | ||
| mock_client.api_key = "test-key" | ||
| mock_client._client = mocker.Mock(spec=httpx.AsyncClient) | ||
|
|
||
| mock_params = mocker.Mock() | ||
| mock_params.model = "provider/my-model" | ||
| mock_params.instructions = "Be helpful." | ||
| mock_params.model_dump.return_value = { | ||
| "model": "provider/my-model", | ||
| "conversation": "conv-1", | ||
| } | ||
| mock_params.max_output_tokens = None | ||
| mock_params.temperature = None | ||
| mock_params.parallel_tool_calls = None | ||
| mock_params.extra_headers = None | ||
| mock_params.store = False | ||
| mock_params.previous_response_id = None | ||
|
|
||
| agent = build_agent(mock_client, mock_params) | ||
|
|
||
| assert agent is not None | ||
|
|
||
| def test_agent_has_instructions(self, mocker: MockerFixture) -> None: | ||
| """Test that build_agent passes instructions to the Agent.""" | ||
| mock_client = mocker.Mock() | ||
| mock_client.base_url = "http://localhost:8321" | ||
| mock_client.api_key = "test-key" | ||
| mock_client._client = mocker.Mock(spec=httpx.AsyncClient) | ||
|
|
||
| mock_params = mocker.Mock() | ||
| mock_params.model = "provider/my-model" | ||
| mock_params.instructions = "You are a helpful assistant." | ||
| mock_params.model_dump.return_value = {"model": "provider/my-model"} | ||
| mock_params.max_output_tokens = None | ||
| mock_params.temperature = None | ||
| mock_params.parallel_tool_calls = None | ||
| mock_params.extra_headers = None | ||
| mock_params.store = False | ||
| mock_params.previous_response_id = None | ||
|
|
||
| agent = build_agent(mock_client, mock_params) | ||
|
|
||
| assert "You are a helpful assistant." in agent._instructions | ||
|
|
||
| def test_agent_with_library_client(self, mocker: MockerFixture) -> None: | ||
| """Test that build_agent works with a library client.""" | ||
| mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient) | ||
| mock_lib_client.provider_data = None | ||
|
|
||
| mock_params = mocker.Mock() | ||
| mock_params.model = "provider/my-model" | ||
| mock_params.instructions = None | ||
| mock_params.model_dump.return_value = { | ||
| "model": "provider/my-model", | ||
| "conversation": "conv-1", | ||
| } | ||
| mock_params.max_output_tokens = None | ||
| mock_params.temperature = None | ||
| mock_params.parallel_tool_calls = None | ||
| mock_params.extra_headers = None | ||
| mock_params.store = True | ||
| mock_params.previous_response_id = None | ||
|
|
||
| agent = build_agent(mock_lib_client, mock_params) | ||
|
|
||
| assert agent is not None |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.