-
Notifications
You must be signed in to change notification settings - Fork 94
LCORE-2308: LlamaStack Pydantic AI Provider #1806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
3def1bc
2c38478
9926b0c
5460b50
7d77ef4
45ba392
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Pydantic AI integrations/extensions for Lightspeed Core Stack.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| """Pydantic AI provider for Llama Stack.""" | ||
|
|
||
| from ._provider import LlamaStackProvider | ||
|
|
||
| __all__ = ['LlamaStackProvider'] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| import os | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import httpx | ||
| from openai import AsyncOpenAI | ||
|
|
||
| from pydantic_ai import ModelProfile | ||
| from pydantic_ai.models import create_async_http_client | ||
| from pydantic_ai.profiles.openai import openai_model_profile | ||
| from pydantic_ai.providers import Provider | ||
|
|
||
| if TYPE_CHECKING: | ||
| from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient | ||
|
|
||
| DEFAULT_BASE_URL = 'http://localhost:8321/v1' | ||
|
|
||
|
|
||
| class LlamaStackProvider(Provider[AsyncOpenAI]): | ||
| """Provider for Llama Stack — connects to a Llama Stack server's OpenAI-compatible API. | ||
|
|
||
| Supports two modes: | ||
|
|
||
| 1. **Server mode** — connect to a running Llama Stack server via HTTP | ||
| 2. **Library mode** — run Llama Stack in-process via ``AsyncLlamaStackAsLibraryClient`` | ||
| """ | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| return 'llama-stack' | ||
|
|
||
| @property | ||
| def base_url(self) -> str: | ||
| return str(self._client.base_url) | ||
|
|
||
| @property | ||
| def client(self) -> AsyncOpenAI: | ||
| return self._client | ||
|
|
||
| @staticmethod | ||
| def model_profile(model_name: str) -> ModelProfile | None: | ||
| return openai_model_profile(model_name) | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| base_url: str | None = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we prefer to use |
||
| api_key: str | None = None, | ||
| library_client: AsyncLlamaStackAsLibraryClient | None = None, | ||
| http_client: httpx.AsyncClient | None = None, | ||
| ) -> None: | ||
| """Create a new Llama Stack provider. | ||
|
|
||
| Args: | ||
| base_url: The base URL for the Llama Stack server (OpenAI-compatible endpoint). | ||
| Defaults to ``LLAMA_STACK_BASE_URL`` env var, then ``http://localhost:8321/v1``. | ||
| Must be ``None`` when ``library_client`` is provided. | ||
| api_key: The API key for authentication. Defaults to ``LLAMA_STACK_API_KEY`` env | ||
| var, then ``'not-needed'`` since local Llama Stack servers typically don't | ||
| require one. Must be ``None`` when ``library_client`` is provided. | ||
| library_client: An initialized ``AsyncLlamaStackAsLibraryClient`` for library mode. | ||
| When provided, requests are dispatched in-process (no server needed). | ||
| Mutually exclusive with ``base_url``, ``api_key``, and ``http_client``. | ||
| http_client: An existing ``httpx.AsyncClient`` to use for making HTTP requests. | ||
| Must be ``None`` when ``library_client`` is provided. | ||
| """ | ||
| if library_client is not None: | ||
| assert base_url is None, 'Cannot provide both `library_client` and `base_url`' | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||
| assert api_key is None, 'Cannot provide both `library_client` and `api_key`' | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||
| assert http_client is None, 'Cannot provide both `library_client` and `http_client`' | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||
|
|
||
| from ._transport import LlamaStackLibraryTransport | ||
|
|
||
| self._library_client = library_client | ||
| transport = LlamaStackLibraryTransport(library_client) | ||
| lib_http_client = httpx.AsyncClient( | ||
| transport=transport, base_url='http://llama-stack-library' | ||
Check warningCode scanning / Bandit Call to httpx without timeout Warning
Call to httpx without timeout
|
||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||
| ) | ||
| self._client = AsyncOpenAI( | ||
| http_client=lib_http_client, | ||
| base_url='http://llama-stack-library/v1', | ||
| api_key='not-needed', | ||
| ) | ||
| else: | ||
| base_url = base_url or os.environ.get('LLAMA_STACK_BASE_URL') or DEFAULT_BASE_URL | ||
| api_key = api_key or os.environ.get('LLAMA_STACK_API_KEY') or 'not-needed' | ||
|
|
||
| if http_client is not None: | ||
| self._client = AsyncOpenAI( | ||
| base_url=base_url, api_key=api_key, http_client=http_client | ||
| ) | ||
| else: | ||
| oai_http_client = create_async_http_client() | ||
| self._client = AsyncOpenAI( | ||
| base_url=base_url, api_key=api_key, http_client=oai_http_client | ||
| ) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f'LlamaStackProvider(name={self.name!r}, base_url={self.base_url!r})' | ||
|
|
||
| def _set_http_client(self, http_client: httpx.AsyncClient) -> None: | ||
| self._client._client = http_client # pyright: ignore[reportPrivateUsage] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| import json | ||
| from collections.abc import AsyncGenerator, AsyncIterator | ||
| from typing import Any | ||
|
|
||
| import httpx | ||
| from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient, convert_pydantic_to_json_value | ||
| from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider_data_context | ||
| from llama_stack.core.server.routes import find_matching_route | ||
| from llama_stack.core.utils.context import preserve_contexts_async_generator | ||
|
|
||
|
|
||
| class _AsyncByteStream(httpx.AsyncByteStream): | ||
| """Wraps an async byte generator as an httpx AsyncByteStream.""" | ||
|
|
||
| def __init__(self, gen: AsyncGenerator[bytes, None]) -> None: | ||
| self._gen = gen | ||
|
|
||
| async def __aiter__(self) -> AsyncIterator[bytes]: | ||
| async for chunk in self._gen: | ||
| yield chunk | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
|
|
||
| class LlamaStackLibraryTransport(httpx.AsyncBaseTransport): | ||
| """Custom httpx transport that dispatches requests through a Llama Stack library client. | ||
|
|
||
| Instead of making real HTTP calls, this transport routes requests directly | ||
| to the Llama Stack's in-process route handlers via the library client's | ||
| route matching and body conversion logic. | ||
| """ | ||
|
|
||
| def __init__(self, client: AsyncLlamaStackAsLibraryClient) -> None: | ||
| self._client = client | ||
|
|
||
| async def handle_async_request(self, request: httpx.Request) -> httpx.Response: | ||
| if self._client.route_impls is None: | ||
| raise RuntimeError('Llama Stack library client not initialized. Call initialize() first.') | ||
|
|
||
| method = request.method | ||
| path = request.url.raw_path.decode('utf-8') | ||
|
|
||
| body = json.loads(request.content) if request.content else {} | ||
|
|
||
| headers: dict[str, str] = { | ||
| k.decode('utf-8') if isinstance(k, bytes) else k: v.decode('utf-8') | ||
| if isinstance(v, bytes) | ||
| else v | ||
| for k, v in request.headers.raw | ||
| } | ||
|
|
||
| if self._client.provider_data: | ||
| keys = ['X-LlamaStack-Provider-Data', 'x-llamastack-provider-data'] | ||
| if all(key not in headers for key in keys): | ||
| headers['X-LlamaStack-Provider-Data'] = json.dumps(self._client.provider_data) | ||
|
|
||
| with request_provider_data_context(headers): | ||
| is_stream = body.get('stream', False) | ||
|
|
||
| if is_stream: | ||
| return await self._handle_streaming(request, method, path, body) | ||
| else: | ||
| return await self._handle_non_streaming(request, method, path, body) | ||
|
|
||
| async def _handle_non_streaming( | ||
| self, | ||
| request: httpx.Request, | ||
| method: str, | ||
| path: str, | ||
| body: dict[str, Any], | ||
| ) -> httpx.Response: | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| assert self._client.route_impls is not None | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||
|
|
||
| matched_func, path_params, _, _ = find_matching_route( | ||
| method, path, self._client.route_impls | ||
| ) | ||
| body |= path_params | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
| body = self._client._convert_body(matched_func, body) | ||
|
|
||
| result = await matched_func(**body) | ||
|
|
||
| json_content = json.dumps(convert_pydantic_to_json_value(result)) | ||
| status_code = httpx.codes.OK | ||
|
|
||
| if method.upper() == 'DELETE' and result is None: | ||
| status_code = httpx.codes.NO_CONTENT | ||
| json_content = '' | ||
|
|
||
| return httpx.Response( | ||
| status_code=status_code, | ||
| content=json_content.encode('utf-8'), | ||
| headers={'Content-Type': 'application/json'}, | ||
| request=request, | ||
| ) | ||
|
|
||
| async def _handle_streaming( | ||
| self, | ||
| request: httpx.Request, | ||
| method: str, | ||
| path: str, | ||
| body: dict[str, Any], | ||
| ) -> httpx.Response: | ||
| assert self._client.route_impls is not None | ||
Check noticeCode scanning / Bandit Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
|
||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||
|
|
||
| func, path_params, _, _ = find_matching_route(method, path, self._client.route_impls) | ||
| body |= path_params | ||
| body = self._client._convert_body(func, body) | ||
|
|
||
| result = await func(**body) | ||
|
|
||
| async def gen() -> AsyncGenerator[bytes, None]: | ||
| async for chunk in result: | ||
| data = json.dumps(convert_pydantic_to_json_value(chunk)) | ||
| yield f'data: {data}\n\n'.encode('utf-8') | ||
|
|
||
| wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR]) | ||
|
|
||
| return httpx.Response( | ||
| status_code=httpx.codes.OK, | ||
| stream=_AsyncByteStream(wrapped_gen), | ||
| headers={'Content-Type': 'text/event-stream'}, | ||
| request=request, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.