Skip to content

Commit c887d60

Browse files
authored
feat(serve): improve OpenAI API compatibility with usage, finish_reas… (#771)
* feat(serve): improve OpenAI API compatibility with usage, finish_reason, and system_fingerprint Add missing OpenAI API fields to m serve chat completion responses: - Add `finish_reason` field to Choice model (defaults to "stop") - Add `usage` field with token counts (prompt_tokens, completion_tokens, total_tokens) extracted from ModelOutputThunk.usage when available - Add `system_fingerprint` field populated from model or provider metadata - Fix bug in model_options extraction (use model_dump().items() instead of iterating request) Includes comprehensive test coverage with 13 unit tests verifying all new fields and edge cases (missing data, partial data, fallback behavior). Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * Remove made with Bob comment Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix(serve): set system_fingerprint to None per OpenAI spec system_fingerprint should be a backend config hash, not the model name. The model name is already in response.model. Set to None since we don't currently track backend config fingerprints. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: remove unused import Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: calculate total_tokens from prompt + completion when missing When partial usage data is provided (e.g., only prompt_tokens), total_tokens was incorrectly defaulting to 0. Now it's calculated as prompt_tokens + completion_tokens when not explicitly provided. This prevents silent bad values from going out over the OpenAI-compatible API. - Modified cli/serve/app.py to calculate total_tokens when missing - Updated test_usage_with_partial_data to expect correct behavior Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * test: make completion ID format test less brittle Replace hardcoded length assertion with implementation-agnostic check. The test now validates that the ID has the correct prefix and a non-empty suffix, without coupling to specific length or format details. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> --------- Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent deb9d24 commit c887d60

3 files changed

Lines changed: 281 additions & 2 deletions

File tree

cli/serve/app.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
import uvicorn
1111
from fastapi import FastAPI
1212

13-
from .models import ChatCompletion, ChatCompletionMessage, ChatCompletionRequest, Choice
13+
from .models import (
14+
ChatCompletion,
15+
ChatCompletionMessage,
16+
ChatCompletionRequest,
17+
Choice,
18+
CompletionUsage,
19+
)
1420

1521
app = FastAPI(
1622
title="M serve OpenAI API Compatible Server",
@@ -40,10 +46,32 @@ async def endpoint(request: ChatCompletionRequest) -> ChatCompletion:
4046
input=request.messages,
4147
requirements=request.requirements,
4248
model_options={
43-
k: v for k, v in request if k not in ["messages", "requirements"]
49+
k: v
50+
for k, v in request.model_dump().items()
51+
if k not in ["messages", "requirements"]
4452
},
4553
)
4654

55+
# Extract usage information from the ModelOutputThunk if available
56+
usage = None
57+
if hasattr(output, "usage") and output.usage is not None:
58+
prompt_tokens = output.usage.get("prompt_tokens", 0)
59+
completion_tokens = output.usage.get("completion_tokens", 0)
60+
# Calculate total_tokens if not provided
61+
total_tokens = output.usage.get(
62+
"total_tokens", prompt_tokens + completion_tokens
63+
)
64+
usage = CompletionUsage(
65+
prompt_tokens=prompt_tokens,
66+
completion_tokens=completion_tokens,
67+
total_tokens=total_tokens,
68+
)
69+
70+
# system_fingerprint represents backend config hash, not model name
71+
# The model name is already in response.model (line 73)
72+
# Leave as None since we don't track backend config fingerprints yet
73+
system_fingerprint = None
74+
4775
return ChatCompletion(
4876
id=completion_id,
4977
model=request.model,
@@ -54,9 +82,12 @@ async def endpoint(request: ChatCompletionRequest) -> ChatCompletion:
5482
message=ChatCompletionMessage(
5583
content=output.value, role="assistant"
5684
),
85+
finish_reason="stop",
5786
)
5887
],
5988
object="chat.completion", # type: ignore
89+
system_fingerprint=system_fingerprint,
90+
usage=usage,
6091
) # type: ignore
6192

6293
endpoint.__name__ = f"chat_{module.__name__}_endpoint"

cli/serve/models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,23 @@ class Choice(BaseModel):
8282
message: ChatCompletionMessage
8383
"""A chat completion message generated by the model."""
8484

85+
finish_reason: (
86+
Literal["stop", "length", "content_filter", "tool_calls", "function_call"]
87+
| None
88+
) = "stop"
89+
"""The reason the model stopped generating tokens."""
90+
91+
92+
class CompletionUsage(BaseModel):
93+
completion_tokens: int
94+
"""Number of tokens in the generated completion."""
95+
96+
prompt_tokens: int
97+
"""Number of tokens in the prompt."""
98+
99+
total_tokens: int
100+
"""Total number of tokens used in the request (prompt + completion)."""
101+
85102

