Skip to content

Commit aff43a4

Browse files
samsjaS1ro1
andcommitted
Implement routed experts delta replay (with branched deltas)
Squashed from origin/r3-delta (tip 5c94833, which extends the earlier 3799bda with 'Support branched routed expert deltas' for cases where the routed-experts payload diverges across siblings in a group). Adapts delta replay to main's deferred routed-experts chunk concat: first step starts at 0; extended steps use prefix_len - 1; row 0 fills the boundary, remaining rows append as the new suffix. Bumps router wheel pin to local-path. Bumps deps/verifiers gitlink to d39cc5876. Co-Authored-By: S1ro1 <matej.sirovatka@gmail.com>
1 parent 6d17559 commit aff43a4

10 files changed

Lines changed: 487 additions & 35 deletions

File tree

Dockerfile.cuda

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,42 @@ ENV DEBIAN_FRONTEND=noninteractive
1818
ENV TZ=Etc/UTC
1919
RUN apt-get update && apt-get install -y --no-install-recommends --force-yes \
2020
build-essential \
21+
autoconf \
22+
automake \
23+
libtool \
24+
pkg-config \
25+
ca-certificates \
2126
curl \
2227
sudo \
2328
git \
2429
ninja-build \
30+
libnuma-dev \
31+
libnl-3-dev \
32+
libnl-route-3-dev \
33+
libibverbs-dev \
34+
librdmacm-dev \
2535
&& apt-get clean autoclean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
2636

37+
ARG UCX_VERSION=1.19.1
38+
RUN git clone --depth 1 --branch v${UCX_VERSION} https://github.com/openucx/ucx.git /tmp/ucx \
39+
&& cd /tmp/ucx \
40+
&& ./autogen.sh \
41+
&& ./configure \
42+
--prefix=/opt/ucx \
43+
--enable-shared \
44+
--disable-static \
45+
--disable-doxygen-doc \
46+
--enable-optimizations \
47+
--enable-cma \
48+
--enable-devel-headers \
49+
--enable-mt \
50+
--with-verbs \
51+
--with-cuda=/usr/local/cuda \
52+
--with-ze=no \
53+
&& make -j"$(nproc)" \
54+
&& make install \
55+
&& rm -rf /tmp/ucx
56+
2757
# Download the latest installer
2858
ADD https://astral.sh/uv/install.sh /uv-installer.sh
2959

@@ -49,7 +79,7 @@ COPY examples /app/examples
4979
COPY benchmarks/scripts /app/benchmarks/scripts
5080

5181
RUN --mount=type=cache,target=/app/.cache/uv \
52-
uv sync --extra flash-attn --extra flash-attn-3 --extra flash-attn-cute --extra envs --extra gpt-oss --group mamba-ssm --locked --no-dev
82+
uv sync --extra flash-attn --extra flash-attn-3 --extra flash-attn-cute --extra envs --extra gpt-oss --extra modelexpress --group mamba-ssm --locked --no-dev
5383

5484
# arm64: build flash-attn from source, fix namespace conflicts, apply workarounds
5585
ARG TARGETARCH
@@ -74,8 +104,12 @@ RUN apt-get update && apt-get install -y \
74104
net-tools \
75105
curl \
76106
vim \
107+
libnuma1 \
108+
libnl-3-200 \
109+
libnl-route-3-200 \
77110
libibverbs1 \
78111
ibverbs-providers \
112+
librdmacm1 \
79113
&& apt-get clean autoclean \
80114
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
81115

@@ -96,6 +130,7 @@ ENV PATH="/usr/local/bin:$PATH"
96130
WORKDIR /app
97131
# Copy the application from the builder
98132
COPY --from=builder --chown=appuser:appuser /app /app
133+
COPY --from=builder /opt/ucx /opt/ucx
99134

