Skip to content

Commit 75a3702

Browse files
snimuclaude
andauthored
migrate rlm_env to unified client types (#914)
* migrate rlm_env to unified client types Adapt rlm_env.py to fully use the provider-agnostic types introduced in #897 (unified client interface): - Use flat ToolCall attributes (name, arguments) instead of nested function object dance - Return ToolMessage objects from _call_sub_tool instead of raw dicts - Use Client type annotation instead of Any for client parameters - Pass tool_defs directly to get_model_response instead of via state - Use typed AssistantMessage access in no_tools_called stop condition - Simplify _extract_tokens_from_response (remove dead dict code paths) - Fix SubLLMResult final_content type narrowing for MessageContent union Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * restore prompt_state tool_defs as safety measure Restore setting prompt_state["tool_defs"] in _call_sub_llm_api alongside the new direct tool_defs kwarg pass. While both paths resolve equivalently through resolve_optional_args, keeping the state key is safer for any code that may read it downstream. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 09099f1 commit 75a3702

1 file changed

Lines changed: 35 additions & 49 deletions

File tree

verifiers/envs/experimental/rlm_env.py

Lines changed: 35 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,17 @@
5353
from prime_tunnel import Tunnel
5454

5555
import verifiers as vf
56+
from verifiers.clients import Client
5657
from verifiers.envs.experimental.sandbox_mixin import SandboxMixin
5758
from verifiers.envs.sandbox_env import CreateSandboxRequest
5859
from verifiers.types import (
60+
AssistantMessage,
5961
Message,
6062
Messages,
6163
Response,
6264
State,
6365
SystemMessage,
66+
ToolMessage,
6467
TrajectoryStep,
6568
UserMessage,
6669
)
@@ -209,17 +212,12 @@ class SandboxRLMReplSession:
209212
paths: RLMWorkerPaths | None = None
210213

211214

212-
def _extract_tokens_from_response(response: Any) -> tuple[int, int]:
215+
def _extract_tokens_from_response(response: Response | Any) -> tuple[int, int]:
216+
if not response:
217+
return 0, 0
213218
usage = getattr(response, "usage", None)
214-
if not usage and isinstance(response, dict):
215-
usage = response.get("usage")
216219
if not usage:
217220
return 0, 0
218-
if isinstance(usage, dict):
219-
return (
220-
int(usage.get("prompt_tokens", 0) or 0),
221-
int(usage.get("completion_tokens", 0) or 0),
222-
)
223221
return (
224222
int(getattr(usage, "prompt_tokens", 0) or 0),
225223
int(getattr(usage, "completion_tokens", 0) or 0),
@@ -2459,33 +2457,31 @@ def _write_builtin_context(self, context_data: Any, fs_root: str) -> None:
24592457

24602458
async def _call_sub_tool(
24612459
self, tool_name: str, tool_args: dict, tool_call_id: str
2462-
) -> dict:
2463-
"""Execute a sub-agent tool call. Returns tool message dict."""
2460+
) -> ToolMessage:
2461+
"""Execute a sub-agent tool call. Returns tool message."""
24642462
try:
24652463
tool_func = self.sub_tool_map[tool_name]
24662464
result = await maybe_await(tool_func, **tool_args)
2467-
return {
2468-
"role": "tool",
2469-
"content": str(result),
2470-
"tool_call_id": tool_call_id,
2471-
}
2465+
return ToolMessage(
2466+
tool_call_id=tool_call_id,
2467+
content=str(result),
2468+
)
24722469
except Exception as e:
24732470
if self._should_stop_for_error(e):
24742471
raise
2475-
return {
2476-
"role": "tool",
2477-
"content": f"Error: {e}",
2478-
"tool_call_id": tool_call_id,
2479-
}
2472+
return ToolMessage(
2473+
tool_call_id=tool_call_id,
2474+
content=f"Error: {e}",
2475+
)
24802476

24812477
async def _call_sub_llm_api(
24822478
self,
24832479
state: State,
2484-
client: Any,
2480+
client: Client,
24852481
model: str,
24862482
messages: Messages,
2487-
tools: list | None = None,
2488-
) -> Any | None:
2483+
tools: list[vf.Tool] | None = None,
2484+
) -> Response | None:
24892485
"""Make a single sub-LLM API call matching main-model request mode."""
24902486
sampling_args = dict(state.get("sampling_args") or {})
24912487
extra_body = sampling_args.get("extra_body")
@@ -2510,6 +2506,7 @@ async def _call_sub_llm_api(
25102506
cast(Messages, messages),
25112507
client=client,
25122508
model=model,
2509+
tool_defs=tools,
25132510
),
25142511
timeout=self.sub_llm_api_timeout,
25152512
)
@@ -2543,7 +2540,7 @@ def _make_timeout_result(
25432540
)
25442541