86103
class ChatCompletion(BaseModel):
87104
id: str
@@ -101,3 +118,9 @@ class ChatCompletion(BaseModel):
101118

102119
object: Literal["chat.completion"]
103120
"""The object type, which is always `chat.completion`."""
121+
122+
system_fingerprint: str | None = None
123+
"""This fingerprint represents the backend configuration that the model runs with."""
124+
125+
usage: CompletionUsage | None = None
126+
"""Usage statistics for the completion request."""

test/cli/test_serve.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
"""Tests for the m serve OpenAI-compatible API server."""
2+
3+
from unittest.mock import Mock
4+
5+
import pytest
6+
7+
from cli.serve.app import make_chat_endpoint
8+
from cli.serve.models import (
9+
ChatCompletion,
10+
ChatCompletionRequest,
11+
ChatMessage,
12+
CompletionUsage,
13+
)
14+
from mellea.core.base import ModelOutputThunk
15+
16+
17+
@pytest.fixture
18+
def mock_module():
19+
"""Create a mock module with a serve function."""
20+
module = Mock()
21+
module.__name__ = "test_module"
22+
return module
23+
24+
25+
@pytest.fixture
26+
def sample_request():
27+
"""Create a sample ChatCompletionRequest."""
28+
return ChatCompletionRequest(
29+
model="test-model",
30+
messages=[ChatMessage(role="user", content="Hello, world!")],
31+
temperature=0.7,
32+
max_tokens=100,
33+
)
34+
35+
36+
class TestChatEndpoint:
37+
"""Tests for the chat completion endpoint."""
38+
39+
@pytest.mark.asyncio
40+
async def test_basic_completion(self, mock_module, sample_request):
41+
"""Test basic chat completion returns correct structure."""
42+
# Setup mock output
43+
mock_output = ModelOutputThunk("Hello! How can I help you?")
44+
mock_module.serve.return_value = mock_output
45+
46+
# Create endpoint and call it
47+
endpoint = make_chat_endpoint(mock_module)
48+
response = await endpoint(sample_request)
49+
50+
# Verify response structure
51+
assert isinstance(response, ChatCompletion)
52+
assert response.model == "test-model"
53+
assert len(response.choices) == 1
54+
assert response.choices[0].message.content == "Hello! How can I help you?"
55+
assert response.choices[0].message.role == "assistant"
56+
assert response.choices[0].index == 0
57+
58+
@pytest.mark.asyncio
59+
async def test_finish_reason_included(self, mock_module, sample_request):
60+
"""Test that finish_reason is included in the response."""
61+
mock_output = ModelOutputThunk("Test response")
62+
mock_module.serve.return_value = mock_output
63+
64+
endpoint = make_chat_endpoint(mock_module)
65+
response = await endpoint(sample_request)
66+
67+
assert response.choices[0].finish_reason == "stop"
68+
69+
@pytest.mark.asyncio
70+
async def test_usage_field_populated(self, mock_module, sample_request):
71+
"""Test that usage field is populated when available."""
72+
mock_output = ModelOutputThunk("Test response")
73+
mock_output.usage = {
74+
"prompt_tokens": 10,
75+
"completion_tokens": 5,
76+
"total_tokens": 15,
77+
}
78+
mock_module.serve.return_value = mock_output
79+
80+
endpoint = make_chat_endpoint(mock_module)
81+
response = await endpoint(sample_request)
82+
83+
assert response.usage is not None
84+
assert isinstance(response.usage, CompletionUsage)
85+
assert response.usage.prompt_tokens == 10
86+
assert response.usage.completion_tokens == 5
87+
assert response.usage.total_tokens == 15
88+
89+
@pytest.mark.asyncio
90+
async def test_usage_field_none_when_unavailable(self, mock_module, sample_request):
91+
"""Test that usage field is None when not available."""
92+
mock_output = ModelOutputThunk("Test response")
93+
# Don't set usage field
94+
mock_module.serve.return_value = mock_output
95+
96+
endpoint = make_chat_endpoint(mock_module)
97+
response = await endpoint(sample_request)
98+
99+
assert response.usage is None
100+
101+
@pytest.mark.asyncio
102+
async def test_system_fingerprint_always_none(self, mock_module, sample_request):
103+
"""Test that system_fingerprint is always None.
104+
105+
Per OpenAI spec, system_fingerprint represents a hash of backend config,
106+
not the model name. The model name is in response.model.
107+
We don't currently track backend config fingerprints.
108+
"""
109+
mock_output = ModelOutputThunk("Test response")
110+
mock_output.model = "gpt-4-turbo"
111+
mock_output.provider = "openai"
112+
mock_module.serve.return_value = mock_output
113+
114+
endpoint = make_chat_endpoint(mock_module)
115+
response = await endpoint(sample_request)
116+
117+
# system_fingerprint should be None, not the model name
118+
assert response.system_fingerprint is None
119+
# Model name should be in the model field
120+
assert response.model == sample_request.model
121+
122+
@pytest.mark.asyncio
123+
async def test_model_options_passed_correctly(self, mock_module, sample_request):
124+
"""Test that model options are passed to serve function correctly."""
125+
mock_output = ModelOutputThunk("Test response")
126+
mock_module.serve.return_value = mock_output
127+
128+
endpoint = make_chat_endpoint(mock_module)
129+
await endpoint(sample_request)
130+
131+
# Verify serve was called with correct arguments
132+
call_args = mock_module.serve.call_args
133+
assert call_args is not None
134+
assert "model_options" in call_args.kwargs
135+
model_options = call_args.kwargs["model_options"]
136+
137+
# Should include temperature and max_tokens but not messages/requirements
138+
assert "temperature" in model_options
139+
assert model_options["temperature"] == 0.7
140+
assert "max_tokens" in model_options
141+
assert model_options["max_tokens"] == 100
142+
assert "messages" not in model_options
143+
assert "requirements" not in model_options
144+
145+
@pytest.mark.asyncio
146+
async def test_completion_id_format(self, mock_module, sample_request):
147+
"""Test that completion ID follows OpenAI format."""
148+
mock_output = ModelOutputThunk("Test response")
149+
mock_module.serve.return_value = mock_output
150+
151+
endpoint = make_chat_endpoint(mock_module)
152+
response = await endpoint(sample_request)
153+
154+
# Should start with "chatcmpl-" and have a non-empty suffix
155+
assert response.id.startswith("chatcmpl-")
156+
assert len(response.id) > len("chatcmpl-"), "ID should have a suffix"
157+
158+
@pytest.mark.asyncio
159+
async def test_created_timestamp_present(self, mock_module, sample_request):
160+
"""Test that created timestamp is present and reasonable."""
161+
mock_output = ModelOutputThunk("Test response")
162+
mock_module.serve.return_value = mock_output
163+
164+
endpoint = make_chat_endpoint(mock_module)
165+
response = await endpoint(sample_request)
166+
167+
# Should be a Unix timestamp (positive integer)
168+
assert isinstance(response.created, int)
169+
assert response.created > 0
170+
171+
@pytest.mark.asyncio
172+
async def test_object_type_correct(self, mock_module, sample_request):
173+
"""Test that object type is set correctly."""
174+
mock_output = ModelOutputThunk("Test response")
175+
mock_module.serve.return_value = mock_output
176+
177+
endpoint = make_chat_endpoint(mock_module)
178+
response = await endpoint(sample_request)
179+
180+
assert response.object == "chat.completion"
181+
182+
@pytest.mark.asyncio
183+
async def test_usage_with_partial_data(self, mock_module, sample_request):
184+
"""Test that usage handles missing fields gracefully."""
185+
mock_output = ModelOutputThunk("Test response")
186+
# Only provide some fields
187+
mock_output.usage = {
188+
"prompt_tokens": 10
189+
# Missing completion_tokens and total_tokens
190+
}
191+
mock_module.serve.return_value = mock_output
192+
193+
endpoint = make_chat_endpoint(mock_module)
194+
response = await endpoint(sample_request)
195+
196+
assert response.usage is not None
197+
assert response.usage.prompt_tokens == 10
198+
assert response.usage.completion_tokens == 0 # Should default to 0
199+
assert (
200+
response.usage.total_tokens == 10
201+
) # Should be prompt_tokens + completion_tokens
202+
203+
@pytest.mark.asyncio
204+
async def test_all_fields_together(self, mock_module, sample_request):
205+
"""Test that all new fields work together correctly."""
206+
mock_output = ModelOutputThunk("Complete response")
207+
mock_output.usage = {
208+
"prompt_tokens": 20,
209+
"completion_tokens": 10,
210+
"total_tokens": 30,
211+
}
212+
mock_output.model = "gpt-4"
213+
mock_output.provider = "openai"
214+
mock_module.serve.return_value = mock_output
215+
216+
endpoint = make_chat_endpoint(mock_module)
217+
response = await endpoint(sample_request)
218+
219+
# Verify all fields are present
220+
assert response.choices[0].finish_reason == "stop"
221+
assert response.usage is not None
222+
assert response.usage.total_tokens == 30
223+
assert response.system_fingerprint is None # Not tracking backend config
224+
assert response.object == "chat.completion"
225+
assert response.id.startswith("chatcmpl-")

0 commit comments

Comments
 (0)