From 42648d2113695061857f63be49a5f13bde197b67 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 1 Apr 2026 15:19:23 -0700 Subject: [PATCH 1/3] feat: add OpenAI-compatible error handling to m serve Add proper exception handling to the chat completion endpoint in cli/serve/app.py to prevent unhandled exceptions from crashing the server. Implements OpenAI API error format for the `m serve` endpoint to ensure compatibility with OpenAI client libraries and tools. Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 135 +++++++++++++++++++++------------ cli/serve/models.py | 28 ++++++- test/cli/test_serve_errors.py | 137 ++++++++++++++++++++++++++++++++++ 3 files changed, 248 insertions(+), 52 deletions(-) create mode 100644 test/cli/test_serve_errors.py diff --git a/cli/serve/app.py b/cli/serve/app.py index 50b4777b6..469e58bce 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -9,6 +9,7 @@ import typer import uvicorn from fastapi import FastAPI +from fastapi.responses import JSONResponse from .models import ( ChatCompletion, @@ -16,6 +17,8 @@ ChatCompletionRequest, Choice, CompletionUsage, + OpenAIError, + OpenAIErrorResponse, ) app = FastAPI( @@ -35,60 +38,94 @@ def load_module_from_path(path: str): return module +def create_openai_error_response( + status_code: int, message: str, error_type: str, param: str | None = None +) -> JSONResponse: + """Create an OpenAI-compatible error response.""" + error_response = OpenAIErrorResponse( + error=OpenAIError(message=message, type=error_type, param=param) + ) + return JSONResponse( + status_code=status_code, content=error_response.model_dump(mode="json") + ) + + def make_chat_endpoint(module): """Makes a chat endpoint using a custom module.""" - async def endpoint(request: ChatCompletionRequest) -> ChatCompletion: - completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" - created_timestamp = int(time.time()) - - output = module.serve( - input=request.messages, - requirements=request.requirements, - model_options={ - k: v - for k, v in request.model_dump().items() - if k not in ["messages", "requirements"] - }, - ) - - # Extract usage information from the ModelOutputThunk if available - usage = None - if hasattr(output, "usage") and output.usage is not None: - prompt_tokens = output.usage.get("prompt_tokens", 0) - completion_tokens = output.usage.get("completion_tokens", 0) - # Calculate total_tokens if not provided - total_tokens = output.usage.get( - "total_tokens", prompt_tokens + completion_tokens - ) - usage = CompletionUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, + async def endpoint(request: ChatCompletionRequest): + try: + completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" + created_timestamp = int(time.time()) + + output = module.serve( + input=request.messages, + requirements=request.requirements, + model_options={ + k: v + for k, v in request.model_dump().items() + if k not in ["messages", "requirements"] + }, ) - # system_fingerprint represents backend config hash, not model name - # The model name is already in response.model (line 73) - # Leave as None since we don't track backend config fingerprints yet - system_fingerprint = None - - return ChatCompletion( - id=completion_id, - model=request.model, - created=created_timestamp, - choices=[ - Choice( - index=0, - message=ChatCompletionMessage( - content=output.value, role="assistant" - ), - finish_reason="stop", + # Extract usage information from the ModelOutputThunk if available + usage = None + if hasattr(output, "usage") and output.usage is not None: + prompt_tokens = output.usage.get("prompt_tokens", 0) + completion_tokens = output.usage.get("completion_tokens", 0) + # Calculate total_tokens if not provided + total_tokens = output.usage.get( + "total_tokens", prompt_tokens + completion_tokens ) - ], - object="chat.completion", # type: ignore - system_fingerprint=system_fingerprint, - usage=usage, - ) # type: ignore + usage = CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + + # system_fingerprint represents backend config hash, not model name + # The model name is already in response.model (line 73) + # Leave as None since we don't track backend config fingerprints yet + system_fingerprint = None + + return ChatCompletion( + id=completion_id, + model=request.model, + created=created_timestamp, + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + content=output.value, role="assistant" + ), + finish_reason="stop", + ) + ], + object="chat.completion", # type: ignore + system_fingerprint=system_fingerprint, + usage=usage, + ) # type: ignore + except AttributeError as e: + # Handle missing 'value' attribute or other attribute errors + return create_openai_error_response( + status_code=500, + message=f"Internal server error: {e!s}", + error_type="server_error", + ) + except ValueError as e: + # Handle validation errors or invalid input + return create_openai_error_response( + status_code=400, + message=f"Invalid request: {e!s}", + error_type="invalid_request_error", + ) + except Exception as e: + # Catch-all for any other unexpected errors + return create_openai_error_response( + status_code=500, + message=f"Internal server error: {e!s}", + error_type="server_error", + ) endpoint.__name__ = f"chat_{module.__name__}_endpoint" return endpoint @@ -110,7 +147,7 @@ def serve( route_path, make_chat_endpoint(module), methods=["POST"], - response_model=ChatCompletion, + response_model=None, # Allow both ChatCompletion and error responses ) typer.echo(f"Serving {route_path} at http://{host}:{port}") uvicorn.run(app, host=host, port=port) diff --git a/cli/serve/models.py b/cli/serve/models.py index 1888e2d00..967ed1684 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -36,6 +36,8 @@ class LogitBias(BaseModel): class ChatCompletionRequest(BaseModel): + model_config = {"extra": "allow"} + model: str messages: list[ChatMessage] requirements: list[str | None] | None = Field(default_factory=list) @@ -59,9 +61,6 @@ class ChatCompletionRequest(BaseModel): # For future/undocumented fields extra: dict[str, Any] = Field(default_factory=dict) - class Config: - extra = "allow" - # Taking this from OpenAI types https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py, class ChatCompletionMessage(BaseModel): @@ -124,3 +123,26 @@ class ChatCompletion(BaseModel): usage: CompletionUsage | None = None """Usage statistics for the completion request.""" + + +class OpenAIError(BaseModel): + """OpenAI API error object.""" + + message: str + """A human-readable error message.""" + + type: str + """The type of error (e.g., 'invalid_request_error', 'server_error').""" + + param: str | None = None + """The parameter that caused the error, if applicable.""" + + code: str | None = None + """An error code, if applicable.""" + + +class OpenAIErrorResponse(BaseModel): + """OpenAI API error response wrapper.""" + + error: OpenAIError + """The error object.""" diff --git a/test/cli/test_serve_errors.py b/test/cli/test_serve_errors.py new file mode 100644 index 000000000..e245816f5 --- /dev/null +++ b/test/cli/test_serve_errors.py @@ -0,0 +1,137 @@ +"""Tests for the OpenAI-compatible serve endpoint.""" + +from unittest.mock import Mock + +import pytest +from fastapi.testclient import TestClient + +from cli.serve.app import app, make_chat_endpoint +from cli.serve.models import ChatCompletionRequest, ChatMessage + + +@pytest.fixture +def mock_module_success(): + """Create a mock module that returns a successful response.""" + module = Mock() + module.__name__ = "test_module" + output = Mock() + output.value = "Test response" + module.serve = Mock(return_value=output) + return module + + +@pytest.fixture +def mock_module_attribute_error(): + """Create a mock module that raises AttributeError.""" + module = Mock() + module.__name__ = "test_module" + output = Mock(spec=[]) # No 'value' attribute + module.serve = Mock(return_value=output) + return module + + +@pytest.fixture +def mock_module_value_error(): + """Create a mock module that raises ValueError.""" + module = Mock() + module.__name__ = "test_module" + module.serve = Mock(side_effect=ValueError("Invalid input")) + return module + + +@pytest.fixture +def mock_module_generic_error(): + """Create a mock module that raises a generic exception.""" + module = Mock() + module.__name__ = "test_module" + module.serve = Mock(side_effect=RuntimeError("Unexpected error")) + return module + + +@pytest.fixture +def sample_request(): + """Create a sample chat completion request.""" + return ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="Hello")], + requirements=None, + ) + + +@pytest.mark.unit +def test_successful_completion(mock_module_success, sample_request): + """Test successful chat completion.""" + endpoint = make_chat_endpoint(mock_module_success) + client = TestClient(app) + + # Add the endpoint to the app + app.add_api_route("/test/completions", endpoint, methods=["POST"]) + + response = client.post("/test/completions", json=sample_request.model_dump()) + + assert response.status_code == 200 + data = response.json() + assert data["choices"][0]["message"]["content"] == "Test response" + assert data["model"] == "test-model" + assert "id" in data + assert data["object"] == "chat.completion" + + +@pytest.mark.unit +def test_attribute_error_handling(mock_module_attribute_error, sample_request): + """Test handling of AttributeError (e.g., missing 'value' attribute).""" + endpoint = make_chat_endpoint(mock_module_attribute_error) + client = TestClient(app) + + app.add_api_route("/test/attribute-error", endpoint, methods=["POST"]) + + response = client.post("/test/attribute-error", json=sample_request.model_dump()) + + assert response.status_code == 500 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "server_error" + assert "Internal server error" in data["error"]["message"] + + +@pytest.mark.unit +def test_value_error_handling(mock_module_value_error, sample_request): + """Test handling of ValueError (validation errors).""" + endpoint = make_chat_endpoint(mock_module_value_error) + client = TestClient(app) + + app.add_api_route("/test/value-error", endpoint, methods=["POST"]) + + response = client.post("/test/value-error", json=sample_request.model_dump()) + + assert response.status_code == 400 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "invalid_request_error" + assert "Invalid request" in data["error"]["message"] + assert "Invalid input" in data["error"]["message"] + + +@pytest.mark.unit +def test_generic_error_handling(mock_module_generic_error, sample_request): + """Test handling of generic exceptions.""" + endpoint = make_chat_endpoint(mock_module_generic_error) + client = TestClient(app) + + app.add_api_route("/test/generic-error", endpoint, methods=["POST"]) + + response = client.post("/test/generic-error", json=sample_request.model_dump()) + + assert response.status_code == 500 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "server_error" + assert "Internal server error" in data["error"]["message"] + assert "Unexpected error" in data["error"]["message"] + + +@pytest.mark.unit +def test_endpoint_name_generation(mock_module_success): + """Test that endpoint names are generated correctly.""" + endpoint = make_chat_endpoint(mock_module_success) + assert endpoint.__name__ == "chat_test_module_endpoint" From 6c91522422dd86c923e6c04a830a02686521e32d Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Wed, 1 Apr 2026 16:42:33 -0700 Subject: [PATCH 2/3] fix: fixes for pr review comments * remove unused import * fix FastAPI app route accumulation * remove duplicate error handler * add types for response_model Signed-off-by: Mark Sturdevant --- cli/serve/app.py | 11 ++--------- test/cli/test_serve_errors.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/cli/serve/app.py b/cli/serve/app.py index 469e58bce..df9eeab76 100644 --- a/cli/serve/app.py +++ b/cli/serve/app.py @@ -105,13 +105,6 @@ async def endpoint(request: ChatCompletionRequest): system_fingerprint=system_fingerprint, usage=usage, ) # type: ignore - except AttributeError as e: - # Handle missing 'value' attribute or other attribute errors - return create_openai_error_response( - status_code=500, - message=f"Internal server error: {e!s}", - error_type="server_error", - ) except ValueError as e: # Handle validation errors or invalid input return create_openai_error_response( @@ -120,7 +113,7 @@ async def endpoint(request: ChatCompletionRequest): error_type="invalid_request_error", ) except Exception as e: - # Catch-all for any other unexpected errors + # Catch-all for any unexpected errors (including AttributeError) return create_openai_error_response( status_code=500, message=f"Internal server error: {e!s}", @@ -147,7 +140,7 @@ def serve( route_path, make_chat_endpoint(module), methods=["POST"], - response_model=None, # Allow both ChatCompletion and error responses + response_model=ChatCompletion | OpenAIErrorResponse, ) typer.echo(f"Serving {route_path} at http://{host}:{port}") uvicorn.run(app, host=host, port=port) diff --git a/test/cli/test_serve_errors.py b/test/cli/test_serve_errors.py index e245816f5..ac1397619 100644 --- a/test/cli/test_serve_errors.py +++ b/test/cli/test_serve_errors.py @@ -3,9 +3,10 @@ from unittest.mock import Mock import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient -from cli.serve.app import app, make_chat_endpoint +from cli.serve.app import make_chat_endpoint from cli.serve.models import ChatCompletionRequest, ChatMessage @@ -61,11 +62,10 @@ def sample_request(): @pytest.mark.unit def test_successful_completion(mock_module_success, sample_request): """Test successful chat completion.""" + app = FastAPI() endpoint = make_chat_endpoint(mock_module_success) - client = TestClient(app) - - # Add the endpoint to the app app.add_api_route("/test/completions", endpoint, methods=["POST"]) + client = TestClient(app) response = client.post("/test/completions", json=sample_request.model_dump()) @@ -80,10 +80,10 @@ def test_successful_completion(mock_module_success, sample_request): @pytest.mark.unit def test_attribute_error_handling(mock_module_attribute_error, sample_request): """Test handling of AttributeError (e.g., missing 'value' attribute).""" + app = FastAPI() endpoint = make_chat_endpoint(mock_module_attribute_error) - client = TestClient(app) - app.add_api_route("/test/attribute-error", endpoint, methods=["POST"]) + client = TestClient(app) response = client.post("/test/attribute-error", json=sample_request.model_dump()) @@ -97,10 +97,10 @@ def test_attribute_error_handling(mock_module_attribute_error, sample_request): @pytest.mark.unit def test_value_error_handling(mock_module_value_error, sample_request): """Test handling of ValueError (validation errors).""" + app = FastAPI() endpoint = make_chat_endpoint(mock_module_value_error) - client = TestClient(app) - app.add_api_route("/test/value-error", endpoint, methods=["POST"]) + client = TestClient(app) response = client.post("/test/value-error", json=sample_request.model_dump()) @@ -115,10 +115,10 @@ def test_value_error_handling(mock_module_value_error, sample_request): @pytest.mark.unit def test_generic_error_handling(mock_module_generic_error, sample_request): """Test handling of generic exceptions.""" + app = FastAPI() endpoint = make_chat_endpoint(mock_module_generic_error) - client = TestClient(app) - app.add_api_route("/test/generic-error", endpoint, methods=["POST"]) + client = TestClient(app) response = client.post("/test/generic-error", json=sample_request.model_dump()) From 4777f5b3734ef0f922dca4d3804c643e37bf7e1e Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Thu, 2 Apr 2026 14:35:16 -0700 Subject: [PATCH 3/3] fix: test_server_errors mock fix for failed CI For the success case, return None for usage not a mock object. Signed-off-by: Mark Sturdevant --- test/cli/test_serve_errors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/cli/test_serve_errors.py b/test/cli/test_serve_errors.py index ac1397619..14632755f 100644 --- a/test/cli/test_serve_errors.py +++ b/test/cli/test_serve_errors.py @@ -17,6 +17,7 @@ def mock_module_success(): module.__name__ = "test_module" output = Mock() output.value = "Test response" + output.usage = None # No usage info in this test module.serve = Mock(return_value=output) return module