Skip to content

Commit a3013f3

Browse files
mikasenghaascursoragent
authored andcommitted
fix(orchestrator): match longest active prefix in interleave_rollout
The first-match-wins loop over active_samples picks the wrong sample when one active prefix is a strict prefix of another. This can happen after a compaction/rollback step whose prompt is shorter than an existing sample's prefix and whose completion re-generates the same tokens and extends past them: the new sample's prefix then starts with the older sample's prefix, and any later step that extends the new sample also satisfies the slice check against the older one. When that happens, extend_sample folds the newer sample's generated tokens into the older sample as user-input tokens (mask=False, logprob=0) and leaves the newer sample stale -- a silent Exact-Prefix invariant violation. Switch to longest-match: strictly more specific, never worse than first-match when only one prefix matches. Co-authored-by: Cursor <cursoragent@cursor.com> (cherry picked from commit 0e239d1)
1 parent 758b5ee commit a3013f3

2 files changed

Lines changed: 143 additions & 3 deletions

File tree

src/prime_rl/orchestrator/trajectories.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,12 +406,18 @@ def extend_sample(
406406
tokens = prepared_steps[step_idx]
407407
step_prompt_ids = tokens["prompt_ids"]
408408

409-
# Check if this step extends ANY active prefix
409+
# Pick the *longest* matching active prefix. With compaction/rollback,
410+
# one active sample's prefix can be a strict prefix of another (e.g. a
411+
# later sample re-generated tokens that overlap an earlier sample's
412+
# prefix). Both would satisfy the slice check; the shorter would
413+
# silently absorb the longer sample's generated tokens as user input.
410414
matched_idx = None
415+
matched_len = -1
411416
for idx, (prefix_tokens, _, _) in enumerate(active_samples):
412-
if step_prompt_ids[: len(prefix_tokens)] == prefix_tokens:
417+
pl = len(prefix_tokens)
418+
if pl > matched_len and step_prompt_ids[:pl] == prefix_tokens:
413419
matched_idx = idx
414-
break
420+
matched_len = pl
415421

416422
if matched_idx is not None:
417423
# Extension holds - merge into matched sample

tests/unit/orchestrator/test_trajectories.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,140 @@ def test_interleave_rollout_interleaved_agents(interleaved_agents_trajectory):
749749
assert agent2_sample.completion_logprobs == [-0.5, -0.6]
750750

751751

752+
@pytest.fixture
753+
def prefix_of_prefix_trajectory():
754+
"""
755+
Trajectory where one active sample's prefix is a strict prefix of another's.
756+
757+
Construction:
758+
- step 0: prompt=[1,2], completion=[3,4] -> sample A, P_A=[1,2,3,4]
759+
- step 1: extends A. prompt=[1,2,3,4,5], completion=[6] -> P_A=[1,2,3,4,5,6]
760+
- step 2: rollback/regenerate. prompt=[1,2] (shorter than P_A so no match),
761+
completion=[3,4,5,6,7] -> sample B, P_B=[1,2,3,4,5,6,7]
762+
P_B starts with P_A.
763+
- step 3: extends B. prompt=[1,2,3,4,5,6,7,8], completion=[9]
764+
Both P_A and P_B are token-prefixes of the step's prompt.
765+
766+
The correct match is the longer P_B. First-match-wins picks P_A and silently
767+
folds B's generated tokens into A as user-input tokens (mask=False).
768+
"""
769+
output = vf.RolloutOutput(
770+
example_id=2,
771+
task="test",
772+
trajectory=[
773+
vf.TrajectoryStep(
774+
prompt="step 0",
775+
completion="completion 0",
776+
response=None,
777+
tokens=vf.TrajectoryStepTokens(
778+
prompt_ids=[1, 2],
779+
prompt_mask=[0, 0],
780+
completion_ids=[3, 4],
781+
completion_mask=[1, 1],
782+
completion_logprobs=[-0.1, -0.2],
783+
overlong_prompt=False,
784+
is_truncated=False,
785+
),
786+
reward=None,
787+
advantage=None,
788+
is_truncated=False,
789+
trajectory_id="traj_A",
790+
extras={},
791+
),
792+
vf.TrajectoryStep(
793+
prompt="step 1",
794+
completion="completion 1",
795+
response=None,
796+
tokens=vf.TrajectoryStepTokens(
797+
prompt_ids=[1, 2, 3, 4, 5],
798+
prompt_mask=[0, 0, 0, 0, 0],
799+
completion_ids=[6],
800+
completion_mask=[1],
801+
completion_logprobs=[-0.3],
802+
overlong_prompt=False,
803+
is_truncated=False,
804+
),
805+
reward=None,
806+
advantage=None,
807+
is_truncated=False,
808+
trajectory_id="traj_A",
809+
extras={},
810+
),
811+
vf.TrajectoryStep(
812+
prompt="step 2 (rollback)",
813+
completion="completion 2",
814+
response=None,
815+
tokens=vf.TrajectoryStepTokens(
816+
prompt_ids=[1, 2],
817+
prompt_mask=[0, 0],
818+
completion_ids=[3, 4, 5, 6, 7],
819+
completion_mask=[1, 1, 1, 1, 1],
820+
completion_logprobs=[-0.4, -0.5, -0.6, -0.7, -0.8],
821+
overlong_prompt=False,
822+
is_truncated=False,
823+
),
824+
reward=None,
825+
advantage=None,
826+
is_truncated=False,
827+
trajectory_id="traj_B",
828+
extras={},
829+
),
830+
vf.TrajectoryStep(
831+
prompt="step 3 (extends B)",
832+
completion="completion 3",
833+
response=None,
834+
tokens=vf.TrajectoryStepTokens(
835+
prompt_ids=[1, 2, 3, 4, 5, 6, 7, 8],
836+
prompt_mask=[0, 0, 0, 0, 0, 0, 0, 0],
837+
completion_ids=[9],
838+
completion_mask=[1],
839+
completion_logprobs=[-0.9],
840+
overlong_prompt=False,
841+
is_truncated=False,
842+
),
843+
reward=None,
844+
advantage=None,
845+
is_truncated=False,
846+
trajectory_id="traj_B",
847+
extras={},
848+
),
849+
],
850+
sampling_args={"temperature": 1.0},
851+
error=None,
852+
)
853+
return output
854+
855+
856+
def test_interleave_rollout_picks_longest_matching_prefix(prefix_of_prefix_trajectory):
857+
"""
858+
When two active samples both match (one's prefix is a strict prefix of the
859+
other's), the longer prefix is the correct extension. Previously the first-
860+
match-wins loop folded the longer sample's generated tokens into the shorter
861+
sample as user input (mask=False) and left the longer sample stale.
862+
"""
863+
rollouts = interleave_rollout(prefix_of_prefix_trajectory)
864+
865+
assert rollouts is not None
866+
assert len(rollouts) == 2
867+
868+
# Sample A: steps 0 and 1 only. Step 3 must NOT have been folded in here.
869+
sample_a = rollouts[0]
870+
assert sample_a.prompt_ids == [1, 2]
871+
# step 0 completion [3,4] + step 1 new prompt [5] + step 1 completion [6]
872+
assert sample_a.completion_ids == [3, 4, 5, 6]
873+
assert sample_a.completion_mask == [True, True, False, True]
874+
assert sample_a.completion_logprobs == [-0.1, -0.2, 0.0, -0.3]
875+
876+
# Sample B: steps 2 and 3 merged. The token 7 (from step 2's completion)
877+
# must remain masked as a generated token, not silently re-classified.
878+
sample_b = rollouts[1]
879+
assert sample_b.prompt_ids == [1, 2]
880+
# step 2 completion [3,4,5,6,7] + step 3 new prompt [8] + step 3 completion [9]
881+
assert sample_b.completion_ids == [3, 4, 5, 6, 7, 8, 9]
882+
assert sample_b.completion_mask == [True, True, True, True, True, False, True]
883+
assert sample_b.completion_logprobs == [-0.4, -0.5, -0.6, -0.7, -0.8, 0.0, -0.9]
884+
885+
752886
def test_interleave_rollout_empty_trajectory():
753887
"""Empty trajectory returns None."""
754888
output = vf.RolloutOutput(

0 commit comments

Comments
 (0)