Skip to content

Commit 00c07a8

Browse files
snimuclaude
andauthored
Support interleaved rollouts with include_sub_llm_in_trajectory=True (#900)
When sub-LLM trajectory steps are stored alongside main-model steps, several base-class methods that read state["trajectory"][-1] would see a sub-LLM step instead of the last main-model step. This adds RLMEnv overrides that filter by trajectory_id so that get_prompt_messages, get_model_response (get_prompt_ids), max_turns_reached, and no_tools_called always reference the correct step. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 26425be commit 00c07a8

2 files changed

Lines changed: 93 additions & 22 deletions

File tree

tests/test_rlm_env.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,14 +1338,15 @@ class TestSubLLMTrajectorySteps:
13381338
async def test_include_sub_llm_in_trajectory_default(self, rlm_env):
13391339
assert rlm_env.include_sub_llm_in_trajectory is False
13401340

1341-
def test_interleaved_disallowed_when_sub_llm_in_trajectory(self):
1341+
def test_interleaved_allowed_when_sub_llm_in_trajectory(self):
13421342
dataset = make_dataset({})
1343-
with pytest.raises(ValueError, match="include_sub_llm_in_trajectory=True"):
1344-
build_env(
1345-
dataset,
1346-
include_sub_llm_in_trajectory=True,
1347-
interleaved_rollouts=True,
1348-
)
1343+
env = build_env(
1344+
dataset,
1345+
include_sub_llm_in_trajectory=True,
1346+
interleaved_rollouts=True,
1347+
)
1348+
assert env.include_sub_llm_in_trajectory is True
1349+
assert env.interleaved_rollouts is True
13491350

13501351
@pytest.mark.asyncio
13511352
async def test_sub_llm_steps_added_to_trajectory(self, rlm_env):

verifiers/envs/experimental/rlm_env.py

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,8 +2610,10 @@ class RLMEnv(vf.StatefulToolEnv):
26102610
include_sub_llm_in_trajectory: Whether to include sub-LLM calls as trajectory steps.
26112611
When True, sub-LLM turns are added to the trajectory as TrajectoryStep
26122612
objects with tokens, enabling training on sub-LLM calls. Interleaved
2613-
rollouts are not supported in this mode. When False (default), sub-LLM
2614-
calls happen but are not stored.
2613+
rollouts are supported in this mode; the environment ensures that
2614+
get_prompt_messages, get_model_response, and stop conditions always
2615+
reference the last main-model step rather than a sub-LLM step.
2616+
When False (default), sub-LLM calls happen but are not stored.
26152617
context_warning_threshold: Fraction of max_seq_len at which to warn the model
26162618
to finish (default: 0.80). Only active if max_seq_len is set.
26172619
max_startup_wait_seconds: Maximum seconds to wait for worker startup (default: 120)
@@ -3768,11 +3770,6 @@ async def teardown_executor(self):
37683770
# =========================================================================
37693771

37703772
def set_interleaved_rollouts(self, interleaved_rollouts: bool) -> None:
3771-
if interleaved_rollouts and self.include_sub_llm_in_trajectory:
3772-
raise ValueError(
3773-
"RLMEnv does not support interleaved rollouts when "
3774-
"include_sub_llm_in_trajectory=True. Use branched rollouts instead."
3775-
)
37763773
super().set_interleaved_rollouts(interleaved_rollouts)
37773774

37783775
def update_tool_args(
@@ -3838,12 +3835,6 @@ async def setup_state(self, state: State, **kwargs) -> State:
38383835
"rlm_control_dir_remote", f"/tmp/rlm_{rollout_id}/rlm_control"
38393836
)
38403837

3841-
if self.include_sub_llm_in_trajectory and self.interleaved_rollouts:
3842-
raise ValueError(
3843-
"RLMEnv does not support interleaved rollouts when "
3844-
"include_sub_llm_in_trajectory=True. Use branched rollouts instead."
3845-
)
3846-
38473838
try:
38483839
# 1. Setup interception and register rollout
38493840
state = await self._setup_interception_and_register(state, rollout_id)
@@ -4202,6 +4193,14 @@ async def call_python_repl(self, code: str, state: Any) -> str:
42024193
append_execution_time=True,
42034194
)
42044195

4196+
def _last_main_trajectory_step(self, state: State) -> TrajectoryStep | None:
4197+
"""Find the last trajectory step belonging to the main (root) model."""
4198+
main_id = state.get("trajectory_id")
4199+
for step in reversed(state.get("trajectory", [])):
4200+
if step.get("trajectory_id") == main_id:
4201+
return step
4202+
return None
4203+
42054204
async def add_trajectory_step(self, state: State, trajectory_step: TrajectoryStep):
42064205
update_rlm_metrics_from_step(state, trajectory_step)
42074206
await super().add_trajectory_step(state, trajectory_step)
@@ -4282,8 +4281,15 @@ async def get_prompt_messages(self, state: State) -> Messages:
42824281

