@@ -201,6 +201,8 @@ async def create_agent(
201201
202202@dataclass
203203class InvokeContext :
204+ thread_id : str
205+
204206 retry : LC_HumanMessage | bool = False
205207 """
206208 Controls whether to retry the agent loop after ainvoke succeeds.
@@ -637,12 +639,6 @@ async def next(r: AgentRequest) -> AgentResponse[Any | None]:
637639 async def invoke (
638640 self , messages : list [BaseMessage ], thread_id : str
639641 ) -> AgentResponse [OutputT ]:
640- # TODO: What if we are passed len(messages) == 0 to invoke?
641- # TODO: What if someone passed call_id that don't have a corresponding id with the response.
642- # Possibly we should do a validation phase of messages here.
643- # TODO: also assert correct ordering, i.e. directly after AIMessage with calls, there is a response
644- # not before or far after.
645-
646642 async def invoke_agent (req : AgentRequest ) -> AgentResponse [Any | None ]:
647643 langchain_msgs = []
648644
@@ -657,7 +653,7 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
657653 langchain_msgs .extend ([_map_message_to_langchain (m ) for m in req .messages ])
658654
659655 while True :
660- ctx = InvokeContext ()
656+ ctx = InvokeContext (thread_id = thread_id )
661657 result = await self ._agent .ainvoke (
662658 {"messages" : langchain_msgs },
663659 context = ctx ,
@@ -699,6 +695,7 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
699695
700696 result = await self ._with_agent_middleware (invoke_agent )(
701697 AgentRequest (
698+ thread_id = thread_id ,
702699 messages = messages ,
703700 )
704701 )
@@ -1054,36 +1051,44 @@ async def _sdk_handler(request: ModelRequest) -> ModelResponse:
10541051def _convert_model_request_from_lc (
10551052 request : LC_ModelRequest , model : BaseChatModel
10561053) -> ModelRequest :
1054+ thread_id = request .runtime .context .thread_id
1055+
10571056 system_message = (
10581057 request .system_message .content .__str__ () if request .system_message else ""
10591058 )
10601059
10611060 return ModelRequest (
10621061 system_message = system_message ,
1063- state = _convert_agent_state_from_langchain (request .state , model ),
1062+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10641063 )
10651064
10661065
10671066def _convert_tool_request_from_lc (
10681067 request : LC_ToolCallRequest , model : BaseChatModel
10691068) -> ToolRequest :
1069+ assert isinstance (request .runtime .context , InvokeContext )
1070+ thread_id = request .runtime .context .thread_id
1071+
10701072 tool_call = _map_tool_call_from_langchain (request .tool_call )
10711073 assert isinstance (tool_call , ToolCall ), "Expected tool call"
10721074 return ToolRequest (
10731075 call = tool_call ,
1074- state = _convert_agent_state_from_langchain (request .state , model ),
1076+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10751077 )
10761078
10771079
10781080def _convert_subagent_request_from_lc (
10791081 request : LC_ToolCallRequest ,
10801082 model : BaseChatModel ,
10811083) -> SubagentRequest :
1084+ assert isinstance (request .runtime .context , InvokeContext )
1085+ thread_id = request .runtime .context .thread_id
1086+
10821087 subagent_call = _map_tool_call_from_langchain (request .tool_call )
10831088 assert isinstance (subagent_call , SubagentCall ), "Expected subagent call"
10841089 return SubagentRequest (
10851090 call = subagent_call ,
1086- state = _convert_agent_state_from_langchain (request .state , model ),
1091+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10871092 )
10881093
10891094
@@ -1508,7 +1513,9 @@ async def invoke_agent(
15081513 OutputT | str ,
15091514 SubagentStructuredResult | SubagentTextResult ,
15101515 ]:
1511- result = await agent .invoke ([message ], thread_id = thread_id )
1516+ result = await agent .invoke (
1517+ [message ], thread_id = thread_id or _thread_id_new_uuid ()
1518+ )
15121519
15131520 if agent .output_schema :
15141521 assert result .structured_output is not None
@@ -1557,7 +1564,7 @@ async def invoke_agent_structured(
15571564 result = await agent .invoke_with_data (
15581565 instructions = "Follow the system prompt." ,
15591566 data = content .model_dump (),
1560- thread_id = thread_id ,
1567+ thread_id = thread_id or _thread_id_new_uuid () ,
15611568 )
15621569
15631570 if agent .output_schema :
@@ -1772,7 +1779,7 @@ def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
17721779
17731780
17741781def _convert_agent_state_from_langchain (
1775- state : LC_AgentState [Any ], model : BaseChatModel
1782+ state : LC_AgentState [Any ], model : BaseChatModel , thread_id : str
17761783) -> AgentState :
17771784 messages = state ["messages" ]
17781785 total_tokens_counter = _get_approximate_token_counter (model )
@@ -1782,6 +1789,7 @@ def _convert_agent_state_from_langchain(
17821789 messages = messages ,
17831790 total_steps = len (messages ),
17841791 token_count = total_tokens ,
1792+ thread_id = thread_id ,
17851793 )
17861794
17871795
@@ -1912,6 +1920,11 @@ def check_tool_name(type: str, name: str) -> None:
19121920 check_call_id ("subagent" , call .id )
19131921 check_tool_name ("subagent" , call .name )
19141922 pending_subagent_calls [call .id ] = call .name
1923+
1924+ if call .thread_id == "" :
1925+ raise _InvalidMessagesException (
1926+ "thread_id should not be an empty string"
1927+ )
19151928 else :
19161929 raise _InvalidMessagesException (
19171930 f"AIMessage contains invalid call type: { type (call )} "
0 commit comments