Skip to content

Commit a3e1756

Browse files
authored
Merge pull request #107 from Serverless-Devs/support_reasoning_content
feat(langchain): add support for reasoning content with new reasoning…
2 parents c983c34 + 3d2f50b commit a3e1756

19 files changed

Lines changed: 1863 additions & 18 deletions

agentrun/integration/langchain/model_adapter.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
33
将 CommonModel 包装为 LangChain BaseChatModel。"""
44

5-
import inspect
6-
import json
7-
from typing import Any, List, Optional
5+
from typing import Any
86

97
from agentrun.integration.langchain.message_adapter import (
108
LangChainMessageAdapter,
119
)
1210
from agentrun.integration.utils.adapter import ModelAdapter
1311

12+
_DEEPSEEK_PROVIDER = "deepseek"
13+
1414

1515
class LangChainModelAdapter(ModelAdapter):
1616
"""LangChain 模型适配器 / LangChain Model Adapter
@@ -23,15 +23,51 @@ def __init__(self):
2323

2424
def wrap_model(self, common_model: Any) -> Any:
2525
"""包装 CommonModel 为 LangChain BaseChatModel / LangChain Model Adapter"""
26-
from langchain_openai import ChatOpenAI
27-
2826
info = common_model.get_model_info() # 确保模型可用
27+
provider = (info.provider or "").lower()
28+
29+
if provider == _DEEPSEEK_PROVIDER:
30+
return self._create_reasoning_model(info)
31+
return self._create_openai_model(info)
32+
33+
def _create_reasoning_model(self, info: Any) -> Any:
34+
"""创建支持 reasoning_content 的模型(使用 ChatDeepSeek)"""
35+
try:
36+
from langchain_deepseek import ChatDeepSeek
37+
except ImportError as e:
38+
raise ImportError(
39+
"import langchain_deepseek failed. "
40+
"Install it with: pip install 'agentrun-sdk[langchain]' "
41+
"or pip install 'agentrun-sdk[langgraph]'"
42+
) from e
43+
44+
return ChatDeepSeek(
45+
name=info.model,
46+
model=info.model,
47+
api_key=info.api_key,
48+
api_base=info.base_url,
49+
default_headers=info.headers,
50+
stream_usage=True,
51+
streaming=True,
52+
)
53+
54+
def _create_openai_model(self, info: Any) -> Any:
55+
"""创建标准 OpenAI 兼容模型"""
56+
try:
57+
from langchain_openai import ChatOpenAI
58+
except ImportError as e:
59+
raise ImportError(
60+
"import langchain_openai failed. "
61+
"Install it with: pip install 'agentrun-sdk[langchain]' "
62+
"or pip install 'agentrun-sdk[langgraph]'"
63+
) from e
64+
2965
return ChatOpenAI(
3066
name=info.model,
3167
api_key=info.api_key,
3268
model=info.model,
3369
base_url=info.base_url,
3470
default_headers=info.headers,
3571
stream_usage=True,
36-
streaming=True, # 启用流式输出以支持 token by token
72+
streaming=True,
3773
)

agentrun/integration/langgraph/agent_converter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,9 +730,19 @@ def _convert_astream_events_event(
730730
and not self.has_on_chat_model_stream
731731
):
732732
chunk_data = data.get("chunk", {})
733+
messages = []
733734
if isinstance(chunk_data, dict):
734735
messages = chunk_data.get("messages", [])
736+
elif isinstance(chunk_data, list):
737+
for item in chunk_data:
738+
update = getattr(item, "update", None)
739+
if not isinstance(update, dict):
740+
continue
741+
item_messages = update.get("messages", [])
742+
if isinstance(item_messages, list):
743+
messages.extend(item_messages)
735744

745+
if isinstance(messages, list):
736746
for msg in messages:
737747
content = AgentRunConverter._get_message_content(msg)
738748
if content:

agentrun/model/__model_service_async_template.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,4 +233,5 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
233233
base_url=self.provider_settings.base_url,
234234
model=default_model,
235235
headers=cfg.get_headers(),
236+
provider=self.provider,
236237
)

agentrun/model/model_service.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,4 +404,5 @@ def model_info(self, config: Optional[Config] = None) -> BaseInfo:
404404
base_url=self.provider_settings.base_url,
405405
model=default_model,
406406
headers=cfg.get_headers(),
407+
provider=self.provider,
407408
)

agentrun/server/agui_protocol.py

