5353from prime_tunnel import Tunnel
5454
5555import verifiers as vf
56+ from verifiers .clients import Client
5657from verifiers .envs .experimental .sandbox_mixin import SandboxMixin
5758from verifiers .envs .sandbox_env import CreateSandboxRequest
5859from verifiers .types import (
60+ AssistantMessage ,
5961 Message ,
6062 Messages ,
6163 Response ,
6264 State ,
6365 SystemMessage ,
66+ ToolMessage ,
6467 TrajectoryStep ,
6568 UserMessage ,
6669)
@@ -209,17 +212,12 @@ class SandboxRLMReplSession:
209212 paths : RLMWorkerPaths | None = None
210213
211214
212- def _extract_tokens_from_response (response : Any ) -> tuple [int , int ]:
215+ def _extract_tokens_from_response (response : Response | Any ) -> tuple [int , int ]:
216+ if not response :
217+ return 0 , 0
213218 usage = getattr (response , "usage" , None )
214- if not usage and isinstance (response , dict ):
215- usage = response .get ("usage" )
216219 if not usage :
217220 return 0 , 0
218- if isinstance (usage , dict ):
219- return (
220- int (usage .get ("prompt_tokens" , 0 ) or 0 ),
221- int (usage .get ("completion_tokens" , 0 ) or 0 ),
222- )
223221 return (
224222 int (getattr (usage , "prompt_tokens" , 0 ) or 0 ),
225223 int (getattr (usage , "completion_tokens" , 0 ) or 0 ),
@@ -2459,33 +2457,31 @@ def _write_builtin_context(self, context_data: Any, fs_root: str) -> None:
24592457
24602458 async def _call_sub_tool (
24612459 self , tool_name : str , tool_args : dict , tool_call_id : str
2462- ) -> dict :
2463- """Execute a sub-agent tool call. Returns tool message dict ."""
2460+ ) -> ToolMessage :
2461+ """Execute a sub-agent tool call. Returns tool message."""
24642462 try :
24652463 tool_func = self .sub_tool_map [tool_name ]
24662464 result = await maybe_await (tool_func , ** tool_args )
2467- return {
2468- "role" : "tool" ,
2469- "content" : str (result ),
2470- "tool_call_id" : tool_call_id ,
2471- }
2465+ return ToolMessage (
2466+ tool_call_id = tool_call_id ,
2467+ content = str (result ),
2468+ )
24722469 except Exception as e :
24732470 if self ._should_stop_for_error (e ):
24742471 raise
2475- return {
2476- "role" : "tool" ,
2477- "content" : f"Error: { e } " ,
2478- "tool_call_id" : tool_call_id ,
2479- }
2472+ return ToolMessage (
2473+ tool_call_id = tool_call_id ,
2474+ content = f"Error: { e } " ,
2475+ )
24802476
24812477 async def _call_sub_llm_api (
24822478 self ,
24832479 state : State ,
2484- client : Any ,
2480+ client : Client ,
24852481 model : str ,
24862482 messages : Messages ,
2487- tools : list | None = None ,
2488- ) -> Any | None :
2483+ tools : list [ vf . Tool ] | None = None ,
2484+ ) -> Response | None :
24892485 """Make a single sub-LLM API call matching main-model request mode."""
24902486 sampling_args = dict (state .get ("sampling_args" ) or {})
24912487 extra_body = sampling_args .get ("extra_body" )
@@ -2510,6 +2506,7 @@ async def _call_sub_llm_api(
25102506 cast (Messages , messages ),
25112507 client = client ,
25122508 model = model ,
2509+ tool_defs = tools ,
25132510 ),
25142511 timeout = self .sub_llm_api_timeout ,
25152512 )
@@ -2543,7 +2540,7 @@ def _make_timeout_result(
25432540 )
25442541
25452542 async def _run_sub_llm (
2546- self , state : State , client : Any , model : str , messages : Messages
2543+ self , state : State , client : Client , model : str , messages : Messages
25472544 ) -> SubLLMResult :
25482545 """Run a sub-LLM call, with optional tool-calling loop."""
25492546 # Fast path: no tools configured - single LLM call
@@ -2553,8 +2550,10 @@ async def _run_sub_llm(
25532550 return self ._make_timeout_result ([], 0 , 0 , 0 , 0 )
25542551
25552552 prompt_tokens , completion_tokens = _extract_tokens_from_response (response )
2553+ content = response .message .content
2554+ final_content = content if isinstance (content , str ) else ""
25562555 return SubLLMResult (
2557- final_content = response . message . content or "" ,
2556+ final_content = final_content ,
25582557 turns = [
25592558 SubLLMTurn (
25602559 prompt_messages = _clone_messages (messages ),
@@ -2616,8 +2615,9 @@ async def _run_sub_llm(
26162615 )
26172616
26182617 if not tool_calls :
2618+ content = assistant_message .content
26192619 return SubLLMResult (
2620- final_content = assistant_message . content or "" ,
2620+ final_content = content if isinstance ( content , str ) else "" ,
26212621 turns = turns ,
26222622 total_prompt_tokens = total_prompt_tokens ,
26232623 total_completion_tokens = total_completion_tokens ,
@@ -2631,26 +2631,14 @@ async def _run_sub_llm(
26312631 )
26322632
26332633 for tool_call in tool_calls :
2634- function_obj = getattr (tool_call , "function" , None )
2635- tool_name = (
2636- function_obj .name
2637- if function_obj is not None and hasattr (function_obj , "name" )
2638- else getattr (tool_call , "name" , "" )
2639- )
26402634 try :
2641- raw_args = (
2642- function_obj .arguments
2643- if function_obj is not None
2644- and hasattr (function_obj , "arguments" )
2645- else getattr (tool_call , "arguments" , "{}" )
2646- )
2647- tool_args = json .loads (raw_args )
2635+ tool_args = json .loads (tool_call .arguments )
26482636 except json .JSONDecodeError :
26492637 tool_args = {}
26502638 tool_result = await self ._call_sub_tool (
2651- tool_name , tool_args , tool_call .id
2639+ tool_call . name , tool_args , tool_call .id
26522640 )
2653- current_messages .append (from_raw_message ( tool_result ) )
2641+ current_messages .append (tool_result )
26542642
26552643 # Max turns reached - add prompt for final answer and make call without tools
26562644 num_turns += 1
@@ -2686,8 +2674,9 @@ async def _run_sub_llm(
26862674 )
26872675 prompt_tokens , completion_tokens = _extract_tokens_from_response (response )
26882676
2677+ content = response .message .content
26892678 return SubLLMResult (
2690- final_content = response . message . content or "" ,
2679+ final_content = content if isinstance ( content , str ) else "" ,
26912680 turns = turns ,
26922681 total_prompt_tokens = total_prompt_tokens + prompt_tokens ,
26932682 total_completion_tokens = total_completion_tokens + completion_tokens ,
@@ -2862,7 +2851,7 @@ async def _run_sub_llm_request(
28622851 self ,
28632852 * ,
28642853 state_ref : State ,
2865- client : Any ,
2854+ client : Client ,
28662855 sub_model : str ,
28672856 messages : Messages ,
28682857 batch_id : str ,
@@ -3745,12 +3734,9 @@ async def no_tools_called(self, state: State) -> bool:
37453734 last_main = self ._last_main_trajectory_step (state )
37463735 if last_main is None :
37473736 return False
3748- last_message = cast (dict [str , Any ], last_main ["completion" ][- 1 ])
3749- is_assistant = last_message .get ("role" ) == "assistant"
3750- no_tool_calls = (
3751- "tool_calls" not in last_message or last_message ["tool_calls" ] is None
3752- )
3753- return is_assistant and no_tool_calls
3737+ last_message = cast (AssistantMessage , last_main ["completion" ][- 1 ])
3738+ is_assistant = last_message .role == "assistant"
3739+ return is_assistant and not (last_message .tool_calls or [])
37543740
37553741 @vf .stop
37563742 async def prompt_too_long (self , state : State ) -> bool :
0 commit comments