25452542
async def _run_sub_llm(
2546-
self, state: State, client: Any, model: str, messages: Messages
2543+
self, state: State, client: Client, model: str, messages: Messages
25472544
) -> SubLLMResult:
25482545
"""Run a sub-LLM call, with optional tool-calling loop."""
25492546
# Fast path: no tools configured - single LLM call
@@ -2553,8 +2550,10 @@ async def _run_sub_llm(
25532550
return self._make_timeout_result([], 0, 0, 0, 0)
25542551

25552552
prompt_tokens, completion_tokens = _extract_tokens_from_response(response)
2553+
content = response.message.content
2554+
final_content = content if isinstance(content, str) else ""
25562555
return SubLLMResult(
2557-
final_content=response.message.content or "",
2556+
final_content=final_content,
25582557
turns=[
25592558
SubLLMTurn(
25602559
prompt_messages=_clone_messages(messages),
@@ -2616,8 +2615,9 @@ async def _run_sub_llm(
26162615
)
26172616

26182617
if not tool_calls:
2618+
content = assistant_message.content
26192619
return SubLLMResult(
2620-
final_content=assistant_message.content or "",
2620+
final_content=content if isinstance(content, str) else "",
26212621
turns=turns,
26222622
total_prompt_tokens=total_prompt_tokens,
26232623
total_completion_tokens=total_completion_tokens,
@@ -2631,26 +2631,14 @@ async def _run_sub_llm(
26312631
)
26322632

26332633
for tool_call in tool_calls:
2634-
function_obj = getattr(tool_call, "function", None)
2635-
tool_name = (
2636-
function_obj.name
2637-
if function_obj is not None and hasattr(function_obj, "name")
2638-
else getattr(tool_call, "name", "")
2639-
)
26402634
try:
2641-
raw_args = (
2642-
function_obj.arguments
2643-
if function_obj is not None
2644-
and hasattr(function_obj, "arguments")
2645-
else getattr(tool_call, "arguments", "{}")
2646-
)
2647-
tool_args = json.loads(raw_args)
2635+
tool_args = json.loads(tool_call.arguments)
26482636
except json.JSONDecodeError:
26492637
tool_args = {}
26502638
tool_result = await self._call_sub_tool(
2651-
tool_name, tool_args, tool_call.id
2639+
tool_call.name, tool_args, tool_call.id
26522640
)
2653-
current_messages.append(from_raw_message(tool_result))
2641+
current_messages.append(tool_result)
26542642

26552643
# Max turns reached - add prompt for final answer and make call without tools
26562644
num_turns += 1
@@ -2686,8 +2674,9 @@ async def _run_sub_llm(
26862674
)
26872675
prompt_tokens, completion_tokens = _extract_tokens_from_response(response)
26882676

2677+
content = response.message.content
26892678
return SubLLMResult(
2690-
final_content=response.message.content or "",
2679+
final_content=content if isinstance(content, str) else "",
26912680
turns=turns,
26922681
total_prompt_tokens=total_prompt_tokens + prompt_tokens,
26932682
total_completion_tokens=total_completion_tokens + completion_tokens,
@@ -2862,7 +2851,7 @@ async def _run_sub_llm_request(
28622851
self,
28632852
*,
28642853
state_ref: State,
2865-
client: Any,
2854+
client: Client,
28662855
sub_model: str,
28672856
messages: Messages,
28682857
batch_id: str,
@@ -3745,12 +3734,9 @@ async def no_tools_called(self, state: State) -> bool:
37453734
last_main = self._last_main_trajectory_step(state)
37463735
if last_main is None:
37473736
return False
3748-
last_message = cast(dict[str, Any], last_main["completion"][-1])
3749-
is_assistant = last_message.get("role") == "assistant"
3750-
no_tool_calls = (
3751-
"tool_calls" not in last_message or last_message["tool_calls"] is None
3752-
)
3753-
return is_assistant and no_tool_calls
3737+
last_message = cast(AssistantMessage, last_main["completion"][-1])
3738+
is_assistant = last_message.role == "assistant"
3739+
return is_assistant and not (last_message.tool_calls or [])
37543740

37553741
@vf.stop
37563742
async def prompt_too_long(self, state: State) -> bool:

0 commit comments

Comments
 (0)