Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions src/inference_endpoint/endpoint_client/adapter_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base class for HTTP request adapters."""

import re
from abc import ABC, abstractmethod

from inference_endpoint.core.types import Query, QueryResult


class HttpRequestAdapter(ABC):
"""
Abstract base class for HTTP request adapters.

Adapters convert between internal Query/QueryResult types and
endpoint-specific formats (e.g., OpenAI, custom formats).
"""

# SSE (Server-Sent Events) is an HTTP standard
# Pre-compiled regex for extracting SSE data fields with JSON content
# Matches "data: {json content}" and captures the JSON part
SSE_DATA_PATTERN: re.Pattern[bytes] = re.compile(rb"data:\s*(\{[^\n]+\})")
Comment thread
viraatc marked this conversation as resolved.

@staticmethod
@abstractmethod
def encode_query(query: Query) -> bytes:
"""
Encode a Query to bytes for HTTP transmission.

Args:
query: Input query with prompt and parameters

Returns:
Encoded request bytes ready for HTTP POST
"""
raise NotImplementedError("encode_query not implemented")

@staticmethod
@abstractmethod
def decode_response(response_bytes: bytes, query_id: str) -> QueryResult:
"""
Decode HTTP response bytes to QueryResult.

Args:
response_bytes: Raw bytes from HTTP response
query_id: ID for the query (to associate with result)

Returns:
QueryResult with extracted content
"""
raise NotImplementedError("decode_response not implemented")

@staticmethod
@abstractmethod
def decode_sse_message(json_bytes: bytes) -> str:
"""
Decode SSE message and extract content string.

Args:
json_bytes: Raw JSON bytes from SSE stream

Returns:
Content string from the SSE message
"""
raise NotImplementedError("decode_sse_message not implemented")
14 changes: 14 additions & 0 deletions src/inference_endpoint/endpoint_client/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import aiohttp
import zmq

from inference_endpoint.endpoint_client.adapter_protocol import HttpRequestAdapter


@dataclass
class HTTPClientConfig:
Expand Down Expand Up @@ -55,6 +57,18 @@ class HTTPClientConfig:
# - add max-sequence-length to HttpClient config (not per-query), base streaming_buffer_size on it
streaming_buffer_size: int = 128 * 1024 # 128KB buffer for streaming tokens

# Request adapter for Query/Response <-> Payload/Response bytes
adapter: type[HttpRequestAdapter] | None = field(default=None, init=False)

def __post_init__(self):
# set default adapter in __post_init__ to avoid circular dependency
if self.adapter is None:
from inference_endpoint.openai.openai_msgspec_adapter import (
OpenAIMsgspecAdapter,
)

self.adapter = OpenAIMsgspecAdapter


Comment thread
viraatc marked this conversation as resolved.
@dataclass
class SocketConfig:
Expand Down
4 changes: 3 additions & 1 deletion src/inference_endpoint/endpoint_client/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def __init__(
self._response_socket: ZMQPullSocket | None = None
self._concurrency_semaphore: asyncio.Semaphore | None = None

self.logger = logging.getLogger(__name__)
logger.info(
f"HTTP endpoint client using adapter: {self.config.adapter.__name__}"
)
Comment thread
viraatc marked this conversation as resolved.

def start(self):
"""Start event loop thread and initialize client."""
Expand Down
45 changes: 22 additions & 23 deletions src/inference_endpoint/endpoint_client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
from typing import Any

import aiohttp
import msgspec
import orjson
import zmq
import zmq.asyncio

Expand All @@ -43,7 +41,6 @@
ZMQConfig,
)
from inference_endpoint.endpoint_client.zmq_utils import ZMQPullSocket, ZMQPushSocket
from inference_endpoint.openai.openai_adapter import OpenAIAdapter, SSEMessage
from inference_endpoint.profiling import profile

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -141,8 +138,8 @@ def __init__(
# Track active request tasks
self._active_tasks: set[asyncio.Task] = set()

# Reusable typed decoder for SSE chunk parsing (struct access faster than dict)
self._sse_decoder: msgspec.json.Decoder = msgspec.json.Decoder(SSEMessage)
# Use adapter type from config
self._adapter = self.http_config.adapter

async def run(self) -> None:
"""Main worker loop - pull requests, execute, push responses."""
Expand Down Expand Up @@ -176,7 +173,6 @@ async def run(self) -> None:
connector=self.tcp_connector,
connector_owner=False, # owned by Worker
skip_auto_headers=self.aiohttp_config.skip_auto_headers,
json_serialize=lambda obj: orjson.dumps(obj).decode("utf-8"),
)

# Signal handlers for graceful shutdown
Expand Down Expand Up @@ -271,15 +267,18 @@ async def _make_http_request(self, query: Query):

url = self.http_config.endpoint_url
headers = query.headers if hasattr(query, "headers") else {}
payload = OpenAIAdapter.to_openai_request(query).model_dump(
mode="json", exclude_unset=True
)

# Issue the request
logging.debug(
f"Making HTTP request to {url} with payload: {payload} and headers: {headers}"
f"Making HTTP request to {url} with payload: {query} and headers: {headers}"
)
async with self._session.post(url, json=payload, headers=headers) as response:

# Encode query to bytes using adapter
payload_bytes = self._adapter.encode_query(query)

# Issue the request with pre-encoded bytes
async with self._session.post(
url, data=payload_bytes, headers=headers
Comment thread
viraatc marked this conversation as resolved.
) as response:
Comment thread
viraatc marked this conversation as resolved.
if response.status != 200:
error_text = await response.text()
await self._handle_error(
Expand All @@ -303,14 +302,14 @@ async def _process_request(self, query: Query) -> None:

@profile
def _parse_sse_chunk(self, buffer: bytes, end_pos: int) -> list[str]:
"""Parse SSE chunk and extract content using msgspec typed decode."""
json_docs = OpenAIAdapter.SSE_DATA_PATTERN.findall(buffer[:end_pos])

"""Parse SSE chunk and extract content using adapter's decoder."""
json_docs = self._adapter.SSE_DATA_PATTERN.findall(buffer[:end_pos])
Comment thread
viraatc marked this conversation as resolved.
Outdated
parsed_contents = []

