|
| 1 | +"""httpx transport that routes OpenAI-compatible requests through a Llama Stack library client.""" |
| 2 | + |
| 3 | +from __future__ import annotations as _annotations |
| 4 | + |
| 5 | +import json |
| 6 | +from collections.abc import AsyncGenerator, AsyncIterator |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +import httpx |
| 10 | +from llama_stack.core.library_client import ( |
| 11 | + AsyncLlamaStackAsLibraryClient, |
| 12 | + convert_pydantic_to_json_value, |
| 13 | +) |
| 14 | +from llama_stack.core.request_headers import ( |
| 15 | + PROVIDER_DATA_VAR, |
| 16 | + request_provider_data_context, |
| 17 | +) |
| 18 | +from llama_stack.core.server.routes import find_matching_route |
| 19 | +from llama_stack.core.utils.context import preserve_contexts_async_generator |
| 20 | + |
| 21 | + |
| 22 | +class _AsyncByteStream(httpx.AsyncByteStream): |
| 23 | + """Wraps an async byte generator as an httpx AsyncByteStream.""" |
| 24 | + |
| 25 | + def __init__(self, gen: AsyncGenerator[bytes, None]) -> None: |
| 26 | + """Store an async generator that yields raw bytes for streaming. |
| 27 | +
|
| 28 | + Args: |
| 29 | + gen: An async generator producing byte chunks to stream. |
| 30 | + """ |
| 31 | + self._gen = gen |
| 32 | + |
| 33 | + async def __aiter__(self) -> AsyncIterator[bytes]: |
| 34 | + """Yield bytes chunks from the wrapped generator. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + An async iterator of bytes fulfilling the httpx.AsyncByteStream contract. |
| 38 | + """ |
| 39 | + async for chunk in self._gen: |
| 40 | + yield chunk |
| 41 | + |
| 42 | + |
| 43 | +class LlamaStackLibraryTransport(httpx.AsyncBaseTransport): |
| 44 | + """Custom httpx transport that dispatches requests through a Llama Stack library client. |
| 45 | +
|
| 46 | + Instead of making real HTTP calls, this transport routes requests directly |
| 47 | + to the Llama Stack's in-process route handlers via the library client's |
| 48 | + route matching and body conversion logic. |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__(self, client: AsyncLlamaStackAsLibraryClient) -> None: |
| 52 | + """Initialize the transport with a Llama Stack library client. |
| 53 | +
|
| 54 | + Args: |
| 55 | + client: An initialized ``AsyncLlamaStackAsLibraryClient`` whose route |
| 56 | + handlers will receive dispatched requests. |
| 57 | + """ |
| 58 | + self._client = client |
| 59 | + |
| 60 | + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: |
| 61 | + """Dispatch an httpx request to the in-process Llama Stack route handlers. |
| 62 | +
|
| 63 | + Args: |
| 64 | + request: The outgoing httpx request to route. |
| 65 | +
|
| 66 | + Returns: |
| 67 | + An httpx response built from the matched route handler result. |
| 68 | +
|
| 69 | + Raises: |
| 70 | + RuntimeError: If the library client has not been initialized. |
| 71 | + """ |
| 72 | + if self._client.route_impls is None: |
| 73 | + raise RuntimeError( |
| 74 | + "Llama Stack library client not initialized. Call initialize() first." |
| 75 | + ) |
| 76 | + |
| 77 | + method = request.method |
| 78 | + path = request.url.raw_path.decode("utf-8") |
| 79 | + |
| 80 | + body = json.loads(request.content) if request.content else {} |
| 81 | + |
| 82 | + headers: dict[str, str] = { |
| 83 | + k.decode("utf-8") if isinstance(k, bytes) else k: ( |
| 84 | + v.decode("utf-8") if isinstance(v, bytes) else v |
| 85 | + ) |
| 86 | + for k, v in request.headers.raw |
| 87 | + } |
| 88 | + |
| 89 | + if self._client.provider_data: |
| 90 | + keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"] |
| 91 | + if all(key not in headers for key in keys): |
| 92 | + headers["X-LlamaStack-Provider-Data"] = json.dumps( |
| 93 | + self._client.provider_data |
| 94 | + ) |
| 95 | + |
| 96 | + with request_provider_data_context(headers): |
| 97 | + is_stream = body.get("stream", False) |
| 98 | + |
| 99 | + if is_stream: |
| 100 | + return await self._handle_streaming(request, method, path, body) |
| 101 | + return await self._handle_non_streaming(request, method, path, body) |
| 102 | + |
| 103 | + async def _handle_non_streaming( |
| 104 | + self, |
| 105 | + request: httpx.Request, |
| 106 | + method: str, |
| 107 | + path: str, |
| 108 | + body: dict[str, Any], |
| 109 | + ) -> httpx.Response: |
| 110 | + """Dispatch a non-streaming request to the matched route handler. |
| 111 | +
|
| 112 | + Args: |
| 113 | + request: The original httpx request (attached to the response). |
| 114 | + method: The HTTP method (e.g. ``"POST"``). |
| 115 | + path: The decoded URL path used for route matching. |
| 116 | + body: The parsed JSON request body. |
| 117 | +
|
| 118 | + Returns: |
| 119 | + An httpx.Response containing the JSON-serialized handler result. |
| 120 | +
|
| 121 | + Raises: |
| 122 | + RuntimeError: If route_impls is not initialized. |
| 123 | + """ |
| 124 | + if self._client.route_impls is None: |
| 125 | + raise RuntimeError("route_impls is not initialized") |
| 126 | + |
| 127 | + matched_func, path_params, _, _ = find_matching_route( |
| 128 | + method, path, self._client.route_impls |
| 129 | + ) |
| 130 | + merged_body = {**body, **path_params} |
| 131 | + merged_body = self._client._convert_body( # pylint: disable=protected-access |
| 132 | + matched_func, merged_body |
| 133 | + ) |
| 134 | + |
| 135 | + result = await matched_func(**merged_body) |
| 136 | + |
| 137 | + json_content = json.dumps(convert_pydantic_to_json_value(result)) |
| 138 | + status_code = httpx.codes.OK |
| 139 | + |
| 140 | + if method.upper() == "DELETE" and result is None: |
| 141 | + status_code = httpx.codes.NO_CONTENT |
| 142 | + json_content = "" |
| 143 | + |
| 144 | + return httpx.Response( |
| 145 | + status_code=status_code, |
| 146 | + content=json_content.encode("utf-8"), |
| 147 | + headers={"Content-Type": "application/json"}, |
| 148 | + request=request, |
| 149 | + ) |
| 150 | + |
| 151 | + async def _handle_streaming( |
| 152 | + self, |
| 153 | + request: httpx.Request, |
| 154 | + method: str, |
| 155 | + path: str, |
| 156 | + body: dict[str, Any], |
| 157 | + ) -> httpx.Response: |
| 158 | + """Dispatch a streaming request and return an SSE event-stream response. |
| 159 | +
|
| 160 | + Args: |
| 161 | + request: The original httpx request (attached to the response). |
| 162 | + method: The HTTP method (e.g. ``"POST"``). |
| 163 | + path: The decoded URL path used for route matching. |
| 164 | + body: The parsed JSON request body (must contain ``stream: True``). |
| 165 | +
|
| 166 | + Returns: |
| 167 | + An httpx.Response with a streaming body of SSE-formatted chunks. |
| 168 | +
|
| 169 | + Raises: |
| 170 | + RuntimeError: If route_impls is not initialized. |
| 171 | + """ |
| 172 | + if self._client.route_impls is None: |
| 173 | + raise RuntimeError("route_impls is not initialized") |
| 174 | + |
| 175 | + func, path_params, _, _ = find_matching_route( |
| 176 | + method, path, self._client.route_impls |
| 177 | + ) |
| 178 | + merged_body = {**body, **path_params} |
| 179 | + merged_body = self._client._convert_body( # pylint: disable=protected-access |
| 180 | + func, merged_body |
| 181 | + ) |
| 182 | + |
| 183 | + result = await func(**merged_body) |
| 184 | + |
| 185 | + async def gen() -> AsyncGenerator[bytes, None]: |
| 186 | + async for chunk in result: |
| 187 | + data = json.dumps(convert_pydantic_to_json_value(chunk)) |
| 188 | + yield f"data: {data}\n\n".encode("utf-8") |
| 189 | + |
| 190 | + wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR]) |
| 191 | + |
| 192 | + return httpx.Response( |
| 193 | + status_code=httpx.codes.OK, |
| 194 | + stream=_AsyncByteStream(wrapped_gen), |
| 195 | + headers={"Content-Type": "text/event-stream"}, |
| 196 | + request=request, |
| 197 | + ) |
0 commit comments