Skip to content

Commit dd9846b

Browse files
committed
fix(mlx): improve reasoning logging and gradio display
1 parent 7382931 commit dd9846b

4 files changed

Lines changed: 249 additions & 86 deletions

File tree

xinference/model/llm/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,15 @@ def prepare_parse_reasoning_content(
209209
warnings.warn(
210210
"enable_thinking cannot be disabled for non hybrid model, will be ignored"
211211
)
212+
abilities = self.model_family.model_ability or []
213+
auto_insert_start_tag = "hybrid" not in abilities
212214
# Initialize reasoning parser if model has reasoning ability
213215
self.reasoning_parser = ReasoningParser( # type: ignore
214216
reasoning_content,
215217
self.model_family.reasoning_start_tag, # type: ignore
216218
self.model_family.reasoning_end_tag, # type: ignore
217219
enable_thinking=enable_thinking,
220+
auto_insert_start_tag=auto_insert_start_tag,
218221
)
219222

220223
def prepare_parse_tool_calls(self):

xinference/model/llm/mlx/core.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,12 +1231,16 @@ async def async_chat(
12311231

12321232
async def _log_streaming_chunks():
12331233
full_text = ""
1234+
full_reasoning = ""
12341235
async for chunk in chunks: # type: ignore[arg-type]
12351236
choices = chunk.get("choices")
12361237
if choices:
12371238
first = choices[0]
12381239
delta = first.get("delta")
12391240
if isinstance(delta, dict):
1241+
delta_reasoning = delta.get("reasoning_content")
1242+
if isinstance(delta_reasoning, str):
1243+
full_reasoning += delta_reasoning
12401244
delta_text = delta.get("content")
12411245
if delta_text:
12421246
full_text += delta_text
@@ -1245,7 +1249,11 @@ async def _log_streaming_chunks():
12451249
if isinstance(text, str):
12461250
full_text += text
12471251
yield chunk
1248-
logger.debug("[MLX] Full accumulated output: %r", full_text)
1252+
logger.debug(
1253+
"[MLX] Full accumulated output: reasoning=%r, content=%r",
1254+
full_reasoning,
1255+
full_text,
1256+
)
12491257

12501258
return self._async_to_chat_completion_chunks(
12511259
_log_streaming_chunks(),
@@ -1602,14 +1610,29 @@ def chat(
16021610

16031611
def _log_streaming_chunks():
16041612
full_text = ""
1613+
full_reasoning = ""
16051614
for chunk in it:
16061615
choices = chunk.get("choices")
1607-
if choices and choices[0].get("text"):
1608-
text = choices[0]["text"]
1609-
if text:
1610-
full_text += text # type: ignore[arg-type]
1616+
if choices:
1617+
first = choices[0]
1618+
delta = first.get("delta")
1619+
if isinstance(delta, dict):
1620+
delta_reasoning = delta.get("reasoning_content")
1621+
if isinstance(delta_reasoning, str):
1622+
full_reasoning += delta_reasoning
1623+
delta_text = delta.get("content")
1624+
if isinstance(delta_text, str):
1625+
full_text += delta_text
1626+
elif first.get("text"):
1627+
text = first["text"]
1628+
if text:
1629+
full_text += text # type: ignore[arg-type]
16111630
yield chunk
1612-
logger.debug("[MLX] Full accumulated output: %r", full_text)
1631+
logger.debug(
1632+
"[MLX] Full accumulated output: reasoning=%r, content=%r",
1633+
full_reasoning,
1634+
full_text,
1635+
)
16131636

16141637
return self._to_chat_completion_chunks(
16151638
_log_streaming_chunks(), self.reasoning_parser

0 commit comments

Comments
 (0)