Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ dependencies = [
"python-dotenv>=1.2.2",
# Used for token estimation before LLM calls (LCORE-1569 / conversation compaction)
"tiktoken>=0.8.0",
# Used for Pydantic AI
"pydantic-ai>=1.99.0"
]


Expand Down
1 change: 1 addition & 0 deletions src/pydantic_ai_lightspeed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Pydantic AI integrations/extensions for Lightspeed Core Stack."""
5 changes: 5 additions & 0 deletions src/pydantic_ai_lightspeed/llamastack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Pydantic AI provider for Llama Stack."""

from ._provider import LlamaStackProvider
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

__all__ = ['LlamaStackProvider']
103 changes: 103 additions & 0 deletions src/pydantic_ai_lightspeed/llamastack/_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations as _annotations

import os
from typing import TYPE_CHECKING

import httpx
from openai import AsyncOpenAI

from pydantic_ai import ModelProfile
from pydantic_ai.models import create_async_http_client
from pydantic_ai.profiles.openai import openai_model_profile
from pydantic_ai.providers import Provider

if TYPE_CHECKING:
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient

DEFAULT_BASE_URL = 'http://localhost:8321/v1'


class LlamaStackProvider(Provider[AsyncOpenAI]):
"""Provider for Llama Stack — connects to a Llama Stack server's OpenAI-compatible API.

Supports two modes:

1. **Server mode** — connect to a running Llama Stack server via HTTP
2. **Library mode** — run Llama Stack in-process via ``AsyncLlamaStackAsLibraryClient``
"""

@property
def name(self) -> str:
return 'llama-stack'

@property
def base_url(self) -> str:
return str(self._client.base_url)

@property
def client(self) -> AsyncOpenAI:
return self._client

@staticmethod
def model_profile(model_name: str) -> ModelProfile | None:
return openai_model_profile(model_name)

