@@ -201,6 +201,8 @@ async def create_agent(
201201
202202@dataclass
203203class InvokeContext :
204+ thread_id : str
205+
204206 retry : LC_HumanMessage | bool = False
205207 """
206208 Controls whether to retry the agent loop after ainvoke succeeds.
@@ -641,12 +643,6 @@ async def next(r: AgentRequest) -> AgentResponse[Any | None]:
641643 async def invoke (
642644 self , messages : list [BaseMessage ], thread_id : str
643645 ) -> AgentResponse [OutputT ]:
644- # TODO: What if we are passed len(messages) == 0 to invoke?
645- # TODO: What if someone passed call_id that don't have a corresponding id with the response.
646- # Possibly we should do a validation phase of messages here.
647- # TODO: also assert correct ordering, i.e. directly after AIMessage with calls, there is a response
648- # not before or far after.
649-
650646 async def invoke_agent (req : AgentRequest ) -> AgentResponse [Any | None ]:
651647 langchain_msgs = []
652648
@@ -661,7 +657,7 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
661657 langchain_msgs .extend ([_map_message_to_langchain (m ) for m in req .messages ])
662658
663659 while True :
664- ctx = InvokeContext ()
660+ ctx = InvokeContext (thread_id = thread_id )
665661 result = await self ._agent .ainvoke (
666662 {"messages" : langchain_msgs },
667663 context = ctx ,
@@ -703,6 +699,7 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
703699
704700 result = await self ._with_agent_middleware (invoke_agent )(
705701 AgentRequest (
702+ thread_id = thread_id ,
706703 messages = messages ,
707704 )
708705 )
@@ -1060,36 +1057,44 @@ async def _sdk_handler(request: ModelRequest) -> ModelResponse:
10601057def _convert_model_request_from_lc (
10611058 request : LC_ModelRequest , model : BaseChatModel
10621059) -> ModelRequest :
1060+ thread_id = request .runtime .context .thread_id
1061+
10631062 system_message = (
10641063 request .system_message .content .__str__ () if request .system_message else ""
10651064 )
10661065
10671066 return ModelRequest (
10681067 system_message = system_message ,
1069- state = _convert_agent_state_from_langchain (request .state , model ),
1068+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10701069 )
10711070
10721071
10731072def _convert_tool_request_from_lc (
10741073 request : LC_ToolCallRequest , model : BaseChatModel
10751074) -> ToolRequest :
1075+ assert isinstance (request .runtime .context , InvokeContext )
1076+ thread_id = request .runtime .context .thread_id
1077+
10761078 tool_call = _map_tool_call_from_langchain (request .tool_call )
10771079 assert isinstance (tool_call , ToolCall ), "Expected tool call"
10781080 return ToolRequest (
10791081 call = tool_call ,
1080- state = _convert_agent_state_from_langchain (request .state , model ),
1082+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10811083 )
10821084
10831085
10841086def _convert_subagent_request_from_lc (
10851087 request : LC_ToolCallRequest ,
10861088 model : BaseChatModel ,
10871089) -> SubagentRequest :
1090+ assert isinstance (request .runtime .context , InvokeContext )
1091+ thread_id = request .runtime .context .thread_id
1092+
10881093 subagent_call = _map_tool_call_from_langchain (request .tool_call )
10891094 assert isinstance (subagent_call , SubagentCall ), "Expected subagent call"
10901095 return SubagentRequest (
10911096 call = subagent_call ,
1092- state = _convert_agent_state_from_langchain (request .state , model ),
1097+ state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
10931098 )
10941099
10951100
@@ -1516,7 +1521,9 @@ async def invoke_agent(
15161521 OutputT | str ,
15171522 SubagentStructuredResult | SubagentTextResult ,
15181523 ]:
1519- result = await agent .invoke ([message ], thread_id = thread_id )
1524+ result = await agent .invoke (
1525+ [message ], thread_id = thread_id or _thread_id_new_uuid ()
1526+ )
15201527
15211528 if agent .output_schema :
15221529 assert result .structured_output is not None
@@ -1565,7 +1572,7 @@ async def invoke_agent_structured(
15651572 result = await agent .invoke_with_data (
15661573 instructions = "Follow the system prompt." ,
15671574 data = content .model_dump (),
1568- thread_id = thread_id ,
1575+ thread_id = thread_id or _thread_id_new_uuid () ,
15691576 )
15701577
15711578 if agent .output_schema :
@@ -1780,7 +1787,7 @@ def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
17801787
17811788
17821789def _convert_agent_state_from_langchain (
1783- state : LC_AgentState [Any ], model : BaseChatModel
1790+ state : LC_AgentState [Any ], model : BaseChatModel , thread_id : str
17841791) -> AgentState :
17851792 messages = state ["messages" ]
17861793 total_tokens_counter = _get_approximate_token_counter (model )
@@ -1790,6 +1797,7 @@ def _convert_agent_state_from_langchain(
17901797 messages = messages ,
17911798 total_steps = len (messages ),
17921799 token_count = total_tokens ,
1800+ thread_id = thread_id ,
17931801 )
17941802
17951803
@@ -1920,6 +1928,11 @@ def check_tool_name(type: str, name: str) -> None:
19201928 check_call_id ("subagent" , call .id )
19211929 check_tool_name ("subagent" , call .name )
19221930 pending_subagent_calls [call .id ] = call .name
1931+
1932+ if call .thread_id == "" :
1933+ raise _InvalidMessagesException (
1934+ "thread_id should not be an empty string"
1935+ )
19231936 else :
19241937 raise _InvalidMessagesException (
19251938 f"AIMessage contains invalid call type: { type (call )} "
0 commit comments