42834282
return cast(Messages, messages)
42844283
else:
4285-
# Subsequent turns: use parent implementation
4286-
return await super().get_prompt_messages(state)
4284+
# Subsequent turns: use last main trajectory step (skip sub-LLM steps)
4285+
last_main = self._last_main_trajectory_step(state)
4286+
if last_main is None:
4287+
return state["prompt"]
4288+
prev_turn_prompt = last_main["prompt"]
4289+
prev_turn_completion = last_main["completion"]
4290+
messages = concat_messages([prev_turn_prompt, prev_turn_completion])
4291+
env_response = await self.env_response(messages, state)
4292+
return concat_messages([messages, env_response])
42874293

42884294
async def env_response(
42894295
self, messages: Messages, state: State, **kwargs
@@ -4294,6 +4300,46 @@ async def env_response(
42944300
state["final_env_response"] = tool_messages
42954301
return tool_messages
42964302

4303+
async def get_model_response( # type: ignore[override]
4304+
self, state: State, prompt: Messages, **kwargs: Any
4305+
) -> ModelResponse:
4306+
"""Ensure get_prompt_ids sees the last main trajectory step, not a sub-LLM step.
4307+
4308+
In interleaved mode, get_prompt_ids (called from super) reads
4309+
state["trajectory"][-1] to build token-level prompts. After
4310+
env_response adds sub-LLM steps, trajectory[-1] may be a sub-LLM
4311+
step with incompatible tokens. We temporarily move trailing sub-LLM
4312+
steps out of the trajectory for the duration of the super call.
4313+
"""
4314+
if not (self.include_sub_llm_in_trajectory and self.interleaved_rollouts):
4315+
return await super().get_model_response(state, prompt, **kwargs)
4316+
4317+
trajectory = state.get("trajectory", [])
4318+
if not trajectory:
4319+
return await super().get_model_response(state, prompt, **kwargs)
4320+
4321+
main_id = state["trajectory_id"]
4322+
if trajectory[-1].get("trajectory_id") == main_id:
4323+
return await super().get_model_response(state, prompt, **kwargs)
4324+
4325+
# Find last main step and temporarily move trailing sub-LLM steps aside
4326+
last_main_idx = None
4327+
for i in range(len(trajectory) - 1, -1, -1):
4328+
if trajectory[i].get("trajectory_id") == main_id:
4329+
last_main_idx = i
4330+
break
4331+
4332+
if last_main_idx is None:
4333+
return await super().get_model_response(state, prompt, **kwargs)
4334+
4335+
trailing = trajectory[last_main_idx + 1 :]
4336+
del trajectory[last_main_idx + 1 :]
4337+
try:
4338+
result = await super().get_model_response(state, prompt, **kwargs)
4339+
finally:
4340+
trajectory.extend(trailing)
4341+
return result
4342+
42974343
# =========================================================================
42984344
# Stop Conditions
42994345
# =========================================================================
@@ -4309,6 +4355,30 @@ async def answer_ready(self, state: State) -> bool:
43094355
"""Stop when model sets answer['ready'] = True."""
43104356
return "final_answer" in state
43114357

4358+
@vf.stop
4359+
async def max_turns_reached(self, state: State) -> bool:
4360+
"""Count only main-model trajectory steps, not sub-LLM steps."""
4361+
if self.max_turns <= 0:
4362+
return False
4363+
main_id = state.get("trajectory_id")
4364+
count = sum(
4365+
1 for s in state.get("trajectory", []) if s.get("trajectory_id") == main_id
4366+
)
4367+
return count >= self.max_turns
4368+
4369+
@vf.stop
4370+
async def no_tools_called(self, state: State) -> bool:
4371+
"""Check last main-model completion for tool calls, ignoring sub-LLM steps."""
4372+
last_main = self._last_main_trajectory_step(state)
4373+
if last_main is None:
4374+
return False
4375+
last_message = cast(dict[str, Any], last_main["completion"][-1])
4376+
is_assistant = last_message.get("role") == "assistant"
4377+
no_tool_calls = (
4378+
"tool_calls" not in last_message or last_message["tool_calls"] is None
4379+
)
4380+
return is_assistant and no_tool_calls
4381+
43124382
@vf.stop
43134383
async def prompt_too_long(self, state: State) -> bool:
43144384
"""Stop when API returns overlong prompt error."""

0 commit comments

Comments
 (0)