Lines changed: 131 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
from dataclasses import dataclass, field
11+
import json
1112
from typing import (
1213
Any,
1314
AsyncIterator,
@@ -30,6 +31,10 @@
3031
import pydash
3132

3233
from ..utils.helper import merge, MergeOptions
34+
from ..utils.reasoning import (
35+
get_reasoning_content,
36+
is_thinking_enabled_from_env,
37+
)
3338
from .model import (
3439
AgentEvent,
3540
AgentRequest,
@@ -60,6 +65,14 @@ class TextState:
6065
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
6166

6267

68+
@dataclass
69+
class ReasoningState:
70+
started: bool = False
71+
message_started: bool = False
72+
phase_id: str = field(default_factory=lambda: str(uuid.uuid4()))
73+
message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
74+
75+
6376
@dataclass
6477
class ToolCallState:
6578
name: str = ""
@@ -72,6 +85,7 @@ class ToolCallState:
7285
@dataclass
7386
class StreamStateMachine:
7487
text: TextState = field(default_factory=TextState)
88+
reasoning: ReasoningState = field(default_factory=ReasoningState)
7589
tool_call_states: Dict[str, ToolCallState] = field(default_factory=dict)
7690
tool_result_chunks: Dict[str, List[str]] = field(default_factory=dict)
7791
run_errored: bool = False
@@ -121,6 +135,43 @@ def cache_tool_result_chunk(self, tool_id: str, delta: str) -> None:
121135
def pop_tool_result_chunks(self, tool_id: str) -> str:
122136
return "".join(self.tool_result_chunks.pop(tool_id, []))
123137

138+
def ensure_reasoning_started(self) -> Iterator[str]:
139+
if not self.reasoning.started:
140+
yield _encode_reasoning_event(
141+
"REASONING_START",
142+
messageId=self.reasoning.phase_id,
143+
)
144+
self.reasoning.started = True
145+
if not self.reasoning.message_started:
146+
yield _encode_reasoning_event(
147+
"REASONING_MESSAGE_START",
148+
messageId=self.reasoning.message_id,
149+
role="reasoning",
150+
)
151+
self.reasoning.message_started = True
152+
153+
def end_reasoning_if_open(self) -> Iterator[str]:
154+
if self.reasoning.message_started:
155+
yield _encode_reasoning_event(
156+
"REASONING_MESSAGE_END",
157+
messageId=self.reasoning.message_id,
158+
)
159+
self.reasoning.message_started = False
160+
if self.reasoning.started:
161+
yield _encode_reasoning_event(
162+
"REASONING_END",
163+
messageId=self.reasoning.phase_id,
164+
)
165+
self.reasoning = ReasoningState()
166+
167+
168+
def _encode_reasoning_event(event_type: str, **payload: Any) -> str:
169+
return (
170+
"data: "
171+
+ json.dumps({"type": event_type, **payload}, ensure_ascii=False)
172+
+ "\n\n"
173+
)
174+
124175

125176
class AGUIProtocolHandler(BaseProtocolHandler):
126177
"""AG-UI 协议处理器
@@ -376,6 +427,10 @@ async def _format_stream(
376427
if state.run_errored:
377428
return
378429

430+
# 结束未结束的 reasoning 消息
431+
for sse_data in state.end_reasoning_if_open():
432+
yield sse_data
433+
379434
# 结束所有未结束的工具调用
380435
for sse_data in state.end_all_tools(self._encoder):
381436
yield sse_data
@@ -399,8 +454,6 @@ def _process_event_with_boundaries(
399454
state: StreamStateMachine,
400455
) -> Iterator[str]:
401456
"""处理事件并注入边界事件"""
402-
import json
403-
404457
from ag_ui.core import CustomEvent as AguiCustomEvent
405458
from ag_ui.core import (
406459
RunErrorEvent,
@@ -413,6 +466,8 @@ def _process_event_with_boundaries(
413466
ToolCallStartEvent,
414467
)
415468

469+
thinking_enabled = is_thinking_enabled_from_env()
470+
416471
# RAW 事件直接透传
417472
if event.event == EventType.RAW:
418473
raw_data = event.data.get("raw", "")
@@ -422,9 +477,46 @@ def _process_event_with_boundaries(
422477
yield raw_data
423478
return
424479

480+
if event.event == EventType.REASONING:
481+
if thinking_enabled:
482+
reasoning_content = (
483+
event.data.get("delta")
484+
or get_reasoning_content(event.data)
485+
or ""
486+
)
487+
if reasoning_content:
488+
for sse_data in state.end_text_if_open(self._encoder):
489+
yield sse_data
490+
for sse_data in state.end_all_tools(self._encoder):
491+
yield sse_data
492+
for sse_data in state.ensure_reasoning_started():
493+
yield sse_data
494+
yield _encode_reasoning_event(
495+
"REASONING_MESSAGE_CONTENT",
496+
messageId=state.reasoning.message_id,
497+
delta=reasoning_content,
498+
)
499+
return
500+
425501
# TEXT 事件:在首个 TEXT 前注入 TEXT_MESSAGE_START
426502
# AG-UI 协议要求:发送 TEXT_MESSAGE_START 前必须先结束所有未结束的 TOOL_CALL
427503
if event.event == EventType.TEXT:
504+
addition = self._strip_reasoning_from_addition(
505+
event.addition, thinking_enabled
506+
)
507+
addition_reasoning = get_reasoning_content(event.addition or {})
508+
if thinking_enabled and addition_reasoning:
509+
for sse_data in state.ensure_reasoning_started():
510+
yield sse_data
511+
yield _encode_reasoning_event(
512+
"REASONING_MESSAGE_CONTENT",
513+
messageId=state.reasoning.message_id,
514+
delta=addition_reasoning,
515+
)
516+
517+
for sse_data in state.end_reasoning_if_open():
518+
yield sse_data
519+
428520
for sse_data in state.end_all_tools(self._encoder):
429521
yield sse_data
430522

@@ -435,13 +527,13 @@ def _process_event_with_boundaries(
435527
message_id=state.text.message_id,
436528
delta=event.data.get("delta", ""),
437529
)
438-
if event.addition:
530+
if addition:
439531
event_dict = agui_event.model_dump(
440532
by_alias=True, exclude_none=True
441533
)
442534
event_dict = self._apply_addition(
443535
event_dict,
444-
event.addition,
536+
addition,
445537
event.addition_merge_options,
446538
)
447539
json_str = json.dumps(event_dict, ensure_ascii=False)
@@ -455,6 +547,9 @@ def _process_event_with_boundaries(
455547
tool_id = event.data.get("id", "")
456548
tool_name = event.data.get("name", "")
457549

550+
for sse_data in state.end_reasoning_if_open():
551+
yield sse_data
552+
458553
for sse_data in state.end_text_if_open(self._encoder):
459554
yield sse_data
460555

@@ -491,6 +586,9 @@ def _process_event_with_boundaries(
491586
tool_name = event.data.get("name", "")
492587
tool_args = event.data.get("args", "")
493588

589+
for sse_data in state.end_reasoning_if_open():
590+
yield sse_data
591+
494592
for sse_data in state.end_text_if_open(self._encoder):
495593
yield sse_data
496594

@@ -541,6 +639,9 @@ def _process_event_with_boundaries(
541639
timeout = event.data.get("timeout")
542640
schema = event.data.get("schema")
543641

642+
for sse_data in state.end_reasoning_if_open():
643+
yield sse_data
644+
544645
for sse_data in state.end_text_if_open(self._encoder):
545646
yield sse_data
546647

@@ -601,6 +702,9 @@ def _process_event_with_boundaries(
601702
tool_id = event.data.get("id", "")
602703
tool_name = event.data.get("name", "")
603704

705+
for sse_data in state.end_reasoning_if_open():
706+
yield sse_data
707+
604708
for sse_data in state.end_text_if_open(self._encoder):
605709
yield sse_data
606710

@@ -767,6 +871,29 @@ def _apply_addition(
767871

768872
return merge(event_data, addition, **(merge_options or {}))
769873

874+
def _strip_reasoning_from_addition(
875+
self,
876+
addition: Optional[Dict[str, Any]],
877+
thinking_enabled: bool,
878+
) -> Optional[Dict[str, Any]]:
879+
if not addition:
880+
return addition
881+
882+
stripped = dict(addition)
883+
stripped.pop("reasoning_content", None)
884+
additional_kwargs = stripped.get("additional_kwargs")
885+
if isinstance(additional_kwargs, dict):
886+
additional_kwargs = dict(additional_kwargs)
887+
additional_kwargs.pop("reasoning_content", None)
888+
if additional_kwargs:
889+
stripped["additional_kwargs"] = additional_kwargs
890+
else:
891+
stripped.pop("additional_kwargs", None)
892+
893+
if not thinking_enabled:
894+
return stripped
895+
return stripped or None
896+
770897
async def _error_stream(self, message: str) -> AsyncIterator[str]:
771898
"""生成错误事件流
772899

0 commit comments

Comments
 (0)