Skip to content

Commit 3def1bc

Browse files
committed
(feat) inital implementation
1 parent 9700c70 commit 3def1bc

6 files changed

Lines changed: 1105 additions & 91 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ dependencies = [
7979
"python-dotenv>=1.2.2",
8080
# Used for token estimation before LLM calls (LCORE-1569 / conversation compaction)
8181
"tiktoken>=0.8.0",
82+
# Used for Pydantic AI
83+
"pydantic-ai>=1.99.0"
8284
]
8385

8486

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Pydantic AI integrations/extensions for Lightspeed Core Stack."""
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Pydantic AI provider for Llama Stack."""
2+
3+
from ._provider import LlamaStackProvider
4+
5+
__all__ = ['LlamaStackProvider']
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import TYPE_CHECKING
5+
6+
import httpx
7+
from openai import AsyncOpenAI
8+
9+
from pydantic_ai import ModelProfile
10+
from pydantic_ai.models import create_async_http_client
11+
from pydantic_ai.profiles.openai import openai_model_profile
12+
from pydantic_ai.providers import Provider
13+
14+
if TYPE_CHECKING:
15+
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
16+
17+
DEFAULT_BASE_URL = 'http://localhost:8321/v1'
18+
19+
20+
class LlamaStackProvider(Provider[AsyncOpenAI]):
21+
"""Provider for Llama Stack — connects to a Llama Stack server's OpenAI-compatible API.
22+
23+
Supports two modes:
24+
25+
1. **Server mode** — connect to a running Llama Stack server via HTTP
26+
2. **Library mode** — run Llama Stack in-process via ``AsyncLlamaStackAsLibraryClient``
27+
"""
28+
29+
@property
30+
def name(self) -> str:
31+
return 'llama-stack'
32+
33+
@property
34+
def base_url(self) -> str:
35+
return str(self._client.base_url)
36+
37+
@property
38+
def client(self) -> AsyncOpenAI:
39+
return self._client
40+
41+
@staticmethod
42+
def model_profile(model_name: str) -> ModelProfile | None:
43+
return openai_model_profile(model_name)
44+
45+
def __init__(
46+
self,
47+
*,
48+
base_url: str | None = None,
49+
api_key: str | None = None,
50+
library_client: AsyncLlamaStackAsLibraryClient | None = None,
51+
http_client: httpx.AsyncClient | None = None,
52+
) -> None:
53+
"""Create a new Llama Stack provider.
54+
55+
Args:
56+
base_url: The base URL for the Llama Stack server (OpenAI-compatible endpoint).
57+
Defaults to ``LLAMA_STACK_BASE_URL`` env var, then ``http://localhost:8321/v1``.
58+
Must be ``None`` when ``library_client`` is provided.
59+
api_key: The API key for authentication. Defaults to ``LLAMA_STACK_API_KEY`` env
60+
var, then ``'not-needed'`` since local Llama Stack servers typically don't
61+
require one. Must be ``None`` when ``library_client`` is provided.
62+
library_client: An initialized ``AsyncLlamaStackAsLibraryClient`` for library mode.
63+
When provided, requests are dispatched in-process (no server needed).
64+
Mutually exclusive with ``base_url``, ``api_key``, and ``http_client``.
65+
http_client: An existing ``httpx.AsyncClient`` to use for making HTTP requests.
66+
Must be ``None`` when ``library_client`` is provided.
67+
"""
68+
if library_client is not None:
69+
assert base_url is None, 'Cannot provide both `library_client` and `base_url`'
70+
assert api_key is None, 'Cannot provide both `library_client` and `api_key`'
71+
assert http_client is None, 'Cannot provide both `library_client` and `http_client`'
72+
73+
from ._transport import LlamaStackLibraryTransport
74+
75+
self._library_client = library_client
76+
transport = LlamaStackLibraryTransport(library_client)
77+
lib_http_client = httpx.AsyncClient(
78+
transport=transport, base_url='http://llama-stack-library'
79+
)
80+
self._client = AsyncOpenAI(
81+
http_client=lib_http_client,
82+
base_url='http://llama-stack-library/v1',
83+
api_key='not-needed',
84+
)
85+
else:
86+
base_url = base_url or os.environ.get('LLAMA_STACK_BASE_URL') or DEFAULT_BASE_URL
87+
api_key = api_key or os.environ.get('LLAMA_STACK_API_KEY') or 'not-needed'
88+
89+
if http_client is not None:
90+
self._client = AsyncOpenAI(
91+
base_url=base_url, api_key=api_key, http_client=http_client
92+
)
93+
else:
94+
oai_http_client = create_async_http_client()
95+
self._client = AsyncOpenAI(
96+
base_url=base_url, api_key=api_key, http_client=oai_http_client
97+
)
98+
99+
def __repr__(self) -> str:
100+
return f'LlamaStackProvider(name={self.name!r}, base_url={self.base_url!r})'
101+
102+
def _set_http_client(self, http_client: httpx.AsyncClient) -> None:
103+
self._client._client = http_client # pyright: ignore[reportPrivateUsage]
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import annotations as _annotations
2+
3+
import json
4+
from collections.abc import AsyncGenerator, AsyncIterator
5+
from typing import Any
6+
7+
import httpx
8+
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient, convert_pydantic_to_json_value
9+
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider_data_context
10+
from llama_stack.core.server.routes import find_matching_route
11+
from llama_stack.core.utils.context import preserve_contexts_async_generator
12+
13+
14+
class _AsyncByteStream(httpx.AsyncByteStream):
15+
"""Wraps an async byte generator as an httpx AsyncByteStream."""
16+
17+
def __init__(self, gen: AsyncGenerator[bytes, None]) -> None:
18+
self._gen = gen
19+
20+
async def __aiter__(self) -> AsyncIterator[bytes]:
21+
async for chunk in self._gen:
22+
yield chunk
23+
24+
25+
class LlamaStackLibraryTransport(httpx.AsyncBaseTransport):
26+
"""Custom httpx transport that dispatches requests through a Llama Stack library client.
27+
28+
Instead of making real HTTP calls, this transport routes requests directly
29+
to the Llama Stack's in-process route handlers via the library client's
30+
route matching and body conversion logic.
31+
"""
32+
33+
def __init__(self, client: AsyncLlamaStackAsLibraryClient) -> None:
34+
self._client = client
35+
36+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
37+
if self._client.route_impls is None:
38+
raise RuntimeError('Llama Stack library client not initialized. Call initialize() first.')
39+
40+
method = request.method
41+
path = request.url.raw_path.decode('utf-8')
42+
43+
body = json.loads(request.content) if request.content else {}
44+
45+
headers: dict[str, str] = {
46+
k.decode('utf-8') if isinstance(k, bytes) else k: v.decode('utf-8')
47+
if isinstance(v, bytes)
48+
else v
49+
for k, v in request.headers.raw
50+
}
51+
52+
if self._client.provider_data:
53+
keys = ['X-LlamaStack-Provider-Data', 'x-llamastack-provider-data']
54+
if all(key not in headers for key in keys):
55+
headers['X-LlamaStack-Provider-Data'] = json.dumps(self._client.provider_data)
56+
57+
with request_provider_data_context(headers):
58+
is_stream = body.get('stream', False)
59+
60+
if is_stream:
61+
return await self._handle_streaming(request, method, path, body)
62+
else:
63+
return await self._handle_non_streaming(request, method, path, body)
64+
65+
async def _handle_non_streaming(
66+
self,
67+
request: httpx.Request,
68+
method: str,
69+
path: str,
70+
body: dict[str, Any],
71+
) -> httpx.Response:
72+
assert self._client.route_impls is not None
73+
74+
matched_func, path_params, _, _ = find_matching_route(
75+
method, path, self._client.route_impls
76+
)
77+
body |= path_params
78+
body = self._client._convert_body(matched_func, body)
79+
80+
result = await matched_func(**body)
81+
82+
json_content = json.dumps(convert_pydantic_to_json_value(result))
83+
status_code = httpx.codes.OK
84+
85+
if method.upper() == 'DELETE' and result is None:
86+
status_code = httpx.codes.NO_CONTENT
87+
json_content = ''
88+
89+
return httpx.Response(
90+
status_code=status_code,
91+
content=json_content.encode('utf-8'),
92+
headers={'Content-Type': 'application/json'},
93+
request=request,
94+
)
95+
96+
async def _handle_streaming(
97+
self,
98+
request: httpx.Request,
99+
method: str,
100+
path: str,
101+
body: dict[str, Any],
102+
) -> httpx.Response:
103+
assert self._client.route_impls is not None
104+
105+
func, path_params, _, _ = find_matching_route(method, path, self._client.route_impls)
106+
body |= path_params
107+
body = self._client._convert_body(func, body)
108+
109+
result = await func(**body)
110+
111+
async def gen() -> AsyncGenerator[bytes, None]:
112+
async for chunk in result:
113+
data = json.dumps(convert_pydantic_to_json_value(chunk))
114+
yield f'data: {data}\n\n'.encode('utf-8')
115+
116+
wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR])
117+
118+
return httpx.Response(
119+
status_code=httpx.codes.OK,
120+
stream=_AsyncByteStream(wrapped_gen),
121+
headers={'Content-Type': 'text/event-stream'},
122+
request=request,
123+
)

0 commit comments

Comments
 (0)