Skip to content

Commit ecc15a6

Browse files
authored
fix: add error handling to OpenAI-compatible serve endpoint (#774)
* feat: add OpenAI-compatible error handling to m serve Add proper exception handling to the chat completion endpoint in cli/serve/app.py to prevent unhandled exceptions from crashing the server. Implements OpenAI API error format for the `m serve` endpoint to ensure compatibility with OpenAI client libraries and tools. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: fixes for pr review comments * remove unused import * fix FastAPI app route accumulation * remove duplicate error handler * add types for response_model Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: test_server_errors mock fix for failed CI For the success case, return None for usage not a mock object. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> --------- Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent 9d02cde commit ecc15a6

3 files changed

Lines changed: 242 additions & 52 deletions

File tree

cli/serve/app.py

Lines changed: 79 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
import typer
1010
import uvicorn
1111
from fastapi import FastAPI
12+
from fastapi.responses import JSONResponse
1213

1314
from .models import (
1415
ChatCompletion,
1516
ChatCompletionMessage,
1617
ChatCompletionRequest,
1718
Choice,
1819
CompletionUsage,
20+
OpenAIError,
21+
OpenAIErrorResponse,
1922
)
2023

2124
app = FastAPI(
@@ -35,60 +38,87 @@ def load_module_from_path(path: str):
3538
return module
3639

3740

41+
def create_openai_error_response(
42+
status_code: int, message: str, error_type: str, param: str | None = None
43+
) -> JSONResponse:
44+
"""Create an OpenAI-compatible error response."""
45+
error_response = OpenAIErrorResponse(
46+
error=OpenAIError(message=message, type=error_type, param=param)
47+
)
48+
return JSONResponse(
49+
status_code=status_code, content=error_response.model_dump(mode="json")
50+
)
51+
52+
3853
def make_chat_endpoint(module):
3954
"""Makes a chat endpoint using a custom module."""
4055

41-
async def endpoint(request: ChatCompletionRequest) -> ChatCompletion:
42-
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
43-
created_timestamp = int(time.time())
44-
45-
output = module.serve(
46-
input=request.messages,
47-
requirements=request.requirements,
48-
model_options={
49-
k: v
50-
for k, v in request.model_dump().items()
51-
if k not in ["messages", "requirements"]
52-
},
53-
)
54-
55-
# Extract usage information from the ModelOutputThunk if available
56-
usage = None
57-
if hasattr(output, "usage") and output.usage is not None:
58-
prompt_tokens = output.usage.get("prompt_tokens", 0)
59-
completion_tokens = output.usage.get("completion_tokens", 0)
60-
# Calculate total_tokens if not provided
61-
total_tokens = output.usage.get(
62-
"total_tokens", prompt_tokens + completion_tokens
63-
)
64-
usage = CompletionUsage(
65-
prompt_tokens=prompt_tokens,
66-
completion_tokens=completion_tokens,
67-
total_tokens=total_tokens,
56+
async def endpoint(request: ChatCompletionRequest):
57+
try:
58+
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
59+
created_timestamp = int(time.time())
60+
61+
output = module.serve(
62+
input=request.messages,
63+
requirements=request.requirements,
64+
model_options={
65+
k: v
66+
for k, v in request.model_dump().items()
67+
if k not in ["messages", "requirements"]
68+
},
6869
)
6970

70-
# system_fingerprint represents backend config hash, not model name
71-
# The model name is already in response.model (line 73)
72-
# Leave as None since we don't track backend config fingerprints yet
73-
system_fingerprint = None
74-
75-
return ChatCompletion(
76-
id=completion_id,
77-
model=request.model,
78-
created=created_timestamp,
79-
choices=[
80-
Choice(
81-
index=0,
82-
message=ChatCompletionMessage(
83-
content=output.value, role="assistant"
84-
),
85-
finish_reason="stop",
71+
# Extract usage information from the ModelOutputThunk if available
72+
usage = None
73+
if hasattr(output, "usage") and output.usage is not None:
74+
prompt_tokens = output.usage.get("prompt_tokens", 0)
75+
completion_tokens = output.usage.get("completion_tokens", 0)
76+
# Calculate total_tokens if not provided
77+
total_tokens = output.usage.get(
78+
"total_tokens", prompt_tokens + completion_tokens
8679
)
87-
],
88-
object="chat.completion", # type: ignore
89-
system_fingerprint=system_fingerprint,
90-
usage=usage,
91-
) # type: ignore
80+
usage = CompletionUsage(
81+
prompt_tokens=prompt_tokens,
82+
completion_tokens=completion_tokens,
83+
total_tokens=total_tokens,
84+
)
85+
86+
# system_fingerprint represents backend config hash, not model name
87+
# The model name is already in response.model (line 73)
88+
# Leave as None since we don't track backend config fingerprints yet
89+
system_fingerprint = None
90+
91+
return ChatCompletion(
92+
id=completion_id,
93+
model=request.model,
94+
created=created_timestamp,
95+
choices=[
96+
Choice(
97+
index=0,
98+
message=ChatCompletionMessage(
99+
content=output.value, role="assistant"
100+
),
101+
finish_reason="stop",
102+
)
103+
],
104+
object="chat.completion", # type: ignore
105+
system_fingerprint=system_fingerprint,
106+
usage=usage,
107+
) # type: ignore
108+
except ValueError as e:
109+
# Handle validation errors or invalid input
110+
return create_openai_error_response(
111+
status_code=400,
112+
message=f"Invalid request: {e!s}",
113+
error_type="invalid_request_error",
114+
)
115+
except Exception as e:
116+
# Catch-all for any unexpected errors (including AttributeError)
117+
return create_openai_error_response(
118+
status_code=500,
119+
message=f"Internal server error: {e!s}",
120+
error_type="server_error",
121+
)
92122

93123
endpoint.__name__ = f"chat_{module.__name__}_endpoint"
94124
return endpoint
@@ -110,7 +140,7 @@ def serve(
110140
route_path,
111141
make_chat_endpoint(module),
112142
methods=["POST"],
113-
response_model=ChatCompletion,
143+
response_model=ChatCompletion | OpenAIErrorResponse,
114144
)
115145
typer.echo(f"Serving {route_path} at http://{host}:{port}")
116146
uvicorn.run(app, host=host, port=port)

cli/serve/models.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class LogitBias(BaseModel):
3636

3737

3838
class ChatCompletionRequest(BaseModel):
39+
model_config = {"extra": "allow"}
40+
3941
model: str
4042
messages: list[ChatMessage]
4143
requirements: list[str | None] | None = Field(default_factory=list)
@@ -59,9 +61,6 @@ class ChatCompletionRequest(BaseModel):
5961
# For future/undocumented fields
6062
extra: dict[str, Any] = Field(default_factory=dict)
6163

62-
class Config:
63-
extra = "allow"
64-
6564

6665
# Taking this from OpenAI types https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py,
6766
class ChatCompletionMessage(BaseModel):
@@ -124,3 +123,26 @@ class ChatCompletion(BaseModel):
124123

125124
usage: CompletionUsage | None = None
126125
"""Usage statistics for the completion request."""
126+
127+
128+
class OpenAIError(BaseModel):
129+
"""OpenAI API error object."""
130+
131+
message: str
132+
"""A human-readable error message."""
133+
134+
type: str
135+
"""The type of error (e.g., 'invalid_request_error', 'server_error')."""
136+
137+
param: str | None = None
138+
"""The parameter that caused the error, if applicable."""
139+
140+
code: str | None = None
141+
"""An error code, if applicable."""
142+
143+
144+
class OpenAIErrorResponse(BaseModel):
145+
"""OpenAI API error response wrapper."""
146+
147+
error: OpenAIError
148+
"""The error object."""

test/cli/test_serve_errors.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""Tests for the OpenAI-compatible serve endpoint."""
2+
3+
from unittest.mock import Mock
4+
5+
import pytest
6+
from fastapi import FastAPI
7+
from fastapi.testclient import TestClient
8+
9+
from cli.serve.app import make_chat_endpoint
10+
from cli.serve.models import ChatCompletionRequest, ChatMessage
11+
12+
13+
@pytest.fixture
14+
def mock_module_success():
15+
"""Create a mock module that returns a successful response."""
16+
module = Mock()
17+
module.__name__ = "test_module"
18+
output = Mock()
19+
output.value = "Test response"
20+
output.usage = None # No usage info in this test
21+
module.serve = Mock(return_value=output)
22+
return module
23+
24+
25+
@pytest.fixture
26+
def mock_module_attribute_error():
27+
"""Create a mock module that raises AttributeError."""
28+
module = Mock()
29+
module.__name__ = "test_module"
30+
output = Mock(spec=[]) # No 'value' attribute
31+
module.serve = Mock(return_value=output)
32+
return module
33+
34+
35+
@pytest.fixture
36+
def mock_module_value_error():
37+
"""Create a mock module that raises ValueError."""
38+
module = Mock()
39+
module.__name__ = "test_module"
40+
module.serve = Mock(side_effect=ValueError("Invalid input"))
41+
return module
42+
43+
44+
@pytest.fixture
45+
def mock_module_generic_error():
46+
"""Create a mock module that raises a generic exception."""
47+
module = Mock()
48+
module.__name__ = "test_module"
49+
module.serve = Mock(side_effect=RuntimeError("Unexpected error"))
50+
return module
51+
52+
53+
@pytest.fixture
54+
def sample_request():
55+
"""Create a sample chat completion request."""
56+
return ChatCompletionRequest(
57+
model="test-model",
58+
messages=[ChatMessage(role="user", content="Hello")],
59+
requirements=None,
60+
)
61+
62+
63+
@pytest.mark.unit
64+
def test_successful_completion(mock_module_success, sample_request):
65+
"""Test successful chat completion."""
66+
app = FastAPI()
67+
endpoint = make_chat_endpoint(mock_module_success)
68+
app.add_api_route("/test/completions", endpoint, methods=["POST"])
69+
client = TestClient(app)
70+
71+
response = client.post("/test/completions", json=sample_request.model_dump())
72+
73+
assert response.status_code == 200
74+
data = response.json()
75+
assert data["choices"][0]["message"]["content"] == "Test response"
76+
assert data["model"] == "test-model"
77+
assert "id" in data
78+
assert data["object"] == "chat.completion"
79+
80+
81+
@pytest.mark.unit
82+
def test_attribute_error_handling(mock_module_attribute_error, sample_request):
83+
"""Test handling of AttributeError (e.g., missing 'value' attribute)."""
84+
app = FastAPI()
85+
endpoint = make_chat_endpoint(mock_module_attribute_error)
86+
app.add_api_route("/test/attribute-error", endpoint, methods=["POST"])
87+
client = TestClient(app)
88+
89+
response = client.post("/test/attribute-error", json=sample_request.model_dump())
90+
91+
assert response.status_code == 500
92+
data = response.json()
93+
assert "error" in data
94+
assert data["error"]["type"] == "server_error"
95+
assert "Internal server error" in data["error"]["message"]
96+
97+
98+
@pytest.mark.unit
99+
def test_value_error_handling(mock_module_value_error, sample_request):
100+
"""Test handling of ValueError (validation errors)."""
101+
app = FastAPI()
102+
endpoint = make_chat_endpoint(mock_module_value_error)
103+
app.add_api_route("/test/value-error", endpoint, methods=["POST"])
104+
client = TestClient(app)
105+
106+
response = client.post("/test/value-error", json=sample_request.model_dump())
107+
108+
assert response.status_code == 400
109+
data = response.json()
110+
assert "error" in data
111+
assert data["error"]["type"] == "invalid_request_error"
112+
assert "Invalid request" in data["error"]["message"]
113+
assert "Invalid input" in data["error"]["message"]
114+
115+
116+
@pytest.mark.unit
117+
def test_generic_error_handling(mock_module_generic_error, sample_request):
118+
"""Test handling of generic exceptions."""
119+
app = FastAPI()
120+
endpoint = make_chat_endpoint(mock_module_generic_error)
121+
app.add_api_route("/test/generic-error", endpoint, methods=["POST"])
122+
client = TestClient(app)
123+
124+
response = client.post("/test/generic-error", json=sample_request.model_dump())
125+
126+
assert response.status_code == 500
127+
data = response.json()
128+
assert "error" in data
129+
assert data["error"]["type"] == "server_error"
130+
assert "Internal server error" in data["error"]["message"]
131+
assert "Unexpected error" in data["error"]["message"]
132+
133+
134+
@pytest.mark.unit
135+
def test_endpoint_name_generation(mock_module_success):
136+
"""Test that endpoint names are generated correctly."""
137+
endpoint = make_chat_endpoint(mock_module_success)
138+
assert endpoint.__name__ == "chat_test_module_endpoint"

0 commit comments

Comments
 (0)