Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps/verifiers
Submodule verifiers updated 101 files
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" }
prime-pydantic-config = { workspace = true }
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" }
vllm-router = { path = "third_party/router/dist/vllm_router-0.1.25-cp38-abi3-linux_x86_64.whl" }
vllm = [
{ url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" },
{ url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl", marker = "platform_machine == 'aarch64'" },
Expand Down
8 changes: 5 additions & 3 deletions src/prime_rl/inference/vllm/routed_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.outputs import RequestOutput


def serialize_routed_experts(routed_experts: Any) -> dict[str, Any] | None:
def serialize_routed_experts(routed_experts: Any, start: int = 0) -> dict[str, Any] | None:
if routed_experts is None:
return None

Expand All @@ -23,18 +23,20 @@ def serialize_routed_experts(routed_experts: Any) -> dict[str, Any] | None:
return {
"data": pybase64.b64encode(memoryview(compact)).decode("ascii"),
"shape": list(compact.shape),
"start": start,
}


class RoutedExpertsCapture:
def __init__(self, generator: AsyncIterator[RequestOutput]):
def __init__(self, generator: AsyncIterator[RequestOutput], start: int = 0):
self._generator = generator
self._start = start
self.routed_experts: dict[int, dict[str, Any]] = {}

async def __aiter__(self):
async for request_output in self._generator:
for output in request_output.outputs:
encoded = serialize_routed_experts(getattr(output, "routed_experts", None))
encoded = serialize_routed_experts(getattr(output, "routed_experts", None), start=self._start)
if encoded is not None:
self.routed_experts[output.index] = encoded
yield request_output
6 changes: 5 additions & 1 deletion src/prime_rl/inference/vllm/serving_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,11 @@ async def serve_tokens_full_generator( # type: ignore[override]
# experts surface in the JSON.
capture: _GenerateRoutedExpertsCapture | None = None
if self.model_config.enable_return_routed_experts:
capture = _GenerateRoutedExpertsCapture(result_generator)
start = request.sampling_params.routed_experts_prompt_start
capture = _GenerateRoutedExpertsCapture(
result_generator,
start=start,
)
result_generator = capture

response = await super().serve_tokens_full_generator(
Expand Down
114 changes: 92 additions & 22 deletions src/prime_rl/orchestrator/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,13 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any
if tokens is not None:
routed_experts_payload = tokens.get("routed_experts")
routed_experts = None
routed_experts_start = None
if routed_experts_payload is not None:
decoded_routed_experts = pybase64.b64decode_as_bytearray(routed_experts_payload["data"])
routed_experts = np.frombuffer(decoded_routed_experts, dtype=np.uint8).reshape(
routed_experts_payload["shape"]
)
routed_experts_start = routed_experts_payload["start"]
Comment thread
cursor[bot] marked this conversation as resolved.

return {
"prompt_ids": list(tokens["prompt_ids"]),
Expand All @@ -257,6 +259,7 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any
"completion_mask": list(map(bool, tokens["completion_mask"])),
"completion_logprobs": list(tokens["completion_logprobs"]),
"routed_experts": routed_experts,
"routed_experts_start": routed_experts_start,
# Renderer-emitted multimodal sidecar (placeholders + per-item
# processed tensors). Populated when the rollout went through
# a multimodal-aware renderer (e.g. Qwen3VLRenderer); absent
Expand All @@ -277,6 +280,12 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any
# Deferred routed_experts state per sample: O(N) chunk list concatenated
# once at finalize, replacing the prior O(N²) per-extension unpack/repack.
sample_routed_state: dict[int, dict[str, Any]] = {}
routed_prefix_states: dict[int, list[tuple[list[int], list[int], dict[str, Any]]]] = {}

# Track (prefix_tokens, sample, step_indices) per active sample. step_indices
# is the explicit list of prepared_steps positions merged into this sample —
# non-contiguous when other agents' steps interleave.
active_samples: list[tuple[list[int], TrainingSample, list[int]]] = []

def make_sample(tokens: dict[str, Any]) -> TrainingSample:
"""Create a new TrainingSample from a trajectory step."""
Expand Down Expand Up @@ -306,9 +315,37 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample:
# each extension is a no-op append rather than a destructive write.
step_routed = tokens.get("routed_experts")
if step_routed is not None:
routed_start = tokens["routed_experts_start"]
assert routed_start is not None
chunks: list[np.ndarray] = []
running_len = 0
if routed_start > 0:
source_len = routed_start + 1
source_state = None
for prompt_ids, completion_ids, candidate_state in routed_prefix_states[source_len]:
prompt_len = len(prompt_ids)
if (
tokens["prompt_ids"][:prompt_len] == prompt_ids
and tokens["prompt_ids"][prompt_len:source_len] == completion_ids
):
source_state = candidate_state
break
assert source_state is not None
Comment thread
cursor[bot] marked this conversation as resolved.
assert source_state["running_len"] >= routed_start
remaining = routed_start
for chunk in source_state["chunks"]:
if remaining == 0:
break
take = min(remaining, int(chunk.shape[0]))
chunks.append(chunk[:take])
remaining -= take
assert remaining == 0
running_len = routed_start
chunks.append(step_routed)
running_len += int(step_routed.shape[0])
sample_routed_state[id(sample)] = {
"chunks": [step_routed],
"running_len": int(step_routed.shape[0]),
"chunks": chunks,
"running_len": running_len,
}
return sample

Expand Down Expand Up @@ -339,51 +376,79 @@ def extend_sample(

step_routed = tokens.get("routed_experts")
state = sample_routed_state.get(id(sample))
if step_routed is not None and state is not None:
# vLLM doesn't capture a routing decision for the *last* token of any
# request, so the previous step left no entry for token at index
# (prefix_len - 1). The next step's forward pass *did* process that
# token (as part of its prompt) and produced step_routed[prefix_len-1].
# Append that single boundary entry as its own chunk, then append the
# genuinely new entries from this step. No prior bytes touched.
if prefix_len > 0 and prefix_len <= step_routed.shape[0]:
boundary_chunk = step_routed[prefix_len - 1 : prefix_len]
if state is not None:
assert step_routed is not None
if step_routed is not None:
assert state is not None
Comment thread
cursor[bot] marked this conversation as resolved.
assert tokens["routed_experts_start"] == prefix_len - 1
# Delta payloads start at prefix_len - 1. Row 0 fills the boundary
# token missing from the previous request; the rest is the new suffix.
if prefix_len > 0:
boundary_chunk = step_routed[:1]
state["chunks"].append(boundary_chunk)
state["running_len"] += 1
new_chunk = step_routed[prefix_len:]
step_routed = step_routed[1:]
new_chunk = step_routed
state["chunks"].append(new_chunk)
state["running_len"] += int(new_chunk.shape[0])

# Track (prefix_tokens, sample, step_indices) per active sample. step_indices
# is the explicit list of prepared_steps positions merged into this sample —
# non-contiguous when other agents' steps interleave.
active_samples: list[tuple[list[int], TrainingSample, list[int]]] = []

first_tokens = prepared_steps[0]
first_prefix = first_tokens["prompt_ids"] + first_tokens["completion_ids"]
first_sample = make_sample(first_tokens)
active_samples.append((first_prefix, first_sample, [0]))
first_routed_state = sample_routed_state.get(id(first_sample))
if first_routed_state is not None:
routed_prefix_states.setdefault(len(first_prefix), []).append(
(first_tokens["prompt_ids"], first_tokens["completion_ids"], first_routed_state)
)

for step_idx, _step in enumerate(trajectory[1:], start=1):
tokens = prepared_steps[step_idx]
step_prompt_ids = tokens["prompt_ids"]

# Check if this step extends ANY active prefix
# Pick the *longest* matching active prefix. With compaction/rollback,
# one active sample's prefix can be a strict prefix of another (e.g. a
# later sample re-generated tokens that overlap an earlier sample's
# prefix). Both would satisfy the slice check; the shorter would
# silently absorb the longer sample's generated tokens as user input.
matched_idx = None
matched_len = -1
matching_prefix_lens: list[int] = []
for idx, (prefix_tokens, _, _) in enumerate(active_samples):
if step_prompt_ids[: len(prefix_tokens)] == prefix_tokens:
matched_idx = idx
break
pl = len(prefix_tokens)
if step_prompt_ids[:pl] == prefix_tokens:
matching_prefix_lens.append(pl)
if pl > matched_len:
matched_idx = idx
matched_len = pl

if len(matching_prefix_lens) > 1:
# Ambiguous extension: rare, but reachable via compaction/rollback
# where a new sample's prefix happens to start with an older
# sample's prefix. Longest-match is the correct choice; surface
# the ambiguity so we can audit if it shows up in real rollouts.
logger.warning(
f"Ambiguous prefix match at step {step_idx} for example {output['example_id']}: "
f"{len(matching_prefix_lens)} of {len(active_samples)} active prefixes match "
f"(lens={sorted(matching_prefix_lens)}, step_prompt_len={len(step_prompt_ids)}). "
f"Extending the longest (len={matched_len})."
)

if matched_idx is not None:
# Extension holds - merge into matched sample
prefix_tokens, sample, step_indices = active_samples[matched_idx]
extend_sample(sample, len(prefix_tokens), step_idx=step_idx)
new_prefix = tokens["prompt_ids"] + tokens["completion_ids"]
active_samples[matched_idx] = (
tokens["prompt_ids"] + tokens["completion_ids"],
new_prefix,
sample,
step_indices + [step_idx],
)
routed_state = sample_routed_state.get(id(sample))
if routed_state is not None:
routed_prefix_states.setdefault(len(new_prefix), []).append(
(tokens["prompt_ids"], tokens["completion_ids"], routed_state)
)
else:
# No prefix matches - start a new sample
logger.debug(
Expand All @@ -393,6 +458,11 @@ def extend_sample(
new_prefix = tokens["prompt_ids"] + tokens["completion_ids"]
sample = make_sample(tokens)
active_samples.append((new_prefix, sample, [step_idx]))
routed_state = sample_routed_state.get(id(sample))
if routed_state is not None:
routed_prefix_states.setdefault(len(new_prefix), []).append(
(tokens["prompt_ids"], tokens["completion_ids"], routed_state)
)

# Finalize routed_experts for each sample. One concat per sample (O(N) byte
# work) replaces the previous per-step unpack/concat/repack (O(N²)). The
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/inference/test_serving_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_serialize_routed_experts_uses_compact_raw_payload():


def test_generate_response_post_process_replaces_upstream_routed_experts():
compact_routed_experts = {"data": "AQID", "shape": [1, 1, 3]}
compact_routed_experts = {"data": "AQID", "shape": [1, 1, 3], "start": 0}
capture = _GenerateRoutedExpertsCapture(_empty_request_outputs())
capture.routed_experts[0] = compact_routed_experts
response = GenerateResponse(
Expand Down
Loading
Loading