100135
# Copy and set up entrypoint script
101136
COPY --chown=appuser:appuser scripts/docker-entrypoint.sh /app/docker-entrypoint.sh
@@ -107,6 +142,8 @@ RUN rm /app/.venv/bin/python3.12 && ln -s /usr/local/bin/python /app/.venv/bin/p
107142

108143
# Place executables in the environment at the front of the path
109144
ENV PATH="/app/.venv/bin:$PATH"
145+
ENV UCX_HOME=/opt/ucx
146+
ENV LD_LIBRARY_PATH="/opt/ucx/lib:/opt/ucx/lib/ucx${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}"
110147

111148
# HuggingFace Hub timeouts (defaults are 10s which causes issues on slow networks)
112149
ENV HF_HUB_ETAG_TIMEOUT=500

deps/verifiers

Submodule verifiers updated 101 files

pyproject.toml

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ envs = [
8484
"opencode-science",
8585
"opencode-swe",
8686
"reverse-text",
87+
"rlm-swe",
8788
"science-env",
8889
"simpleqa-verified",
8990
"tau2-bench",
9091
"wiki-search",
92+
"wordle",
9193
]
9294
disagg = [
9395
"deep-ep ; platform_machine == 'x86_64'",
@@ -99,6 +101,9 @@ disagg = [
99101
gpt-oss = [
100102
"kernels",
101103
]
104+
modelexpress = [
105+
"modelexpress==0.3.0",
106+
]
102107
quack = [
103108
"quack-kernels>=0.4.1",
104109
]
@@ -134,6 +139,7 @@ members = [
134139
"deps/verifiers/environments/math_python",
135140
"deps/verifiers/environments/reverse_text",
136141
"deps/verifiers/environments/wiki_search",
142+
"deps/verifiers/environments/wordle",
137143
"deps/research-environments/environments/aime2024",
138144
"deps/research-environments/environments/aime2025",
139145
"deps/research-environments/environments/code_env",
@@ -155,6 +161,7 @@ members = [
155161
"deps/research-environments/environments/opencode_math",
156162
"deps/research-environments/environments/opencode_science",
157163
"deps/research-environments/environments/opencode_swe",
164+
"deps/research-environments/environments/rlm_swe",
158165
"deps/research-environments/environments/science_env",
159166
"deps/research-environments/environments/simpleqa_verified",
160167
"deps/research-environments/environments/tau2_bench",
@@ -178,6 +185,22 @@ override-dependencies = [
178185
"openenv-core",
179186
]
180187

188+
# ModelExpress 0.3.0 publishes protobuf<6 metadata, but its generated proto is
189+
# compatible with protobuf 6. prime-sandboxes requires protobuf>=6.31.1; keep
190+
# this capped to the validated protobuf major.
191+
[[tool.uv.dependency-metadata]]
192+
name = "modelexpress"
193+
version = "0.3.0"
194+
requires-dist = [
195+
"grpcio>=1.66.2",
196+
"huggingface_hub>=0.20.0",
197+
"nixl[cu12]",
198+
"numpy>=1.24.0",
199+
"protobuf>=5.27.0,<7.0.0",
200+
"pydantic>=2.0.0",
201+
"torch>=2.6.0",
202+
]
203+
181204
[tool.uv.exclude-newer-package]
182205
# we want latest vllm, remove next patch
183206
vllm = false
@@ -224,10 +247,12 @@ opencode-math = { workspace = true }
224247
opencode-science = { workspace = true }
225248
opencode-swe = { workspace = true }
226249
reverse-text = { workspace = true }
250+
rlm-swe = { workspace = true }
227251
science-env = { workspace = true }
228252
simpleqa-verified = { workspace = true }
229253
tau2-bench = { workspace = true }
230254
wiki-search = { workspace = true }
255+
wordle = { workspace = true }
231256
torch = { index = "pytorch-cu128" }
232257
torchvision = { index = "pytorch-cu128" }
233258
torchaudio = { index = "pytorch-cu128" }
@@ -236,7 +261,7 @@ dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
236261
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" }
237262
flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" }
238263
prime-pydantic-config = { workspace = true }
239-
vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" }
264+
vllm-router = { path = "third_party/router/dist/vllm_router-0.1.25-cp38-abi3-linux_x86_64.whl" }
240265
vllm = [
241266
{ url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" },
242267
{ url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl", marker = "platform_machine == 'aarch64'" },

skills/configs/SKILL.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ For rollout debugging, enable trainer-side token export under `trainer.experimen
7070

7171
Leave it unset for normal training. When enabled, it exports every sequence from each exporting rank.
7272

73+
## RLM SWE harness args
74+
75+
For `rlm_swe` / `rlm-swe` configs using the composable RLM harness, use current harness kwargs such as `rlm_max_turns`, `rlm_exec_timeout`, `rlm_max_depth`, `summarize_at_tokens`, `rlm_ref`, `local_checkout`, `append_to_system_prompt`, and `rlm_tools`. Do not use the stale `rlm_max_turns_in_context` key with the composable harness; it is not accepted by `rlm_harness`.
76+
7377
## Key files
7478

7579
- `packages/prime-rl-configs/src/prime_rl/` — config classes under `configs/`; `utils/config.py` re-exports `BaseConfig` and `cli`

src/prime_rl/inference/vllm/routed_experts.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from vllm.outputs import RequestOutput
99

1010

11-
def serialize_routed_experts(routed_experts: Any) -> dict[str, Any] | None:
11+
def serialize_routed_experts(routed_experts: Any, start: int = 0) -> dict[str, Any] | None:
1212
if routed_experts is None:
1313
return None
1414

@@ -23,18 +23,20 @@ def serialize_routed_experts(routed_experts: Any) -> dict[str, Any] | None:
2323
return {
2424
"data": pybase64.b64encode(memoryview(compact)).decode("ascii"),
2525
"shape": list(compact.shape),
26+
"start": start,
2627
}
2728

2829

2930
class RoutedExpertsCapture:
30-
def __init__(self, generator: AsyncIterator[RequestOutput]):
31+
def __init__(self, generator: AsyncIterator[RequestOutput], start: int = 0):
3132
self._generator = generator
33+
self._start = start
3234
self.routed_experts: dict[int, dict[str, Any]] = {}
3335

3436
async def __aiter__(self):
3537
async for request_output in self._generator:
3638
for output in request_output.outputs:
37-
encoded = serialize_routed_experts(getattr(output, "routed_experts", None))
39+
encoded = serialize_routed_experts(getattr(output, "routed_experts", None), start=self._start)
3840
if encoded is not None:
3941
self.routed_experts[output.index] = encoded
4042
yield request_output

src/prime_rl/inference/vllm/serving_tokens.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,11 @@ async def serve_tokens_full_generator( # type: ignore[override]
266266
# experts surface in the JSON.
267267
capture: _GenerateRoutedExpertsCapture | None = None
268268
if self.model_config.enable_return_routed_experts:
269-
capture = _GenerateRoutedExpertsCapture(result_generator)
269+
start = request.sampling_params.routed_experts_prompt_start
270+
capture = _GenerateRoutedExpertsCapture(
271+
result_generator,
272+
start=start,
273+
)
270274
result_generator = capture
271275

272276
response = await super().serve_tokens_full_generator(

src/prime_rl/orchestrator/trajectories.py

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,13 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any
244244
if tokens is not None:
245245
routed_experts_payload = tokens.get("routed_experts")
246246
routed_experts = None
247+
routed_experts_start = None
247248
if routed_experts_payload is not None:
248249
decoded_routed_experts = pybase64.b64decode_as_bytearray(routed_experts_payload["data"])
249250
routed_experts = np.frombuffer(decoded_routed_experts, dtype=np.uint8).reshape(
250251
routed_experts_payload["shape"]
251252
)
253+
routed_experts_start = routed_experts_payload["start"]
252254

253255
return {
254256
"prompt_ids": list(tokens["prompt_ids"]),
@@ -257,6 +259,7 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any
257259
"completion_mask": list(map(bool, tokens["completion_mask"])),
258260
"completion_logprobs": list(tokens["completion_logprobs"]),
259261
"routed_experts": routed_experts,
262+
"routed_experts_start": routed_experts_start,
260263
# Renderer-emitted multimodal sidecar (placeholders + per-item
261264
# processed tensors). Populated when the rollout went through
262265
# a multimodal-aware renderer (e.g. Qwen3VLRenderer); absent
@@ -277,6 +280,12 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any
277280
# Deferred routed_experts state per sample: O(N) chunk list concatenated
278281
# once at finalize, replacing the prior O(N²) per-extension unpack/repack.
279282
sample_routed_state: dict[int, dict[str, Any]] = {}
283+
routed_prefix_states: dict[int, list[tuple[list[int], list[int], dict[str, Any]]]] = {}
284+
285+
# Track (prefix_tokens, sample, step_indices) per active sample. step_indices
286+
# is the explicit list of prepared_steps positions merged into this sample —
287+
# non-contiguous when other agents' steps interleave.
288+
active_samples: list[tuple[list[int], TrainingSample, list[int]]] = []
280289

281290
def make_sample(tokens: dict[str, Any]) -> TrainingSample:
282291
"""Create a new TrainingSample from a trajectory step."""
@@ -306,9 +315,37 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample:
306315
# each extension is a no-op append rather than a destructive write.
307316
step_routed = tokens.get("routed_experts")
308317
if step_routed is not None:
318+
routed_start = tokens["routed_experts_start"]
319+
assert routed_start is not None
320+
chunks: list[np.ndarray] = []
321+
running_len = 0
322+
if routed_start > 0:
323+
source_len = routed_start + 1
324+
source_state = None
325+
for prompt_ids, completion_ids, candidate_state in routed_prefix_states[source_len]:
326+
prompt_len = len(prompt_ids)
327+
if (
328+
tokens["prompt_ids"][:prompt_len] == prompt_ids
329+
and tokens["prompt_ids"][prompt_len:source_len] == completion_ids
330+
):
331+
source_state = candidate_state
332+
break
333+
assert source_state is not None
334+
assert source_state["running_len"] >= routed_start
335+
remaining = routed_start
336+
for chunk in source_state["chunks"]:
337+
if remaining == 0:
338+
break
339+
take = min(remaining, int(chunk.shape[0]))
340+
chunks.append(chunk[:take])
341+
remaining -= take
342+
assert remaining == 0
343+
running_len = routed_start
344+
chunks.append(step_routed)
345+
running_len += int(step_routed.shape[0])
309346
sample_routed_state[id(sample)] = {
310-
"chunks": [step_routed],
311-
"running_len": int(step_routed.shape[0]),
347+
"chunks": chunks,
348+
"running_len": running_len,
312349
}
313350
return sample
314351

@@ -339,30 +376,31 @@ def extend_sample(
339376

340377
step_routed = tokens.get("routed_experts")
341378
state = sample_routed_state.get(id(sample))
342-
if step_routed is not None and state is not None:
343-
# vLLM doesn't capture a routing decision for the *last* token of any
344-
# request, so the previous step left no entry for token at index
345-
# (prefix_len - 1). The next step's forward pass *did* process that
346-
# token (as part of its prompt) and produced step_routed[prefix_len-1].
347-
# Append that single boundary entry as its own chunk, then append the
348-
# genuinely new entries from this step. No prior bytes touched.
349-
if prefix_len > 0 and prefix_len <= step_routed.shape[0]:
350-
boundary_chunk = step_routed[prefix_len - 1 : prefix_len]
379+
if state is not None:
380+
assert step_routed is not None
381+
if step_routed is not None:
382+
assert state is not None
383+
assert tokens["routed_experts_start"] == prefix_len - 1
384+
# Delta payloads start at prefix_len - 1. Row 0 fills the boundary
385+
# token missing from the previous request; the rest is the new suffix.
386+
if prefix_len > 0:
387+
boundary_chunk = step_routed[:1]
351388
state["chunks"].append(boundary_chunk)
352389
state["running_len"] += 1
353-
new_chunk = step_routed[prefix_len:]
390+
step_routed = step_routed[1:]
391+
new_chunk = step_routed
354392
state["chunks"].append(new_chunk)
355393
state["running_len"] += int(new_chunk.shape[0])
356394

357-
# Track (prefix_tokens, sample, step_indices) per active sample. step_indices
358-
# is the explicit list of prepared_steps positions merged into this sample —
359-
# non-contiguous when other agents' steps interleave.
360-
active_samples: list[tuple[list[int], TrainingSample, list[int]]] = []
361-
362395
first_tokens = prepared_steps[0]
363396
first_prefix = first_tokens["prompt_ids"] + first_tokens["completion_ids"]
364397
first_sample = make_sample(first_tokens)
365398
active_samples.append((first_prefix, first_sample, [0]))
399+
first_routed_state = sample_routed_state.get(id(first_sample))
400+
if first_routed_state is not None:
401+
routed_prefix_states.setdefault(len(first_prefix), []).append(
402+
(first_tokens["prompt_ids"], first_tokens["completion_ids"], first_routed_state)
403+
)
366404

367405
for step_idx, _step in enumerate(trajectory[1:], start=1):
368406
tokens = prepared_steps[step_idx]
@@ -379,11 +417,17 @@ def extend_sample(
379417
# Extension holds - merge into matched sample
380418
prefix_tokens, sample, step_indices = active_samples[matched_idx]
381419
extend_sample(sample, len(prefix_tokens), step_idx=step_idx)
420+
new_prefix = tokens["prompt_ids"] + tokens["completion_ids"]
382421
active_samples[matched_idx] = (
383-
tokens["prompt_ids"] + tokens["completion_ids"],
422+
new_prefix,
384423
sample,
385424
step_indices + [step_idx],
386425
)
426+
routed_state = sample_routed_state.get(id(sample))
427+
if routed_state is not None:
428+
routed_prefix_states.setdefault(len(new_prefix), []).append(
429+
(tokens["prompt_ids"], tokens["completion_ids"], routed_state)
430+
)
387431
else:
388432
# No prefix matches - start a new sample
389433
logger.debug(
@@ -393,6 +437,11 @@ def extend_sample(
393437
new_prefix = tokens["prompt_ids"] + tokens["completion_ids"]
394438
sample = make_sample(tokens)
395439
active_samples.append((new_prefix, sample, [step_idx]))
440+
routed_state = sample_routed_state.get(id(sample))
441+
if routed_state is not None:
442+
routed_prefix_states.setdefault(len(new_prefix), []).append(
443+
(tokens["prompt_ids"], tokens["completion_ids"], routed_state)
444+
)
396445

397446
# Finalize routed_experts for each sample. One concat per sample (O(N) byte
398447
# work) replaces the previous per-step unpack/concat/repack (O(N²)). The

tests/unit/inference/test_serving_tokens.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_serialize_routed_experts_uses_compact_raw_payload():
7373

7474

7575
def test_generate_response_post_process_replaces_upstream_routed_experts():
76-
compact_routed_experts = {"data": "AQID", "shape": [1, 1, 3]}
76+
compact_routed_experts = {"data": "AQID", "shape": [1, 1, 3], "start": 0}
7777
capture = _GenerateRoutedExpertsCapture(_empty_request_outputs())
7878
capture.routed_experts[0] = compact_routed_experts
7979
response = GenerateResponse(

0 commit comments

Comments
 (0)