try:
for json_doc in json_docs:
msg = self._sse_decoder.decode(json_doc)
parsed_contents.append(msg.choices[0].delta.content)
content = self._adapter.decode_sse_message(json_doc)
parsed_contents.append(content)
except Exception:
# Normal for non-content SSE messages (role, finish_reason, etc)
pass
Comment thread
viraatc marked this conversation as resolved.
Outdated
Expand Down Expand Up @@ -413,10 +412,8 @@ async def _handle_non_streaming_request(self, query: Query) -> None:
"""Handle non-streaming response."""
async for response in self._make_http_request(query):
response_bytes = await response.read()
response_data = orjson.loads(response_bytes)
response_obj = OpenAIAdapter.from_json_response(query.id, response_data)
# Send response back to the main process
await self._response_socket.send(response_obj)
result = self._adapter.decode_response(response_bytes, query.id)
await self._response_socket.send(result)

def shutdown(self, signum: int | None = None, frame: Any | None = None) -> None:
"""Trigger shutdown of worker process."""
Expand Down Expand Up @@ -480,6 +477,8 @@ async def initialize(self) -> None:
)

try:
logger.info(f"Starting {self.http_config.num_workers} worker processes")

# Spawn worker processes
for i in range(self.http_config.num_workers):
worker = self._spawn_worker(i)
Expand All @@ -495,7 +494,7 @@ async def wait_for_all_workers():
worker_id = await readiness_socket.receive()
if worker_id is not None:
ready_count += 1
logger.info(
logger.debug(
f"Worker {worker_id} is ready ({ready_count}/{self.http_config.num_workers})"
)

Expand All @@ -505,7 +504,7 @@ async def wait_for_all_workers():
wait_for_all_workers(),
timeout=self.http_config.worker_initialization_timeout,
)
logger.info(f"All {ready_count} workers are ready")
logger.info(f"{ready_count}/{self.http_config.num_workers} workers ready")
except TimeoutError as e:
raise TimeoutError(
f"Workers failed to initialize within {self.http_config.worker_initialization_timeout} seconds."
Expand Down
93 changes: 43 additions & 50 deletions src/inference_endpoint/openai/openai_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import time

import msgspec
import orjson
from inference_endpoint.core.types import Query, QueryResult
from inference_endpoint.endpoint_client.adapter_protocol import HttpRequestAdapter

from .openai_types_gen import (
ChatCompletionResponseMessage,
Expand Down Expand Up @@ -55,15 +56,33 @@ class SSEMessage(msgspec.Struct):
choices: list[SSEChoice] = msgspec.field(default_factory=list)


class OpenAIAdapter:
class OpenAIAdapter(HttpRequestAdapter):
"""Adapter for OpenAI API."""

# Pre-compiled regex for extracting SSE data fields with JSON content
# Matches "data: {json content}" and captures the JSON part
SSE_DATA_PATTERN = re.compile(rb"data:\s*(\{[^\n]+\})", re.MULTILINE)
@staticmethod
def encode_query(query: Query) -> bytes:
"""Encode a Query to bytes for HTTP transmission."""
request = OpenAIAdapter.to_endpoint_request(query)
return OpenAIAdapter.encode_request(request)

@staticmethod
def decode_response(response_bytes: bytes, query_id: str) -> QueryResult:
"""Decode HTTP response bytes to QueryResult."""
openai_response = OpenAIAdapter.decode_endpoint_response(response_bytes)
return OpenAIAdapter.from_endpoint_response(openai_response, result_id=query_id)

@staticmethod
def decode_sse_message(json_bytes: bytes) -> str:
"""Decode SSE message and extract content string."""
msg = msgspec.json.decode(json_bytes, type=SSEMessage)
return msg.choices[0].delta.content

# ========================================================================
# Internal APIs
# ========================================================================

@staticmethod
def to_openai_request(query: Query) -> CreateChatCompletionRequest:
def to_endpoint_request(query: Query) -> CreateChatCompletionRequest:
"""Convert a Query to an OpenAI request."""
if "prompt" not in query.data:
raise ValueError("prompt not found in json_value")
Expand All @@ -86,33 +105,11 @@ def to_openai_request(query: Query) -> CreateChatCompletionRequest:
return request

@staticmethod
def from_openai_request(request: CreateChatCompletionRequest) -> Query:
"""Convert an OpenAI request to a Query."""
if not request.messages or len(request.messages) == 0:
raise ValueError("Request must contain at least one message")
return Query(
data={
"prompt": request.messages[0].root.content,
"model": request.model,
"stream": request.stream,
},
)

@staticmethod
def from_openai_response(
def from_endpoint_response(
response: CreateChatCompletionResponse,
result_id: str | None = None,
) -> QueryResult:
"""Convert an OpenAI response to a QueryResult.
Args:
response: The OpenAI response to convert.
result_id: If provided, use this as the ID for the QueryResult. Otherwise,
uses the response ID from the OpenAI response. This is useful
since QueryResult is a frozen dataclass, and `id` cannot be changed
after creation. (Default: None)
Returns:
A QueryResult object.
"""
"""Convert an OpenAI response to a QueryResult."""
if not response.choices:
raise ValueError("Response must contain at least one choice")

Expand All @@ -125,26 +122,7 @@ def from_openai_response(
)

@staticmethod
def from_json_response(query_id, response: dict) -> QueryResult:
"""Convert an OpenAI response data to a QueryResult.
Note that this function fixes the fields to be compatible with
OpenAI pydantic definitions. This includes updating the refusal and
logprobs fields to be compatible with the OpenAI pydantic definitions.
Args:
query_id: The ID of the query.
response: The OpenAI response data to convert.
Returns:
A QueryResult object.
"""
response["choices"][0]["message"]["refusal"] = "None"
response["choices"][0]["logprobs"] = {"content": [], "refusal": []}
return OpenAIAdapter.from_openai_response(
CreateChatCompletionResponse(**response, ignore_extra=True),
result_id=query_id,
)

@staticmethod
def to_openai_response(result: QueryResult) -> CreateChatCompletionResponse:
def to_endpoint_response(result: QueryResult) -> CreateChatCompletionResponse:
"""Convert a QueryResult to an OpenAI response."""
return CreateChatCompletionResponse(
id=result.id,
Expand All @@ -163,3 +141,18 @@ def to_openai_response(result: QueryResult) -> CreateChatCompletionResponse:
object=Object7.chat_completion,
service_tier=ServiceTier.auto,
)

@staticmethod
def encode_request(request: CreateChatCompletionRequest) -> bytes:
"""Encode request to JSON bytes using orjson."""
return orjson.dumps(request.model_dump(mode="json"))

@staticmethod
def decode_endpoint_response(response_bytes: bytes) -> CreateChatCompletionResponse:
"""Decode response from JSON bytes using orjson."""
response_dict = orjson.loads(response_bytes)

# Set default values for optional fields if missing
Comment thread
viraatc marked this conversation as resolved.
response_dict["choices"][0]["message"]["refusal"] = "None"
Comment thread
viraatc marked this conversation as resolved.
response_dict["choices"][0]["logprobs"] = {"content": [], "refusal": []}
Comment thread
viraatc marked this conversation as resolved.
return CreateChatCompletionResponse(**response_dict, ignore_extra=True)
Loading
Loading