Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 79 additions & 49 deletions cli/serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import typer
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse

from .models import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionRequest,
Choice,
CompletionUsage,
OpenAIError,
OpenAIErrorResponse,
)

app = FastAPI(
Expand All @@ -35,60 +38,87 @@ def load_module_from_path(path: str):
return module


def create_openai_error_response(
status_code: int, message: str, error_type: str, param: str | None = None
) -> JSONResponse:
"""Create an OpenAI-compatible error response."""
error_response = OpenAIErrorResponse(
error=OpenAIError(message=message, type=error_type, param=param)
)
return JSONResponse(
status_code=status_code, content=error_response.model_dump(mode="json")
)


def make_chat_endpoint(module):
"""Makes a chat endpoint using a custom module."""

async def endpoint(request: ChatCompletionRequest) -> ChatCompletion:
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
created_timestamp = int(time.time())

output = module.serve(
input=request.messages,
requirements=request.requirements,
model_options={
k: v
for k, v in request.model_dump().items()
if k not in ["messages", "requirements"]
},
)

# Extract usage information from the ModelOutputThunk if available
usage = None
if hasattr(output, "usage") and output.usage is not None:
prompt_tokens = output.usage.get("prompt_tokens", 0)
completion_tokens = output.usage.get("completion_tokens", 0)
# Calculate total_tokens if not provided
total_tokens = output.usage.get(
"total_tokens", prompt_tokens + completion_tokens
)
usage = CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
async def endpoint(request: ChatCompletionRequest):
try:
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
created_timestamp = int(time.time())

output = module.serve(
input=request.messages,
requirements=request.requirements,
model_options={
k: v
for k, v in request.model_dump().items()
if k not in ["messages", "requirements"]
},
)

# system_fingerprint represents backend config hash, not model name
# The model name is already in response.model (line 73)
# Leave as None since we don't track backend config fingerprints yet
system_fingerprint = None

return ChatCompletion(
id=completion_id,
model=request.model,
created=created_timestamp,
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
content=output.value, role="assistant"
),
finish_reason="stop",
# Extract usage information from the ModelOutputThunk if available
usage = None
if hasattr(output, "usage") and output.usage is not None:
prompt_tokens = output.usage.get("prompt_tokens", 0)
completion_tokens = output.usage.get("completion_tokens", 0)
# Calculate total_tokens if not provided
total_tokens = output.usage.get(
"total_tokens", prompt_tokens + completion_tokens
)
],
object="chat.completion", # type: ignore
system_fingerprint=system_fingerprint,
usage=usage,
) # type: ignore
usage = CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)

# system_fingerprint represents backend config hash, not model name
# The model name is already in response.model (line 73)
# Leave as None since we don't track backend config fingerprints yet
system_fingerprint = None

return ChatCompletion(
id=completion_id,
model=request.model,
created=created_timestamp,
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
content=output.value, role="assistant"
),
finish_reason="stop",
)
],
object="chat.completion", # type: ignore
system_fingerprint=system_fingerprint,
usage=usage,
) # type: ignore
except ValueError as e:
# Handle validation errors or invalid input
return create_openai_error_response(
status_code=400,
message=f"Invalid request: {e!s}",
error_type="invalid_request_error",
)
except Exception as e:
# Catch-all for any unexpected errors (including AttributeError)
return create_openai_error_response(
status_code=500,
message=f"Internal server error: {e!s}",
error_type="server_error",
)

endpoint.__name__ = f"chat_{module.__name__}_endpoint"
return endpoint
Expand All @@ -110,7 +140,7 @@ def serve(
route_path,
make_chat_endpoint(module),
methods=["POST"],
response_model=ChatCompletion,
response_model=ChatCompletion | OpenAIErrorResponse,
)
typer.echo(f"Serving {route_path} at http://{host}:{port}")
uvicorn.run(app, host=host, port=port)
28 changes: 25 additions & 3 deletions cli/serve/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class LogitBias(BaseModel):


class ChatCompletionRequest(BaseModel):
model_config = {"extra": "allow"}

model: str
messages: list[ChatMessage]
requirements: list[str | None] | None = Field(default_factory=list)
Expand All @@ -59,9 +61,6 @@ class ChatCompletionRequest(BaseModel):
# For future/undocumented fields
extra: dict[str, Any] = Field(default_factory=dict)

class Config:
extra = "allow"


# Taking this from OpenAI types https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py,
class ChatCompletionMessage(BaseModel):
Expand Down Expand Up @@ -124,3 +123,26 @@ class ChatCompletion(BaseModel):

