@@ -2610,8 +2610,10 @@ class RLMEnv(vf.StatefulToolEnv):
26102610 include_sub_llm_in_trajectory: Whether to include sub-LLM calls as trajectory steps.
26112611 When True, sub-LLM turns are added to the trajectory as TrajectoryStep
26122612 objects with tokens, enabling training on sub-LLM calls. Interleaved
2613- rollouts are not supported in this mode. When False (default), sub-LLM
2614- calls happen but are not stored.
2613+ rollouts are supported in this mode; the environment ensures that
2614+ get_prompt_messages, get_model_response, and stop conditions always
2615+ reference the last main-model step rather than a sub-LLM step.
2616+ When False (default), sub-LLM calls happen but are not stored.
26152617 context_warning_threshold: Fraction of max_seq_len at which to warn the model
26162618 to finish (default: 0.80). Only active if max_seq_len is set.
26172619 max_startup_wait_seconds: Maximum seconds to wait for worker startup (default: 120)
@@ -3768,11 +3770,6 @@ async def teardown_executor(self):
37683770 # =========================================================================
37693771
37703772 def set_interleaved_rollouts (self , interleaved_rollouts : bool ) -> None :
3771- if interleaved_rollouts and self .include_sub_llm_in_trajectory :
3772- raise ValueError (
3773- "RLMEnv does not support interleaved rollouts when "
3774- "include_sub_llm_in_trajectory=True. Use branched rollouts instead."
3775- )
37763773 super ().set_interleaved_rollouts (interleaved_rollouts )
37773774
37783775 def update_tool_args (
@@ -3838,12 +3835,6 @@ async def setup_state(self, state: State, **kwargs) -> State:
38383835 "rlm_control_dir_remote" , f"/tmp/rlm_{ rollout_id } /rlm_control"
38393836 )
38403837
3841- if self .include_sub_llm_in_trajectory and self .interleaved_rollouts :
3842- raise ValueError (
3843- "RLMEnv does not support interleaved rollouts when "
3844- "include_sub_llm_in_trajectory=True. Use branched rollouts instead."
3845- )
3846-
38473838 try :
38483839 # 1. Setup interception and register rollout
38493840 state = await self ._setup_interception_and_register (state , rollout_id )
@@ -4202,6 +4193,14 @@ async def call_python_repl(self, code: str, state: Any) -> str:
42024193 append_execution_time = True ,
42034194 )
42044195
4196+ def _last_main_trajectory_step (self , state : State ) -> TrajectoryStep | None :
4197+ """Find the last trajectory step belonging to the main (root) model."""
4198+ main_id = state .get ("trajectory_id" )
4199+ for step in reversed (state .get ("trajectory" , [])):
4200+ if step .get ("trajectory_id" ) == main_id :
4201+ return step
4202+ return None
4203+
42054204 async def add_trajectory_step (self , state : State , trajectory_step : TrajectoryStep ):
42064205 update_rlm_metrics_from_step (state , trajectory_step )
42074206 await super ().add_trajectory_step (state , trajectory_step )
@@ -4282,8 +4281,15 @@ async def get_prompt_messages(self, state: State) -> Messages:
42824281
42834282 return cast (Messages , messages )
42844283 else :
4285- # Subsequent turns: use parent implementation
4286- return await super ().get_prompt_messages (state )
4284+ # Subsequent turns: use last main trajectory step (skip sub-LLM steps)
4285+ last_main = self ._last_main_trajectory_step (state )
4286+ if last_main is None :
4287+ return state ["prompt" ]
4288+ prev_turn_prompt = last_main ["prompt" ]
4289+ prev_turn_completion = last_main ["completion" ]
4290+ messages = concat_messages ([prev_turn_prompt , prev_turn_completion ])
4291+ env_response = await self .env_response (messages , state )
4292+ return concat_messages ([messages , env_response ])
42874293
42884294 async def env_response (
42894295 self , messages : Messages , state : State , ** kwargs
@@ -4294,6 +4300,46 @@ async def env_response(
42944300 state ["final_env_response" ] = tool_messages
42954301 return tool_messages
42964302
4303+ async def get_model_response ( # type: ignore[override]
4304+ self , state : State , prompt : Messages , ** kwargs : Any
4305+ ) -> ModelResponse :
4306+ """Ensure get_prompt_ids sees the last main trajectory step, not a sub-LLM step.
4307+
4308+ In interleaved mode, get_prompt_ids (called from super) reads
4309+ state["trajectory"][-1] to build token-level prompts. After
4310+ env_response adds sub-LLM steps, trajectory[-1] may be a sub-LLM
4311+ step with incompatible tokens. We temporarily move trailing sub-LLM
4312+ steps out of the trajectory for the duration of the super call.
4313+ """
4314+ if not (self .include_sub_llm_in_trajectory and self .interleaved_rollouts ):
4315+ return await super ().get_model_response (state , prompt , ** kwargs )
4316+
4317+ trajectory = state .get ("trajectory" , [])
4318+ if not trajectory :
4319+ return await super ().get_model_response (state , prompt , ** kwargs )
4320+
4321+ main_id = state ["trajectory_id" ]
4322+ if trajectory [- 1 ].get ("trajectory_id" ) == main_id :
4323+ return await super ().get_model_response (state , prompt , ** kwargs )
4324+
4325+ # Find last main step and temporarily move trailing sub-LLM steps aside
4326+ last_main_idx = None
4327+ for i in range (len (trajectory ) - 1 , - 1 , - 1 ):
4328+ if trajectory [i ].get ("trajectory_id" ) == main_id :
4329+ last_main_idx = i
4330+ break
4331+
4332+ if last_main_idx is None :
4333+ return await super ().get_model_response (state , prompt , ** kwargs )
4334+
4335+ trailing = trajectory [last_main_idx + 1 :]
4336+ del trajectory [last_main_idx + 1 :]
4337+ try :
4338+ result = await super ().get_model_response (state , prompt , ** kwargs )
4339+ finally :
4340+ trajectory .extend (trailing )
4341+ return result
4342+
42974343 # =========================================================================
42984344 # Stop Conditions
42994345 # =========================================================================
@@ -4309,6 +4355,30 @@ async def answer_ready(self, state: State) -> bool:
43094355 """Stop when model sets answer['ready'] = True."""
43104356 return "final_answer" in state
43114357
4358+ @vf .stop
4359+ async def max_turns_reached (self , state : State ) -> bool :
4360+ """Count only main-model trajectory steps, not sub-LLM steps."""
4361+ if self .max_turns <= 0 :
4362+ return False
4363+ main_id = state .get ("trajectory_id" )
4364+ count = sum (
4365+ 1 for s in state .get ("trajectory" , []) if s .get ("trajectory_id" ) == main_id
4366+ )
4367+ return count >= self .max_turns
4368+
4369+ @vf .stop
4370+ async def no_tools_called (self , state : State ) -> bool :
4371+ """Check last main-model completion for tool calls, ignoring sub-LLM steps."""
4372+ last_main = self ._last_main_trajectory_step (state )
4373+ if last_main is None :
4374+ return False
4375+ last_message = cast (dict [str , Any ], last_main ["completion" ][- 1 ])
4376+ is_assistant = last_message .get ("role" ) == "assistant"
4377+ no_tool_calls = (
4378+ "tool_calls" not in last_message or last_message ["tool_calls" ] is None
4379+ )
4380+ return is_assistant and no_tool_calls
4381+
43124382 @vf .stop
43134383 async def prompt_too_long (self , state : State ) -> bool :
43144384 """Stop when API returns overlong prompt error."""
0 commit comments