@@ -201,6 +201,8 @@ async def create_agent(
201201
202202@dataclass
203203class InvokeContext :
204+ thread_id : str
205+
204206 retry : LC_HumanMessage | bool = False
205207 """
206208 Controls whether to retry the agent loop after ainvoke succeeds.
@@ -636,12 +638,6 @@ async def next(r: AgentRequest) -> AgentResponse[Any | None]:
636638 async def invoke (
637639 self , messages : list [BaseMessage ], thread_id : str
638640 ) -> AgentResponse [OutputT ]:
639- # TODO: What if we are passed len(messages) == 0 to invoke?
640- # TODO: What if someone passed call_id that don't have a corresponding id with the response.
641- # Possibly we should do a validation phase of messages here.
642- # TODO: also assert correct ordering, i.e. directly after AIMessage with calls, there is a response
643- # not before or far after.
644-
645641 async def invoke_agent (req : AgentRequest ) -> AgentResponse [Any | None ]:
646642 langchain_msgs = []
647643
@@ -656,7 +652,7 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
656652 langchain_msgs .extend ([_map_message_to_langchain (m ) for m in req .messages ])
657653
658654 while True :
659- ctx = InvokeContext ()
655+ ctx = InvokeContext (thread_id = thread_id )
660656 result = await self ._agent .ainvoke (
661657 {"messages" : langchain_msgs },
662658 context = ctx ,
@@ -698,6 +694,7 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
698694
699695 result = await self ._with_agent_middleware (invoke_agent )(
700696 AgentRequest (
697+ thread_id = thread_id ,
701698 messages = messages ,
702699 )
703700 )
@@ -1051,38 +1048,48 @@ async def _sdk_handler(request: ModelRequest) -> ModelResponse:
10511048
10521049
10531050def _convert_model_request_from_lc (
1054- request : LC_ModelRequest , model : BaseChatModel
1051+ request : LC_ModelRequest ,
1052+ model : BaseChatModel ,
10551053) -> ModelRequest :
1054+ thread_id = request .runtime .context .thread_id
1055+
10561056 system_message = (
10571057 request .system_message .content .__str__ () if request .system_message else ""
10581058 )
10591059
10601060 return ModelRequest (
10611061 system_message = system_message ,
1062- state = _convert_agent_state_from_langchain (request .state , model ),
1062+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10631063 )
10641064
10651065
10661066def _convert_tool_request_from_lc (
1067- request : LC_ToolCallRequest , model : BaseChatModel
1067+ request : LC_ToolCallRequest ,
1068+ model : BaseChatModel ,
10681069) -> ToolRequest :
1070+ assert isinstance (request .runtime .context , InvokeContext )
1071+ thread_id = request .runtime .context .thread_id
1072+
10691073 tool_call = _map_tool_call_from_langchain (request .tool_call )
10701074 assert isinstance (tool_call , ToolCall ), "Expected tool call"
10711075 return ToolRequest (
10721076 call = tool_call ,
1073- state = _convert_agent_state_from_langchain (request .state , model ),
1077+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10741078 )
10751079
10761080
10771081def _convert_subagent_request_from_lc (
10781082 request : LC_ToolCallRequest ,
10791083 model : BaseChatModel ,
10801084) -> SubagentRequest :
1085+ assert isinstance (request .runtime .context , InvokeContext )
1086+ thread_id = request .runtime .context .thread_id
1087+
10811088 subagent_call = _map_tool_call_from_langchain (request .tool_call )
10821089 assert isinstance (subagent_call , SubagentCall ), "Expected subagent call"
10831090 return SubagentRequest (
10841091 call = subagent_call ,
1085- state = _convert_agent_state_from_langchain (request .state , model ),
1092+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10861093 )
10871094
10881095
@@ -1506,7 +1513,9 @@ async def invoke_agent(
15061513 OutputT | str ,
15071514 SubagentStructuredResult | SubagentTextResult ,
15081515 ]:
1509- result = await agent .invoke ([message ], thread_id = thread_id )
1516+ result = await agent .invoke (
1517+ [message ], thread_id = thread_id or _thread_id_new_uuid ()
1518+ )
15101519
15111520 if agent .output_schema :
15121521 assert result .structured_output is not None
@@ -1555,7 +1564,7 @@ async def invoke_agent_structured(
15551564 result = await agent .invoke_with_data (
15561565 instructions = "Follow the system prompt." ,
15571566 data = content .model_dump (),
1558- thread_id = thread_id ,
1567+ thread_id = thread_id or _thread_id_new_uuid () ,
15591568 )
15601569
15611570 if agent .output_schema :
@@ -1769,7 +1778,7 @@ def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
17691778
17701779
17711780def _convert_agent_state_from_langchain (
1772- state : LC_AgentState [Any ], model : BaseChatModel
1781+ state : LC_AgentState [Any ], model : BaseChatModel , thread_id : str
17731782) -> AgentState :
17741783 messages = state ["messages" ]
17751784 total_tokens_counter = _get_approximate_token_counter (model )
@@ -1779,6 +1788,7 @@ def _convert_agent_state_from_langchain(
17791788 messages = messages ,
17801789 total_steps = len (messages ),
17811790 token_count = total_tokens ,
1791+ thread_id = thread_id ,
17821792 )
17831793
17841794
@@ -1909,6 +1919,11 @@ def check_tool_name(type: str, name: str) -> None:
19091919 check_call_id ("subagent" , call .id )
19101920 check_tool_name ("subagent" , call .name )
19111921 pending_subagent_calls [call .id ] = call .name
1922+
1923+ if call .thread_id == "" :
1924+ raise _InvalidMessagesException (
1925+ "thread_id should not be an empty string"
1926+ )
19121927 else :
19131928 raise _InvalidMessagesException (
19141929 f"AIMessage contains invalid call type: { type (call )} "
0 commit comments