55from types import SimpleNamespace
66
77import pytest
8- from fastapi import FastAPI
9- from fastapi .testclient import TestClient
8+ from fastapi . encoders import jsonable_encoder
9+ from fastapi .responses import JSONResponse , StreamingResponse
1010
11+ from lmdeploy .serve .anthropic .protocol import MessagesRequest
1112from lmdeploy .serve .anthropic .router import create_anthropic_router
1213from lmdeploy .serve .anthropic .streaming import stream_messages_response
1314from lmdeploy .serve .openai .protocol import DeltaFunctionCall , DeltaMessage , DeltaToolCall , FunctionCall , ToolCall
@@ -141,6 +142,7 @@ def __init__(
141142 logprobs_mode = logprobs_mode ,
142143 enable_return_routed_experts = enable_return_routed_experts ,
143144 )
145+ self .default_gen_config = {}
144146 self .response_parser_cls = response_parser_cls
145147
146148 def create_session (self , _session_id : int | None = None ):
@@ -153,6 +155,79 @@ def get_engine_config(self):
153155 return self .async_engine .backend_config
154156
155157
158+ class _FakeRawRequest :
159+
160+ def __init__ (self , headers ):
161+ self .headers = headers
162+
163+ async def is_disconnected (self ):
164+ return False
165+
166+
167+ class _TestResponse :
168+
169+ def __init__ (self , status_code : int , payload = None , body : str = '' ):
170+ self .status_code = status_code
171+ self ._payload = jsonable_encoder (payload )
172+ self ._body = body
173+
174+ def json (self ):
175+ return self ._payload
176+
177+ def iter_lines (self ):
178+ return self ._body .splitlines ()
179+
180+
181+ class _StreamContext :
182+
183+ def __init__ (self , response : _TestResponse ):
184+ self .response = response
185+
186+ def __enter__ (self ):
187+ return self .response
188+
189+ def __exit__ (self , * args ):
190+ return False
191+
192+
193+ class _AnthropicTestClient :
194+
195+ def __init__ (self , server_context ):
196+ router = create_anthropic_router (server_context )
197+ self ._routes = {route .path : route .endpoint for route in router .routes }
198+
199+ def post (self , path : str , * , headers , json ):
200+ return asyncio .run (self ._post (path , headers = headers , json = json ))
201+
202+ def stream (self , method : str , path : str , * , headers , json ):
203+ assert method == 'POST'
204+ return _StreamContext (self .post (path , headers = headers , json = json ))
205+
206+ def get (self , path : str ):
207+ return asyncio .run (self ._get (path ))
208+
209+ async def _post (self , path : str , * , headers , json ):
210+ endpoint = self ._routes [path .split ('?' , 1 )[0 ]]
211+ result = await endpoint (MessagesRequest (** json ), _FakeRawRequest (headers ))
212+ return await self ._response_from_result (result )
213+
214+ async def _get (self , path : str ):
215+ endpoint = self ._routes [path .split ('?' , 1 )[0 ]]
216+ return await self ._response_from_result (await endpoint ())
217+
218+ async def _response_from_result (self , result ):
219+ if isinstance (result , JSONResponse ):
220+ return _TestResponse (result .status_code , json .loads (result .body ))
221+ if isinstance (result , StreamingResponse ):
222+ chunks = []
223+ async for chunk in result .body_iterator :
224+ if isinstance (chunk , bytes ):
225+ chunk = chunk .decode ()
226+ chunks .append (chunk )
227+ return _TestResponse (result .status_code , body = '' .join (chunks ))
228+ return _TestResponse (200 , result )
229+
230+
156231class _ToolAndReasoningParser :
157232 tool_parser_cls = object
158233
@@ -217,11 +292,9 @@ def _make_client(response_parser_cls=_BasicParser,
217292 server_context = None ,
218293 logprobs_mode = 'raw_logprobs' ,
219294 return_context = False ):
220- app = FastAPI ()
221295 context = server_context or _FakeServerContext (response_parser_cls = response_parser_cls ,
222296 logprobs_mode = logprobs_mode )
223- app .include_router (create_anthropic_router (context ))
224- client = TestClient (app )
297+ client = _AnthropicTestClient (context )
225298 if return_context :
226299 return client , context
227300 return client
@@ -252,11 +325,11 @@ def test_messages_non_stream():
252325 assert len (context .session_mgr .removed ) == 1
253326
254327
255- def _post_messages (client : TestClient , ** overrides ):
328+ def _post_messages (client : _AnthropicTestClient , ** overrides ):
256329 return client .post ('/v1/messages' , headers = ANTHROPIC_HEADERS , json = _messages_payload (** overrides ))
257330
258331
259- def _stream_messages_body (client : TestClient , ** overrides ):
332+ def _stream_messages_body (client : _AnthropicTestClient , ** overrides ):
260333 payload = _messages_payload (** overrides )
261334 payload ['stream' ] = True
262335 with client .stream ('POST' , '/v1/messages' , headers = ANTHROPIC_HEADERS , json = payload ) as response :
0 commit comments