Skip to content

Commit 4b8f2c8

Browse files
authored
Merge pull request #1806 from jrobertboos/lcore-2308
LCORE-2308: LlamaStack Pydantic AI Provider
2 parents ba755d6 + 45ba392 commit 4b8f2c8

10 files changed

Lines changed: 1700 additions & 91 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ dependencies = [
7979
"python-dotenv>=1.2.2",
8080
# Used for token estimation before LLM calls (LCORE-1569 / conversation compaction)
8181
"tiktoken>=0.8.0",
82+
# Used for Pydantic AI
83+
"pydantic-ai>=1.99.0"
8284
]
8385

8486

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Pydantic AI integrations/extensions for Lightspeed Core Stack."""
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Pydantic AI provider for Llama Stack."""
2+
3+
from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider
4+
5+
__all__ = ["LlamaStackProvider"]
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""Llama Stack provider implementation for Pydantic AI."""
2+
3+
from __future__ import annotations as _annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import httpx
8+
from openai import AsyncOpenAI
9+
from pydantic_ai import ModelProfile
10+
from pydantic_ai.models import create_async_http_client
11+
from pydantic_ai.profiles.openai import openai_model_profile
12+
from pydantic_ai.providers import Provider
13+
14+
from pydantic_ai_lightspeed.llamastack._transport import LlamaStackLibraryTransport
15+
16+
if TYPE_CHECKING:
17+
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
18+
19+
DEFAULT_BASE_URL = "http://localhost:8321/v1"
20+
21+
22+
class LlamaStackProvider(Provider[AsyncOpenAI]):
23+
"""Provider for Llama Stack — connects to a Llama Stack server's OpenAI-compatible API.
24+
25+
Supports two modes:
26+
27+
1. **Server mode** — connect to a running Llama Stack server via HTTP
28+
2. **Library mode** — run Llama Stack in-process via ``AsyncLlamaStackAsLibraryClient``
29+
"""
30+
31+
@property
32+
def name(self) -> str:
33+
"""The provider name."""
34+
return "llama-stack"
35+
36+
@property
37+
def base_url(self) -> str:
38+
"""The base URL for the provider API."""
39+
return str(self._client.base_url)
40+
41+
@property
42+
def client(self) -> AsyncOpenAI:
43+
"""The OpenAI-compatible client for the provider."""
44+
return self._client
45+
46+
@staticmethod
47+
def model_profile(model_name: str) -> ModelProfile | None:
48+
"""Return the model profile for the named model, if available."""
49+
return openai_model_profile(model_name)
50+
51+
def __init__(
52+
self,
53+
*,
54+
base_url: str | None = None,
55+
api_key: str | None = None,
56+
library_client: AsyncLlamaStackAsLibraryClient | None = None,
57+
http_client: httpx.AsyncClient | None = None,
58+
) -> None:
59+
"""Create a new Llama Stack provider.
60+
61+
Args:
62+
base_url: The base URL for the Llama Stack server (OpenAI-compatible endpoint).
63+
Defaults to ``http://localhost:8321/v1``.
64+
Must be ``None`` when ``library_client`` is provided.
65+
api_key: The API key for authentication. Defaults to ``'not-needed'`` since
66+
local Llama Stack servers typically don't require one.
67+
Must be ``None`` when ``library_client`` is provided.
68+
library_client: An initialized ``AsyncLlamaStackAsLibraryClient`` for library mode.
69+
When provided, requests are dispatched in-process (no server needed).
70+
Mutually exclusive with ``base_url``, ``api_key``, and ``http_client``.
71+
http_client: An existing ``httpx.AsyncClient`` to use for making HTTP requests.
72+
Must be ``None`` when ``library_client`` is provided.
73+
"""
74+
if library_client is not None:
75+
if base_url is not None:
76+
raise ValueError("Cannot provide both `library_client` and `base_url`")
77+
if api_key is not None:
78+
raise ValueError("Cannot provide both `library_client` and `api_key`")
79+
if http_client is not None:
80+
raise ValueError(
81+
"Cannot provide both `library_client` and `http_client`"
82+
)
83+
84+
self._library_client = library_client
85+
transport = LlamaStackLibraryTransport(library_client)
86+
lib_http_client = httpx.AsyncClient(
87+
transport=transport,
88+
base_url="http://llama-stack-library",
89+
timeout=httpx.Timeout(None),
90+
)
91+
self._client = AsyncOpenAI(
92+
http_client=lib_http_client,
93+
base_url="http://llama-stack-library/v1",
94+
api_key="not-needed",
95+
)
96+
else:
97+
base_url = base_url or DEFAULT_BASE_URL
98+
api_key = api_key or "not-needed"
99+
100+
if http_client is not None:
101+
self._client = AsyncOpenAI(
102+
base_url=base_url, api_key=api_key, http_client=http_client
103+
)
104+
else:
105+
oai_http_client = create_async_http_client()
106+
self._client = AsyncOpenAI(
107+
base_url=base_url, api_key=api_key, http_client=oai_http_client
108+
)
109+
110+
def __repr__(self) -> str:
111+
"""Return a string representation of the provider."""
112+
return f"LlamaStackProvider(name={self.name!r}, base_url={self.base_url!r})"
113+
114+
def _set_http_client(self, http_client: httpx.AsyncClient) -> None:
115+
"""Inject an httpx.AsyncClient into the underlying OpenAI client.
116+
117+
Replaces the internal HTTP transport by assigning directly to the
118+
protected ``self._client._client`` attribute of the AsyncOpenAI instance.
119+
120+
Args:
121+
http_client: The async HTTP client to use for subsequent requests.
122+
"""
123+
self._client._client = http_client # pyright: ignore[reportPrivateUsage] # pylint: disable=protected-access
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""httpx transport that routes OpenAI-compatible requests through a Llama Stack library client."""
2+
3+
from __future__ import annotations as _annotations
4+
5+
import json
6+
from collections.abc import AsyncGenerator, AsyncIterator
7+
from typing import Any
8+
9+
import httpx
10+
from llama_stack.core.library_client import (
11+
AsyncLlamaStackAsLibraryClient,
12+
convert_pydantic_to_json_value,
13+
)
14+
from llama_stack.core.request_headers import (
15+
PROVIDER_DATA_VAR,
16+
request_provider_data_context,
17+
)
18+
from llama_stack.core.server.routes import find_matching_route
19+
from llama_stack.core.utils.context import preserve_contexts_async_generator
20+
21+
22+
class _AsyncByteStream(httpx.AsyncByteStream):
23+
"""Wraps an async byte generator as an httpx AsyncByteStream."""
24+
25+
def __init__(self, gen: AsyncGenerator[bytes, None]) -> None:
26+
"""Store an async generator that yields raw bytes for streaming.
27+
28+
Args:
29+
gen: An async generator producing byte chunks to stream.
30+
"""
31+
self._gen = gen
32+
33+
async def __aiter__(self) -> AsyncIterator[bytes]:
34+
"""Yield bytes chunks from the wrapped generator.
35+
36+
Returns:
37+
An async iterator of bytes fulfilling the httpx.AsyncByteStream contract.
38+
"""
39+
async for chunk in self._gen:
40+
yield chunk
41+
42+
43+
class LlamaStackLibraryTransport(httpx.AsyncBaseTransport):
44+
"""Custom httpx transport that dispatches requests through a Llama Stack library client.
45+
46+
Instead of making real HTTP calls, this transport routes requests directly
47+
to the Llama Stack's in-process route handlers via the library client's
48+
route matching and body conversion logic.
49+
"""
50+
51+
def __init__(self, client: AsyncLlamaStackAsLibraryClient) -> None:
52+
"""Initialize the transport with a Llama Stack library client.
53+
54+
Args:
55+
client: An initialized ``AsyncLlamaStackAsLibraryClient`` whose route
56+
handlers will receive dispatched requests.
57+
"""
58+
self._client = client
59+
60+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
61+
"""Dispatch an httpx request to the in-process Llama Stack route handlers.
62+
63+
Args:
64+
request: The outgoing httpx request to route.
65+
66+
Returns:
67+
An httpx response built from the matched route handler result.
68+
69+
Raises:
70+
RuntimeError: If the library client has not been initialized.
71+
"""
72+
if self._client.route_impls is None:
73+
raise RuntimeError(
74+
"Llama Stack library client not initialized. Call initialize() first."
75+
)
76+
77+
method = request.method
78+
path = request.url.raw_path.decode("utf-8")
79+
80+
body = json.loads(request.content) if request.content else {}
81+
82+
headers: dict[str, str] = {
83+
k.decode("utf-8") if isinstance(k, bytes) else k: (
84+
v.decode("utf-8") if isinstance(v, bytes) else v
85+
)
86+
for k, v in request.headers.raw
87+
}
88+
89+
if self._client.provider_data:
90+
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
91+
if all(key not in headers for key in keys):
92+
headers["X-LlamaStack-Provider-Data"] = json.dumps(
93+
self._client.provider_data
94+
)
95+
96+
with request_provider_data_context(headers):
97+
is_stream = body.get("stream", False)
98+
99+
if is_stream:
100+
return await self._handle_streaming(request, method, path, body)
101+
return await self._handle_non_streaming(request, method, path, body)
102+
103+
async def _handle_non_streaming(
104+
self,
105+
request: httpx.Request,
106+
method: str,
107+
path: str,
108+
body: dict[str, Any],
109+
) -> httpx.Response:
110+
"""Dispatch a non-streaming request to the matched route handler.
111+
112+
Args:
113+
request: The original httpx request (attached to the response).
114+
method: The HTTP method (e.g. ``"POST"``).
115+
path: The decoded URL path used for route matching.
116+
body: The parsed JSON request body.
117+
118+
Returns:
119+
An httpx.Response containing the JSON-serialized handler result.
120+
121+
Raises:
122+
RuntimeError: If route_impls is not initialized.
123+
"""
124+
if self._client.route_impls is None:
125+
raise RuntimeError("route_impls is not initialized")
126+
127+
matched_func, path_params, _, _ = find_matching_route(
128+
method, path, self._client.route_impls
129+
)
130+
merged_body = {**body, **path_params}
131+
merged_body = self._client._convert_body( # pylint: disable=protected-access
132+
matched_func, merged_body
133+
)
134+
135+
result = await matched_func(**merged_body)
136+
137+
json_content = json.dumps(convert_pydantic_to_json_value(result))
138+
status_code = httpx.codes.OK
139+
140+
if method.upper() == "DELETE" and result is None:
141+
status_code = httpx.codes.NO_CONTENT
142+
json_content = ""
143+
144+
return httpx.Response(
145+
status_code=status_code,
146+
content=json_content.encode("utf-8"),
147+
headers={"Content-Type": "application/json"},
148+
request=request,
149+
)
150+
151+
async def _handle_streaming(
152+
self,
153+
request: httpx.Request,
154+
method: str,
155+
path: str,
156+
body: dict[str, Any],
157+
) -> httpx.Response:
158+
"""Dispatch a streaming request and return an SSE event-stream response.
159+
160+
Args:
161+
request: The original httpx request (attached to the response).
162+
method: The HTTP method (e.g. ``"POST"``).
163+
path: The decoded URL path used for route matching.
164+
body: The parsed JSON request body (must contain ``stream: True``).
165+
166+
Returns:
167+
An httpx.Response with a streaming body of SSE-formatted chunks.
168+
169+
Raises:
170+
RuntimeError: If route_impls is not initialized.
171+
"""
172+
if self._client.route_impls is None:
173+
raise RuntimeError("route_impls is not initialized")
174+
175+
func, path_params, _, _ = find_matching_route(
176+
method, path, self._client.route_impls
177+
)
178+
merged_body = {**body, **path_params}
179+
merged_body = self._client._convert_body( # pylint: disable=protected-access
180+
func, merged_body
181+
)
182+
183+
result = await func(**merged_body)
184+
185+
async def gen() -> AsyncGenerator[bytes, None]:
186+
async for chunk in result:
187+
data = json.dumps(convert_pydantic_to_json_value(chunk))
188+
yield f"data: {data}\n\n".encode("utf-8")
189+
190+
wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR])
191+
192+
return httpx.Response(
193+
status_code=httpx.codes.OK,
194+
stream=_AsyncByteStream(wrapped_gen),
195+
headers={"Content-Type": "text/event-stream"},
196+
request=request,
197+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for the pydantic_ai_lightspeed package."""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests for pydantic_ai_lightspeed.llamastack sub-package."""

0 commit comments

Comments
 (0)