usage: CompletionUsage | None = None
"""Usage statistics for the completion request."""


class OpenAIError(BaseModel):
"""OpenAI API error object."""

message: str
"""A human-readable error message."""

type: str
"""The type of error (e.g., 'invalid_request_error', 'server_error')."""

param: str | None = None
"""The parameter that caused the error, if applicable."""

code: str | None = None
"""An error code, if applicable."""


class OpenAIErrorResponse(BaseModel):
"""OpenAI API error response wrapper."""

error: OpenAIError
"""The error object."""
138 changes: 138 additions & 0 deletions test/cli/test_serve_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Tests for the OpenAI-compatible serve endpoint."""

from unittest.mock import Mock

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient

from cli.serve.app import make_chat_endpoint
from cli.serve.models import ChatCompletionRequest, ChatMessage


@pytest.fixture
def mock_module_success():
"""Create a mock module that returns a successful response."""
module = Mock()
module.__name__ = "test_module"
output = Mock()
output.value = "Test response"
output.usage = None # No usage info in this test
module.serve = Mock(return_value=output)
return module


@pytest.fixture
def mock_module_attribute_error():
"""Create a mock module that raises AttributeError."""
module = Mock()
module.__name__ = "test_module"
output = Mock(spec=[]) # No 'value' attribute
module.serve = Mock(return_value=output)
return module


@pytest.fixture
def mock_module_value_error():
"""Create a mock module that raises ValueError."""
module = Mock()
module.__name__ = "test_module"
module.serve = Mock(side_effect=ValueError("Invalid input"))
return module


@pytest.fixture
def mock_module_generic_error():
"""Create a mock module that raises a generic exception."""
module = Mock()
module.__name__ = "test_module"
module.serve = Mock(side_effect=RuntimeError("Unexpected error"))
return module


@pytest.fixture
def sample_request():
"""Create a sample chat completion request."""
return ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="Hello")],
requirements=None,
)


@pytest.mark.unit
def test_successful_completion(mock_module_success, sample_request):
"""Test successful chat completion."""
app = FastAPI()
endpoint = make_chat_endpoint(mock_module_success)
app.add_api_route("/test/completions", endpoint, methods=["POST"])
client = TestClient(app)

response = client.post("/test/completions", json=sample_request.model_dump())

assert response.status_code == 200
data = response.json()
assert data["choices"][0]["message"]["content"] == "Test response"
assert data["model"] == "test-model"
assert "id" in data
assert data["object"] == "chat.completion"


@pytest.mark.unit
def test_attribute_error_handling(mock_module_attribute_error, sample_request):
"""Test handling of AttributeError (e.g., missing 'value' attribute)."""
app = FastAPI()
endpoint = make_chat_endpoint(mock_module_attribute_error)
app.add_api_route("/test/attribute-error", endpoint, methods=["POST"])
client = TestClient(app)

response = client.post("/test/attribute-error", json=sample_request.model_dump())

assert response.status_code == 500
data = response.json()
assert "error" in data
assert data["error"]["type"] == "server_error"
assert "Internal server error" in data["error"]["message"]


@pytest.mark.unit
def test_value_error_handling(mock_module_value_error, sample_request):
"""Test handling of ValueError (validation errors)."""
app = FastAPI()
endpoint = make_chat_endpoint(mock_module_value_error)
app.add_api_route("/test/value-error", endpoint, methods=["POST"])
client = TestClient(app)

response = client.post("/test/value-error", json=sample_request.model_dump())

assert response.status_code == 400
data = response.json()
assert "error" in data
assert data["error"]["type"] == "invalid_request_error"
assert "Invalid request" in data["error"]["message"]
assert "Invalid input" in data["error"]["message"]


@pytest.mark.unit
def test_generic_error_handling(mock_module_generic_error, sample_request):
"""Test handling of generic exceptions."""
app = FastAPI()
endpoint = make_chat_endpoint(mock_module_generic_error)
app.add_api_route("/test/generic-error", endpoint, methods=["POST"])
client = TestClient(app)

response = client.post("/test/generic-error", json=sample_request.model_dump())

assert response.status_code == 500
data = response.json()
assert "error" in data
assert data["error"]["type"] == "server_error"
assert "Internal server error" in data["error"]["message"]
assert "Unexpected error" in data["error"]["message"]


@pytest.mark.unit
def test_endpoint_name_generation(mock_module_success):
"""Test that endpoint names are generated correctly."""
endpoint = make_chat_endpoint(mock_module_success)
assert endpoint.__name__ == "chat_test_module_endpoint"
Loading