Skip to content

Commit 1f43918

Browse files
committed
fix openai - agentscope compat
1 parent 1f9d63e commit 1f43918

3 files changed

Lines changed: 16 additions & 7 deletions

File tree

ajet/schema/extended_msg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def __init__(
8888
self.clip = clip
8989
self.tools = tools
9090
self.tool_calls = tool_calls
91+
if not isinstance(self.tool_calls, list):
92+
# agent scope sometimes gives weird type for tool_calls, which is against OpenAI schema
93+
self.tool_calls = list(self.tool_calls)
9194
self.uuid = uuid.uuid4().hex
9295
self.build_from_uuid = build_from_uuid
9396
self.first_message = first_message

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ async def _service_loop(self):
128128
parsed_msg: InterchangeCompletionRequest = pickle.loads(
129129
await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25)
130130
)
131+
if isinstance(parsed_msg, str):
132+
parsed_msg = InterchangeCompletionRequest(**json.loads(parsed_msg))
131133

132134
response = await self.llm_infer(
133135
req=parsed_msg.completion_request,

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ async def coro_task_1_lookup_dict_received__send_loop(key, websocket: WebSocket,
8383
# will be received by:
8484
# ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
8585
# await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25)
86-
await websocket.send_bytes(pickle.dumps(new_req))
86+
try:
87+
await websocket.send_bytes(pickle.dumps(new_req))
88+
except:
89+
# AgentScope sometimes fails the standard OAI schema compliance check for ChatCompletionRequest
90+
await websocket.send_bytes(pickle.dumps(new_req.model_dump_json()))
8791
else:
8892
await asyncio.sleep(0.25)
8993

@@ -243,12 +247,12 @@ async def chat_completions(request: Request, authorization: str = Header(None)):
243247
timeline_uuid = timeline_uuid,
244248
)
245249

246-
# fix Pydantic validation issue for tool_calls field
247-
for msg in int_req.completion_request.messages:
248-
if isinstance(msg, dict) and 'tool_calls' in msg:
249-
tc = msg['tool_calls']
250-
if not isinstance(tc, list):
251-
msg['tool_calls'] = list(tc) if tc else []
250+
# # fix Pydantic validation issue for tool_calls field
251+
# for msg in int_req.completion_request.messages:
252+
# if isinstance(msg, dict) and 'tool_calls' in msg:
253+
# tc = msg['tool_calls']
254+
# if not isinstance(tc, list):
255+
# msg['tool_calls'] = list(tc) if tc else []
252256

253257
ajet_remote_handler_received[key][timeline_uuid] = int_req
254258

0 commit comments

Comments
 (0)