def __init__(
self,
*,
base_url: str | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we prefer to use Optional[str] style (for now)

api_key: str | None = None,
library_client: AsyncLlamaStackAsLibraryClient | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
"""Create a new Llama Stack provider.

Args:
base_url: The base URL for the Llama Stack server (OpenAI-compatible endpoint).
Defaults to ``LLAMA_STACK_BASE_URL`` env var, then ``http://localhost:8321/v1``.
Must be ``None`` when ``library_client`` is provided.
api_key: The API key for authentication. Defaults to ``LLAMA_STACK_API_KEY`` env
var, then ``'not-needed'`` since local Llama Stack servers typically don't
require one. Must be ``None`` when ``library_client`` is provided.
library_client: An initialized ``AsyncLlamaStackAsLibraryClient`` for library mode.
When provided, requests are dispatched in-process (no server needed).
Mutually exclusive with ``base_url``, ``api_key``, and ``http_client``.
http_client: An existing ``httpx.AsyncClient`` to use for making HTTP requests.
Must be ``None`` when ``library_client`` is provided.
"""
if library_client is not None:
assert base_url is None, 'Cannot provide both `library_client` and `base_url`'

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
assert api_key is None, 'Cannot provide both `library_client` and `api_key`'

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
assert http_client is None, 'Cannot provide both `library_client` and `http_client`'

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed

from ._transport import LlamaStackLibraryTransport

self._library_client = library_client
transport = LlamaStackLibraryTransport(library_client)
lib_http_client = httpx.AsyncClient(
transport=transport, base_url='http://llama-stack-library'

Check warning

Code scanning / Bandit

Call to httpx without timeout Warning

Call to httpx without timeout
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
)
self._client = AsyncOpenAI(
http_client=lib_http_client,
base_url='http://llama-stack-library/v1',
api_key='not-needed',
)
else:
base_url = base_url or os.environ.get('LLAMA_STACK_BASE_URL') or DEFAULT_BASE_URL
api_key = api_key or os.environ.get('LLAMA_STACK_API_KEY') or 'not-needed'

if http_client is not None:
self._client = AsyncOpenAI(
base_url=base_url, api_key=api_key, http_client=http_client
)
else:
oai_http_client = create_async_http_client()
self._client = AsyncOpenAI(
base_url=base_url, api_key=api_key, http_client=oai_http_client
)

def __repr__(self) -> str:
return f'LlamaStackProvider(name={self.name!r}, base_url={self.base_url!r})'

def _set_http_client(self, http_client: httpx.AsyncClient) -> None:
self._client._client = http_client # pyright: ignore[reportPrivateUsage]
123 changes: 123 additions & 0 deletions src/pydantic_ai_lightspeed/llamastack/_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations as _annotations

import json
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any

import httpx
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient, convert_pydantic_to_json_value
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider_data_context
from llama_stack.core.server.routes import find_matching_route
from llama_stack.core.utils.context import preserve_contexts_async_generator


class _AsyncByteStream(httpx.AsyncByteStream):
"""Wraps an async byte generator as an httpx AsyncByteStream."""

def __init__(self, gen: AsyncGenerator[bytes, None]) -> None:
self._gen = gen

async def __aiter__(self) -> AsyncIterator[bytes]:
async for chunk in self._gen:
yield chunk
Comment thread
coderabbitai[bot] marked this conversation as resolved.


class LlamaStackLibraryTransport(httpx.AsyncBaseTransport):
"""Custom httpx transport that dispatches requests through a Llama Stack library client.

Instead of making real HTTP calls, this transport routes requests directly
to the Llama Stack's in-process route handlers via the library client's
route matching and body conversion logic.
"""

def __init__(self, client: AsyncLlamaStackAsLibraryClient) -> None:
self._client = client

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if self._client.route_impls is None:
raise RuntimeError('Llama Stack library client not initialized. Call initialize() first.')

method = request.method
path = request.url.raw_path.decode('utf-8')

body = json.loads(request.content) if request.content else {}

headers: dict[str, str] = {
k.decode('utf-8') if isinstance(k, bytes) else k: v.decode('utf-8')
if isinstance(v, bytes)
else v
for k, v in request.headers.raw
}

if self._client.provider_data:
keys = ['X-LlamaStack-Provider-Data', 'x-llamastack-provider-data']
if all(key not in headers for key in keys):
headers['X-LlamaStack-Provider-Data'] = json.dumps(self._client.provider_data)

with request_provider_data_context(headers):
is_stream = body.get('stream', False)

if is_stream:
return await self._handle_streaming(request, method, path, body)
else:
return await self._handle_non_streaming(request, method, path, body)

async def _handle_non_streaming(
self,
request: httpx.Request,
method: str,
path: str,
body: dict[str, Any],
) -> httpx.Response:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
assert self._client.route_impls is not None

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed

matched_func, path_params, _, _ = find_matching_route(
method, path, self._client.route_impls
)
body |= path_params
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
body = self._client._convert_body(matched_func, body)

result = await matched_func(**body)

json_content = json.dumps(convert_pydantic_to_json_value(result))
status_code = httpx.codes.OK

if method.upper() == 'DELETE' and result is None:
status_code = httpx.codes.NO_CONTENT
json_content = ''

return httpx.Response(
status_code=status_code,
content=json_content.encode('utf-8'),
headers={'Content-Type': 'application/json'},
request=request,
)

async def _handle_streaming(
self,
request: httpx.Request,
method: str,
path: str,
body: dict[str, Any],
) -> httpx.Response:
assert self._client.route_impls is not None

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed

func, path_params, _, _ = find_matching_route(method, path, self._client.route_impls)
body |= path_params
body = self._client._convert_body(func, body)

result = await func(**body)

async def gen() -> AsyncGenerator[bytes, None]:
async for chunk in result:
data = json.dumps(convert_pydantic_to_json_value(chunk))
yield f'data: {data}\n\n'.encode('utf-8')

wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR])

return httpx.Response(
status_code=httpx.codes.OK,
stream=_AsyncByteStream(wrapped_gen),
headers={'Content-Type': 'text/event-stream'},
request=request,
)
Loading
Loading