Skip to content

Commit 7e46bed

Browse files
lvhan028cursoragent
andcommitted
test(serve): fix anthropic and responses tests for default_gen_config
Add default_gen_config to fake server contexts and route anthropic endpoint tests through a lightweight client that invokes handlers directly. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent fdbf098 commit 7e46bed

2 files changed

Lines changed: 81 additions & 7 deletions

File tree

tests/test_lmdeploy/serve/anthropic/test_endpoints.py

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from types import SimpleNamespace
66

77
import pytest
8-
from fastapi import FastAPI
9-
from fastapi.testclient import TestClient
8+
from fastapi.encoders import jsonable_encoder
9+
from fastapi.responses import JSONResponse, StreamingResponse
1010

11+
from lmdeploy.serve.anthropic.protocol import MessagesRequest
1112
from lmdeploy.serve.anthropic.router import create_anthropic_router
1213
from lmdeploy.serve.anthropic.streaming import stream_messages_response
1314
from lmdeploy.serve.openai.protocol import DeltaFunctionCall, DeltaMessage, DeltaToolCall, FunctionCall, ToolCall
@@ -141,6 +142,7 @@ def __init__(
141142
logprobs_mode=logprobs_mode,
142143
enable_return_routed_experts=enable_return_routed_experts,
143144
)
145+
self.default_gen_config = {}
144146
self.response_parser_cls = response_parser_cls
145147

146148
def create_session(self, _session_id: int | None = None):
@@ -153,6 +155,79 @@ def get_engine_config(self):
153155
return self.async_engine.backend_config
154156

155157

158+
class _FakeRawRequest:
159+
160+
def __init__(self, headers):
161+
self.headers = headers
162+
163+
async def is_disconnected(self):
164+
return False
165+
166+
167+
class _TestResponse:
168+
169+
def __init__(self, status_code: int, payload=None, body: str = ''):
170+
self.status_code = status_code
171+
self._payload = jsonable_encoder(payload)
172+
self._body = body
173+
174+
def json(self):
175+
return self._payload
176+
177+
def iter_lines(self):
178+
return self._body.splitlines()
179+
180+
181+
class _StreamContext:
182+
183+
def __init__(self, response: _TestResponse):
184+
self.response = response
185+
186+
def __enter__(self):
187+
return self.response
188+
189+
def __exit__(self, *args):
190+
return False
191+
192+
193+
class _AnthropicTestClient:
194+
195+
def __init__(self, server_context):
196+
router = create_anthropic_router(server_context)
197+
self._routes = {route.path: route.endpoint for route in router.routes}
198+
199+
def post(self, path: str, *, headers, json):
200+
return asyncio.run(self._post(path, headers=headers, json=json))
201+
202+
def stream(self, method: str, path: str, *, headers, json):
203+
assert method == 'POST'
204+
return _StreamContext(self.post(path, headers=headers, json=json))
205+
206+
def get(self, path: str):
207+
return asyncio.run(self._get(path))
208+
209+
async def _post(self, path: str, *, headers, json):
210+
endpoint = self._routes[path.split('?', 1)[0]]
211+
result = await endpoint(MessagesRequest(**json), _FakeRawRequest(headers))
212+
return await self._response_from_result(result)
213+
214+
async def _get(self, path: str):
215+
endpoint = self._routes[path.split('?', 1)[0]]
216+
return await self._response_from_result(await endpoint())
217+
218+
async def _response_from_result(self, result):
219+
if isinstance(result, JSONResponse):
220+
return _TestResponse(result.status_code, json.loads(result.body))
221+
if isinstance(result, StreamingResponse):
222+
chunks = []
223+
async for chunk in result.body_iterator:
224+
if isinstance(chunk, bytes):
225+
chunk = chunk.decode()
226+
chunks.append(chunk)
227+
return _TestResponse(result.status_code, body=''.join(chunks))
228+
return _TestResponse(200, result)
229+
230+
156231
class _ToolAndReasoningParser:
157232
tool_parser_cls = object
158233

@@ -217,11 +292,9 @@ def _make_client(response_parser_cls=_BasicParser,
217292
server_context=None,
218293
logprobs_mode='raw_logprobs',
219294
return_context=False):
220-
app = FastAPI()
221295
context = server_context or _FakeServerContext(response_parser_cls=response_parser_cls,
222296
logprobs_mode=logprobs_mode)
223-
app.include_router(create_anthropic_router(context))
224-
client = TestClient(app)
297+
client = _AnthropicTestClient(context)
225298
if return_context:
226299
return client, context
227300
return client
@@ -252,11 +325,11 @@ def test_messages_non_stream():
252325
assert len(context.session_mgr.removed) == 1
253326

254327

255-
def _post_messages(client: TestClient, **overrides):
328+
def _post_messages(client: _AnthropicTestClient, **overrides):
256329
return client.post('/v1/messages', headers=ANTHROPIC_HEADERS, json=_messages_payload(**overrides))
257330

258331

259-
def _stream_messages_body(client: TestClient, **overrides):
332+
def _stream_messages_body(client: _AnthropicTestClient, **overrides):
260333
payload = _messages_payload(**overrides)
261334
payload['stream'] = True
262335
with client.stream('POST', '/v1/messages', headers=ANTHROPIC_HEADERS, json=payload) as response:

tests/test_lmdeploy/serve/openai/responses/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class FakeServerContext:
6464

6565
def __init__(self):
6666
self.async_engine = FakeAsyncEngine()
67+
self.default_gen_config = {}
6768
self.sessions = []
6869

6970
def create_session(self, session_id):

0 commit comments

Comments
 (0)