@@ -201,6 +201,8 @@ async def create_agent(
201201
202202@dataclass
203203class InvokeContext :
204+ thread_id : str
205+
204206 retry : LC_HumanMessage | bool = False
205207 """
206208 Controls whether to retry the agent loop after ainvoke succeeds.
@@ -636,12 +638,6 @@ async def next(r: AgentRequest) -> AgentResponse[Any | None]:
636638 async def invoke (
637639 self , messages : list [BaseMessage ], thread_id : str
638640 ) -> AgentResponse [OutputT ]:
639- # TODO: What if we are passed len(messages) == 0 to invoke?
640- # TODO: What if someone passed call_id that don't have a corresponding id with the response.
641- # Possibly we should do a validation phase of messages here.
642- # TODO: also assert correct ordering, i.e. directly after AIMessage with calls, there is a response
643- # not before or far after.
644-
645641 async def invoke_agent (req : AgentRequest ) -> AgentResponse [Any | None ]:
646642 langchain_msgs = []
647643
@@ -656,7 +652,7 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
656652 langchain_msgs .extend ([_map_message_to_langchain (m ) for m in req .messages ])
657653
658654 while True :
659- ctx = InvokeContext ()
655+ ctx = InvokeContext (thread_id = thread_id )
660656 result = await self ._agent .ainvoke (
661657 {"messages" : langchain_msgs },
662658 context = ctx ,
@@ -698,6 +694,7 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
698694
699695 result = await self ._with_agent_middleware (invoke_agent )(
700696 AgentRequest (
697+ thread_id = thread_id ,
701698 messages = messages ,
702699 )
703700 )
@@ -1053,36 +1050,44 @@ async def _sdk_handler(request: ModelRequest) -> ModelResponse:
10531050def _convert_model_request_from_lc (
10541051 request : LC_ModelRequest , model : BaseChatModel
10551052) -> ModelRequest :
1053+ thread_id = request .runtime .context .thread_id
1054+
10561055 system_message = (
10571056 request .system_message .content .__str__ () if request .system_message else ""
10581057 )
10591058
10601059 return ModelRequest (
10611060 system_message = system_message ,
1062- state = _convert_agent_state_from_langchain (request .state , model ),
1061+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10631062 )
10641063
10651064
10661065def _convert_tool_request_from_lc (
10671066 request : LC_ToolCallRequest , model : BaseChatModel
10681067) -> ToolRequest :
1068+ assert isinstance (request .runtime .context , InvokeContext )
1069+ thread_id = request .runtime .context .thread_id
1070+
10691071 tool_call = _map_tool_call_from_langchain (request .tool_call )
10701072 assert isinstance (tool_call , ToolCall ), "Expected tool call"
10711073 return ToolRequest (
10721074 call = tool_call ,
1073- state = _convert_agent_state_from_langchain (request .state , model ),
1075+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10741076 )
10751077
10761078
10771079def _convert_subagent_request_from_lc (
10781080 request : LC_ToolCallRequest ,
10791081 model : BaseChatModel ,
10801082) -> SubagentRequest :
1083+ assert isinstance (request .runtime .context , InvokeContext )
1084+ thread_id = request .runtime .context .thread_id
1085+
10811086 subagent_call = _map_tool_call_from_langchain (request .tool_call )
10821087 assert isinstance (subagent_call , SubagentCall ), "Expected subagent call"
10831088 return SubagentRequest (
10841089 call = subagent_call ,
1085- state = _convert_agent_state_from_langchain (request .state , model ),
1090+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10861091 )
10871092
10881093
@@ -1506,7 +1511,9 @@ async def invoke_agent(
15061511 OutputT | str ,
15071512 SubagentStructuredResult | SubagentTextResult ,
15081513 ]:
1509- result = await agent .invoke ([message ], thread_id = thread_id )
1514+ result = await agent .invoke (
1515+ [message ], thread_id = thread_id or _thread_id_new_uuid ()
1516+ )
15101517
15111518 if agent .output_schema :
15121519 assert result .structured_output is not None
@@ -1555,7 +1562,7 @@ async def invoke_agent_structured(
15551562 result = await agent .invoke_with_data (
15561563 instructions = "Follow the system prompt." ,
15571564 data = content .model_dump (),
1558- thread_id = thread_id ,
1565+ thread_id = thread_id or _thread_id_new_uuid () ,
15591566 )
15601567
15611568 if agent .output_schema :
@@ -1769,7 +1776,7 @@ def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
17691776
17701777
17711778def _convert_agent_state_from_langchain (
1772- state : LC_AgentState [Any ], model : BaseChatModel
1779+ state : LC_AgentState [Any ], model : BaseChatModel , thread_id : str
17731780) -> AgentState :
17741781 messages = state ["messages" ]
17751782 total_tokens_counter = _get_approximate_token_counter (model )
@@ -1779,6 +1786,7 @@ def _convert_agent_state_from_langchain(
17791786 messages = messages ,
17801787 total_steps = len (messages ),
17811788 token_count = total_tokens ,
1789+ thread_id = thread_id ,
17821790 )
17831791
17841792
@@ -1909,6 +1917,11 @@ def check_tool_name(type: str, name: str) -> None:
19091917 check_call_id ("subagent" , call .id )
19101918 check_tool_name ("subagent" , call .name )
19111919 pending_subagent_calls [call .id ] = call .name
1920+
1921+ if call .thread_id == "" :
1922+ raise _InvalidMessagesException (
1923+ "thread_id should not be an empty string"
1924+ )
19121925 else :
19131926 raise _InvalidMessagesException (
19141927 f"AIMessage contains invalid call type: { type (call )} "
0 commit comments