From e460190325e6c132995e47e4b5f6b8ca6de34d28 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 27 May 2026 08:06:11 -0700 Subject: [PATCH 01/18] docs(proposals): post-PR-#2389 plan for kernel compile, mixed-TP, MX clients MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Proposes the next phase of work on top of `nixl_mx` once #2389 merges: 1. Phase-1 — six surgical fixes against the in-tree code that close the bug classes we hit during GB200 bring-up (cross-subnet add_remote_agent full-mesh; stale READY peer dedup; heartbeat / STALE-on-shutdown; hardcoded 1200s timeout; non-MLA model guard; HSDP barrier ordering). Line-pinned against HEAD `79ea824d8`. 2. Phase-2 — graduate `src/prime_rl/transport/mx_rendezvous.py` onto NVIDIA's published `modelexpress` Python clients (`MxV2TrainingPublisher` / `MxV2RefitReceiver`). Deletes ~185 LOC of in-tree rendezvous that duplicates the upstream client. Inherits heartbeat + freshest-per-rank dedup + retention + sidecar-filter for free. `NixlAgentWrapper` / `Slot` / `TransportPlan` / `classic_cuda_pool` stay — those are prime-rl specialization. 3. Phase-3 — solves the trainer-side kernel-compile issue surfaced during #2389's FP8 cast-pipeline iteration. Trainer publishes HF-raw bytes (kernel-agnostic); inference compiles into its target layout (DeepGemm, cutlass, ...) via a receiver-side scratch-buffer pass. Extends the v2 shape registry with `compile_target` + `compile_metadata`. Heterogeneous fleets (mixed kernels on the same training run) now work without trainer-side branching. 4. Phase-3 also generalizes the v2 sharding metadata to handle mixed-TP/EP via `TargetTPLayout` + multi-source slice discovery in the same machinery NemoRL v2 uses for MoE expert filtering. Pulls heavily on the NemoRL × Dynamo path (NVIDIA, John Thompson) which is already running at 380 Gbps on GB300 RoCE for an 8.82 GB refit — same scratch-buffer + worker-extension-cls pattern this plan adopts. Component + per-refit sequence diagrams (mermaid) included. Estimated ~450 LOC additive across modelexpress + prime-rl for Phases 3-4 (plus the ~400 LOC subtraction from Phase 2). Doc only. Implementation phases sequenced behind the upstream merge of #2389. --- .../post-pr2389-kernel-compile-plan.md | 517 ++++++++++++++++++ 1 file changed, 517 insertions(+) create mode 100644 docs/proposals/post-pr2389-kernel-compile-plan.md diff --git a/docs/proposals/post-pr2389-kernel-compile-plan.md b/docs/proposals/post-pr2389-kernel-compile-plan.md new file mode 100644 index 0000000000..18f15f6f57 --- /dev/null +++ b/docs/proposals/post-pr2389-kernel-compile-plan.md @@ -0,0 +1,517 @@ +# Post-PR-#2389 plan: kernel-compile separation, mixed-TP, and MX client adoption + +**Status**: Planning doc. Branch `kavink/post-2389-kernel-compile-plan` (NVIDIA-authored, off PR [#2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) HEAD `79ea824d8`). +**Premise**: This plan is what we propose to build on top of `nixl_mx` once #2389 merges to `main`. It (a) graduates the in-tree `MxRendezvous` reimplementation onto NVIDIA's published ModelExpress clients, (b) introduces a compile-target registry to fix the trainer-side cutlass-pinning issue surfaced during #2389's FP8 cast-pipeline iteration, and (c) extends the v2 shape registry to handle mixed-TP / sharded-source transfers. **None of this fights the #2389 data plane** — the `Slot` / `TransportPlan` / `NixlAgentWrapper` / `classic_cuda_pool` stack stays untouched. We extend the rendezvous and metadata surfaces only. + +--- + +## 1. What we're building on + +After #2389 merges, prime-rl's `nixl_mx` mode has: + +- `src/prime_rl/transport/` — `NixlAgentWrapper`, `MxRendezvous` (PI's reimplementation of MX's upper layer), `TransportPlan`, `classic_cuda_pool`, `wire.py` (msgspec types). +- `src/prime_rl/trainer/models/slots.py` — `ShardedSlot`, `GatheredSlot`, `ExpertSlot` (the per-tensor buffer abstractions that hold registered NIXL memory and conversion logic). +- `src/prime_rl/trainer/models/conversions/` — bf16-cast and FP8-blockwise conversion specs (`bf16_cast.py`, `fp8_blockwise.py`). +- `src/prime_rl/trainer/rl/broadcast/nixl_mx.py` — `NIXLMxWeightBroadcast`, the lifecycle wrapper. +- `src/prime_rl/inference/vllm/worker/nixl_mx.py` — `NIXLMxWeightUpdateWorker`, the vLLM worker extension. + +PI's data plane (Slot / TransportPlan / NixlAgentWrapper / FP8 conversion specs) is **kept**. It's the conversion / compile / topology layer where this plan extends things. + +--- + +## 2. The problem Matej flagged: trainer-side kernel compile pins the topology + +In #2389 today, the conversion-cast pipeline (`fp8_blockwise.py`, `Slot.convert`, the recent `c3c4b148` "inline transfer slot casting" + `47e170f5` "simplify transfer cast conversions" refactors) runs **on the trainer**. The trainer's send bucket holds the **post-compile** layout: bf16 if the inference engine wants bf16, FP8 with DeepGemm scale interleaving if it wants DeepGemm, cutlass-friendly column-major + epilogue scales if it wants cutlass, etc. Inference RDMA-reads the bucket and copies straight into live params. + +This is fast — no receiver-side compute on the refit critical path — but it couples three independent decisions into the trainer's process: + +1. **Which kernel format does inference want?** Today the trainer has to know. Adding cutlass on GB300 means trainer code changes; adding DeepGemm-EP changes the slot layout. +2. **Mixed inference replicas with different kernels.** With trainer-side compile, one trainer can serve only one compile target per refit. Heterogeneous fleets (e.g. a/b testing DeepGemm vs cutlass on the same training run) need two parallel trainer buckets. +3. **Mixed TP/EP layouts.** A trainer at TP=4 publishing to inference at TP=8 needs to know the inference splits at *publish* time to write the right bucket entries. The information is in PI's `build_topology`, but it's tangled with the compile decision because the compile passes produce TP-specific layouts. + +When Matej said "we're having issues handling compiled kernels" — this is the source. The cutlass compile happens trainer-side; any inference replica that doesn't want exactly the trainer's compile target either gets bad bytes or fails the assertion check. + +--- + +## 3. Proposed shift: receiver-side compile via scratch buffers + +Move the compile pass back to the **inference side**. Trainer ships **canonical HF layout** over NIXL (raw bfloat16 or raw float8_e4m3fn + raw per-block scales — whatever format the trainer naturally has post-optimizer-step). Inference runs the kernel-specific compile pass after the RDMA receive completes, into its own live params. + +This is the same scratch-buffer pattern we already validated in two places: + +- Our **original PrimeRL PoC** (`KavinKrishnan/prime-rl:kavink/mx-weight-broadcast`) used scratch buffers explicitly to triangulate KL drift — proved correctness when the receiver decides the layout. +- **John Thompson's NemoRL + Dynamo path** (May 22, 2026) uses `MxRefitReceiver.receive_weights_scratch` to handle vLLM's HF→fused param remapping via `stacked_params_mapping`. Validated end-to-end on GB300 RoCE at **380 Gbps** for an 8.82 GB / 399-tensor refit — same scratch path we'd reuse here, just with kernel-compile transforms instead of name remapping. + +### Receiver-side refit pseudocode + +```python +# inference/vllm/worker/nixl_mx.py — after PR #2389 graduates to MX clients + +def update_weights_via_mx(self, *, version, mx_config): + # 1. Discover same-rank trainer source via MxV2RefitReceiver + candidates = self._mx_receiver.discover_v2_sources( + model_name=self.model_name, + min_version=version, + same_rank_only=mx_config.same_rank_only, + compile_target_filter=self._target_kernel, # NEW — see §5 + ) + chosen = self._mx_receiver.pick_best_source(candidates) + + # 2. RDMA pull into scratch buffers (HF layout — whatever trainer published) + scratch = {} + for name, tensor in self._mx_receiver.receive_weights_scratch( + chosen.ref, + tensor_shapes=chosen.registry.tensor_shapes, # global shapes from v2 sidecar + target_tp_layout=self._tp_layout, # NEW — mixed-TP slice request + ): + scratch[name] = tensor + + # 3. Run the kernel-specific compile pass into live model params + self._compile_pass.apply( + scratch_buffers=scratch, + live_params=dict(self.model_runner.model.named_parameters()), + ) + + # 4. Tree fan-out republish (TensorHub pipeline replication) — unchanged + self._mx_receiver.publish_self_as_source(...) +``` + +The compile pass is a small dispatch: + +```python +# inference/vllm/worker/compile_passes.py — NEW + +class CompilePass(ABC): + target: str + def apply(self, scratch_buffers, live_params): ... + +class HFRaw(CompilePass): + target = "hf_raw" + def apply(self, scratch, live): + # Direct copy — names + dtypes match. fast path. + for name, t in scratch.items(): + live[name].data.copy_(t) + +class DeepGemmFP8(CompilePass): + target = "deep_gemm_fp8" + def apply(self, scratch, live): + # K-major scale interleave, fused gate_up_proj packing + for name, t in scratch.items(): + interleaved = deep_gemm_layout(t, block_size=128) + live[fused_name(name)].data.copy_(interleaved) + +class CutlassFP8(CompilePass): + target = "cutlass_fp8" + def apply(self, scratch, live): + # Column-major weights + epilogue scale tensors + ... +``` + +### Cost vs benefit + +**Cost**: extra GPU memory for scratch (~1× model size briefly per refit, freed after compile) + compile latency on the inference side (~50-200ms for FP8 passes on Qwen3-30B-A3B). + +**Benefit**: +- Trainer is **kernel-agnostic** — same bucket bytes serve any inference target. +- Mixed-fleet OK — different inference replicas can run different compile passes on the same trainer publish. +- Adding a new kernel = adding a new `CompilePass` subclass on the inference side. **Zero trainer change.** +- Cross-TP / cross-EP layouts decouple cleanly from compile (see §6). + +This is the same trade John Thompson made for Dynamo and the same trade we made in the original PoC. Empirically the 50-200ms compile latency is dwarfed by the 200ms RDMA pull anyway; total wall time is unchanged. + +--- + +## 4. New primitive: compile-target registry (extension to v2 shape registry) + +Our v2 shape registry today encodes: + +``` +TensorDescriptorV2: + name, global_shape, dtype, placement_kind, shard_axis, local_shard_range, + is_expert, expert_axis, owned_expert_ids +``` + +Add **two fields** at the registry level (not per-tensor — these describe how the *publisher* prepared the data): + +```diff + RegistryPayload (JSON in __mx_v2_meta__ sidecar): + version: int + trainer_world_layout: str # "fsdp:4,tp:1,pp:1,ep:1" ++ compile_target: str # "hf_raw" | "deep_gemm_fp8" | "cutlass_fp8" | ... ++ compile_metadata: dict # kernel-specific params: ++ # {block_size: 128, scale_dtype: "float8_e8m0", ++ # layout: "row_major", ...} + tensors: list[TensorDescriptorV2] +``` + +### Trainer side + +The trainer declares what it published — typically `"hf_raw"` post-optimizer-step: + +```python +publisher = MxV2TrainingPublisher( + agent_name=..., + world_layout=TrainerWorldLayout(fsdp_world_size=4, tp_world_size=1), + compile_target="hf_raw", # NEW — declarative, no inference-side dependency +) +``` + +Specialized trainers that *do* want to bake a compile pass in for perf can declare it: + +```python +publisher = MxV2TrainingPublisher( + ..., + compile_target="deep_gemm_fp8", + compile_metadata={"block_size": 128, "scale_dtype": "float8_e8m0"}, +) +``` + +— and then only same-target inference replicas will accept that publish. Mixed-target inference replicas skip it and look for an `hf_raw` source. + +### Receiver side + +`discover_v2_sources(compile_target_filter=...)` filters candidate trainers: + +```python +candidates = receiver.discover_v2_sources( + model_name=..., + min_version=N, + same_rank_only=True, + compile_target_filter={"hf_raw"}, # accept HF-raw only, run compile myself +) +``` + +Or accept any compatible target (if the receiver has a fallback compile pass): + +```python +candidates = receiver.discover_v2_sources( + model_name=..., + compile_target_filter={"hf_raw", "deep_gemm_fp8"}, # accept either +) +``` + +The picker's existing trainer-vs-replica + freshest-per-rank sort applies on the filtered set. + +### Why this is in the v2 shape registry, not a new transport + +It's metadata about the wire format. The shape registry already travels via the synthetic `__mx_v2_meta__` `TensorDescriptor` sidecar (proven by jthomson04's GB300 run + protected by PR #295's filter). Adding two more JSON fields is zero-cost on the wire and keeps the discovery contract in one place. + +--- + +## 5. Sharding and mixed-TP transfers + +The hardest case is **trainer TP/EP layout ≠ inference TP/EP layout**. The shape registry already encodes this on the publish side: + +``` +placement_kind: "SHARD" +shard_axis: 0 +local_shard_range: (start, end) # this rank's slice along shard_axis +``` + +What's missing is the **receiver's expression of what it wants**. + +### New receiver-side API + +```python +class TargetTPLayout: + """What slice of the global tensor THIS receiver needs.""" + world_size: int # inference TP world size + rank: int # this receiver's TP rank + shard_axis: int # which axis we're sharded on (model-dependent) + +receiver.receive_weights_scratch( + chosen.ref, + target_tp_layout=TargetTPLayout(world_size=8, rank=3, shard_axis=0), + tensor_shapes=chosen.registry.tensor_shapes, +) +``` + +The receiver computes its desired slice from `target_tp_layout` and the published `placement_kind`: + +| Publisher | Receiver | Result | +|---|---|---| +| `REPLICATE` (trainer TP=1) | TP=8, rank=3, axis=0 | Receiver requests rows `[3N/8 : 4N/8]` of the publisher's tensor | +| `SHARD(0)` trainer TP=4 rank=2, range `[N/2 : 3N/4]` | TP=8, rank=4, axis=0, requests `[N/2 : 5N/8]` | Receiver pulls from same-physical-rank trainer (R2), takes lower half of R2's shard | +| `SHARD(0)` trainer TP=4 rank=2, range `[N/2 : 3N/4]` | TP=8, rank=5, axis=0, requests `[5N/8 : 6N/8]` | Receiver pulls from same trainer rank (R2), takes upper half of R2's shard | +| `SHARD(0)` trainer TP=4 rank=1, range `[N/4 : N/2]` | TP=2, rank=0, axis=0, requests `[0 : N/2]` | Receiver pulls from **both** trainer R0 + R1, concatenates | + +For cases where one inference rank needs slices from multiple trainer ranks (last row), the receiver picks **N candidates** instead of one: + +```python +multi_source = receiver.discover_v2_sources_for_slice( + model_name=..., + target_slice=(start, end), + shard_axis=0, +) +# Returns one SourceRef per trainer rank whose shard overlaps target_slice +# Receiver does N parallel RDMA pulls, concatenates in scratch +``` + +### Mixed-EP for MoE + +Same machinery, on the expert axis. NemoRL v2 already does this for `owned_expert_ids`: + +```python +candidates = receiver.discover_v2_sources( + model_name=..., + target_expert_ids_per_layer={ + 5: {0, 1, 2, 3}, # this inference rank's owned experts in layer 5 + 6: {0, 1, 2, 3}, + }, +) +``` + +The picker matches candidates whose `expert_owner_per_rank` covers the needed experts (existing logic in `MxV2RefitReceiver.pick_best_source` — see `nemo_rl_v2.py`). For mixed-EP (trainer EP=4, inference EP=8), receivers may pull from multiple trainer ranks via the same `discover_v2_sources_for_slice` pattern. + +### Why this matters for PrimeRL specifically + +PI's `ExpertSlot` (in `slots.py`) and `build_topology()` (in `nixl_checkpoint_engine`-style code) already implement TP-matched and EP-matched pairing on the trainer side. They're computing `peer_chunk_descs` based on the publisher's known topology. **What's missing is the inverse**: when the inference layout differs from the trainer's, the receiver needs to express that. The compile-target registry + the slice-discovery API give it that vocabulary. + +--- + +## 6. MX client adoption — layered with this plan + +The two-phase migration from our previous review (`pensieve/RL/PrimeRL/07_pr_2389_review_comments.md` + the "Phase 1 / Phase 2" framing) is the foundation this plan builds on: + +### Phase 1 — surgical fixes against the `nixl_mx` in-tree code (drop-in patches) + +The 6 inline-comment fixes (line numbers verified against `nixl_mx` HEAD `79ea824d8`): + +1. Same-rank `add_remote_agent` filter in `transport_plan.py` +2. Freshest-per-rank dedup in `mx_rendezvous.py::wait_for_peers` +3. `HeartbeatThread` after `set_status(READY)` in `inference/.../nixl_mx.py` +4. Read timeout from config (not hardcoded 1200s) +5. MLA-guard for non-MLA models (`update_mla_absorbed_weights`) +6. HSDP barrier ordering in `trainer/.../nixl_mx.py` + +These land **before** this plan starts — closes the bug classes without architectural change. + +### Phase 2 — graduate the rendezvous half onto ModelExpress clients + +Delete `src/prime_rl/transport/mx_rendezvous.py` (~185 LOC, replicates functionality already in `modelexpress`). Replace with imports of `MxV2TrainingPublisher` and `MxV2RefitReceiver`. The in-tree `NixlAgentWrapper` + `Slot` + `TransportPlan` + `classic_cuda_pool` stay — that's prime-rl-specific data-plane specialization and shouldn't move. + +```diff +-from prime_rl.transport.mx_rendezvous import MxRendezvous ++from modelexpress import MxV2TrainingPublisher, MxV2RefitReceiver +``` + +This is what unblocks everything in §3-§5: + +- `MxV2TrainingPublisher` exposes the v2 sidecar registry that §4's `compile_target` extends. +- `MxV2RefitReceiver.receive_weights_scratch` is the proven path from John's Dynamo work (380 Gbps GB300). +- `discover_v2_sources(compile_target_filter=...)` is a small extension to the existing picker. +- Heartbeat / freshest-dedup / retention all come along for free — no separate Phase 1 work needed once Phase 2 lands. + +### Phase 3 (this plan) — compile-target registry + mixed-TP + +- Add `compile_target` + `compile_metadata` to v2 shape registry (~30 LOC in `shape_descriptors.py` + `nemo_rl_v2.py`). +- Add `compile_target_filter` to `discover_v2_sources` (~15 LOC). +- Add `target_tp_layout` + `discover_v2_sources_for_slice` to `MxV2RefitReceiver` (~120 LOC). +- Add `compile_passes/` module in `src/prime_rl/inference/vllm/worker/` with `HFRaw`, `DeepGemmFP8`, `CutlassFP8` passes (~300 LOC). Or in NemoRL `nemo_rl/models/generation/vllm/compile_passes/` — see §7. +- PI's `nixl_mx.py` inference worker calls into the right `CompilePass` based on `engine.kernel_target`. + +Total: ~450 LOC across MX + PrimeRL, all additive. + +--- + +## 7. What we borrow from John Thompson's NemoRL+Dynamo work + +Five specific pieces, all already proven on GB300 RoCE: + +### 7.1 `receive_weights_scratch` is the foundation + +John's path uses `MxRefitReceiver.receive_weights_scratch` because vLLM's `stacked_params_mapping` requires HF-named tensors that the receiver later passes to `model.load_weights()`. That's structurally identical to "trainer ships HF-raw, receiver compiles into kernel layout": + +```python +# John's existing flow (NemoRL + Dynamo + vLLM v1) +weights = list(receiver._receiver.receive_weights_scratch( + chosen.ref, + timeout_seconds=mx_config.timeout_seconds, + tensor_shapes=tensor_shapes, +)) +self.model_runner.model.load_weights(weights=weights) # vLLM does the HF→fused remap + +# Our extension (PrimeRL post-#2389 + kernel compile) +weights = list(receiver._receiver.receive_weights_scratch( + chosen.ref, + timeout_seconds=mx_config.timeout_seconds, + tensor_shapes=tensor_shapes, + target_tp_layout=self._target_tp_layout, # NEW +)) +self._compile_pass.apply(weights, live_params) # OUR compile dispatch +``` + +The mechanism is identical. Only the post-RDMA stage differs. + +### 7.2 The `worker_extension_cls` injection pattern is cleaner than subclassing + +John's `MxRefitWorkerExtension` (in `dynamo/vllm/mx_refit/extension.py`) is injected via vLLM v1's `parallel_config.worker_extension_cls`. The class has no `__init__`; vLLM merges its methods into the existing `Worker` via `__bases__`. State is stashed lazily on `self` with `_mx_` prefixed attribute names. + +PI's `NIXLMxWeightUpdateWorker` today **subclasses** `Worker` directly. The extension-class pattern would let the refit logic live in a sibling module without touching the inheritance chain — useful when we add the compile passes (§3) because those want to live in their own package. + +**Recommend**: when graduating PrimeRL to MX clients in Phase 2, also adopt the `worker_extension_cls` pattern for the inference worker. The two changes naturally compose. + +### 7.3 PR #295's sidecar filter is required for any new v2 metadata + +If we extend the v2 sidecar (§4) without keeping PR #295's filter in `MxRefitReceiver.receive_weights{,_scratch}`, the synthetic `__mx_v2_meta__` `TensorDescriptor` poisons `prep_xfer_dlist` again — same `NIXL_ERR_NOT_FOUND` John hit before May 22. **No new code needed; the filter is already in `kavink/nemo_rl_moe` HEAD `8594fd6`.** Just don't accidentally back it out. + +### 7.4 The `FORCE_RDMA=1` test mode catches this class of bug in loopback + +The v2 demo scripts (`scripts/v2_*_e2e_demo.py`) have a `FORCE_RDMA=1` env var (commit `e8e063b`) that pins `UCX_TLS` off `cuda_ipc` so intra-node loopback exercises the strict `rc_mlx5` descriptor-list validator. **Run every new compile-pass test under `FORCE_RDMA=1`** — otherwise we'll merge a sidecar / descriptor-list bug that doesn't show up until cross-node deploy. + +### 7.5 The compile-pass module probably belongs in NemoRL first, mirrored into prime-rl + +John's Dynamo path is the most mature target for testing new kernels (Qwen3-4B-Thinking GRPO smoke is already running cross-node at 380 Gbps). The compile passes themselves are framework-agnostic — they just take `(scratch_dict, live_params_dict)` and run torch ops. **Recommend**: + +1. Implement `compile_passes/` first in `modelexpress_client/python/modelexpress/compile_passes/` — framework-neutral, reusable. +2. NemoRL + Dynamo path adopts it via `MxRefitWorkerExtension._compile_pass = HFRaw()` (or `DeepGemmFP8()`). +3. Validate end-to-end on GB300 with cutlass + DeepGemm kernels. +4. Mirror into PrimeRL's inference worker after Phase 2 graduates them to MX clients. + +This sequence means we de-risk the compile-pass design on the path that's already shown working before we touch PrimeRL. Same play we ran for the v2 sidecar — designed in NemoRL, validated in NemoRL+Dynamo (John), then graduated to prime-rl. + +--- + +## 8. Implementation phases + +| Phase | Scope | Estimated LOC | Owner | +|---|---|---|---| +| **0** | Wait for #2389 to merge upstream | — | Matej | +| **1** | 6 surgical fixes against PI's `transport/*.py` + `inference/*.py` (closes bug classes) | ~100 LOC | Us — fast follow on Matej | +| **2** | Graduate `MxRendezvous` → `MxV2TrainingPublisher` / `MxV2RefitReceiver`; adopt `worker_extension_cls` pattern in inference worker | ~−400 LOC (PI's reimpl removed) + ~150 LOC import-and-call | Us | +| **3a** | Add `compile_target` + `compile_metadata` to v2 shape registry | ~30 LOC | Us | +| **3b** | Add `compile_target_filter` to `discover_v2_sources` | ~15 LOC | Us | +| **3c** | Add `target_tp_layout` + `discover_v2_sources_for_slice` to `MxV2RefitReceiver` | ~120 LOC | Us | +| **3d** | Implement `compile_passes/` (HFRaw, DeepGemmFP8, CutlassFP8) — in MX repo for reuse | ~300 LOC | Us | +| **3e** | Validate on NemoRL+Dynamo path (John's GB300 cluster) — Qwen3-4B-Thinking with DeepGemm and cutlass kernels both running on the same MX server | E2E | Us + John | +| **3f** | Mirror compile-pass dispatch into PI's `inference/vllm/worker/nixl_mx.py` | ~50 LOC | Us — PR back to PI | +| **4** | Mixed-TP / mixed-EP slice discovery wired end-to-end (multi-source RDMA pulls) | ~200 LOC | Us — separable from Phase 3 | + +**Phases 0-1** are fully sequenced (must wait for upstream + apply surgical fixes). **Phases 2 onward** can run in parallel if we're willing to maintain a `kavink/post-2389-*` branch off PI's main + a follow-on PR per phase. + +--- + +## 9. Open questions + +1. **Does PI want the compile passes in their tree, or in MX?** If MX, they import a pluggable `CompilePassRegistry`. If their tree, the kernel ecosystem stays close to their Slot system (which already does fp8_blockwise). My lean: **MX**, because Dynamo + NemoRL also want them, and PI's per-Slot conversion stays for the trainer-side path when teams opt into "publish post-compile". + +2. **Does cutlass-FP8 work on inference-side compute?** Compile pass needs ~200ms of CUDA time. If the inference engine is mid-rollout when the refit arrives, we either pause and run the compile or queue it for the next "between rollouts" window. PrimeRL's current orchestrator does the latter; this plan inherits. + +3. **How do we handle trainers that publish post-compile (Matej's current path)?** Their `compile_target = "deep_gemm_fp8"`; receivers either accept it directly (fast path) or reject and look for `hf_raw`. Mixed fleets get clean error messages, not corrupt weights. + +4. **Mixed-TP across nodes — what's the bandwidth math?** Trainer TP=4 ↔ inference TP=8 means each inference rank pulls from 1-2 trainer ranks. For Qwen3-30B-A3B on GB200 (~30 GB / 4 trainer ranks = 7.5 GB/rank), an inference rank pulling 2× 4 GB slices is well within NIC budget. For larger EP layouts where one inference rank needs experts from N>2 trainer ranks, fan-in becomes interesting — that's where pipeline replication (TensorHub) and rollouts-as-replicas pay off. + +5. **What's the deprecation story for trainer-side compile?** We don't deprecate it — Matej's path stays valid for teams that want zero inference-side latency. The `compile_target` field is just informational; receivers filter on it. + +6. **Should the compile pass run before or after `update_mla_absorbed_weights`?** After. MLA absorption operates on live params; compile runs first so live params are in the right layout when MLA absorption runs. + +--- + +## 10. Component view + sequence diagram + +```mermaid +flowchart TB + subgraph trainer["Trainer side (after Phase 2)"] + TBcast["NIXLMxWeightBroadcast
(PI's lifecycle wrapper)"] + TSlots["Slots: Sharded · Gathered · Expert
(PI's data plane, kept)"] + TPlan["TransportPlan
(PI's, kept)"] + TPub["MxV2TrainingPublisher
(NEW — replaces MxRendezvous)"] + TAgent["NixlAgentWrapper
(PI's, kept)"] + TBcast --> TSlots + TBcast --> TPlan + TBcast --> TPub + TSlots --> TAgent + TPlan --> TAgent + TPub -. publishes registry incl.
compile_target=hf_raw .-> TPlan + end + + subgraph mx["MX control plane (unchanged)"] + MXSVR[("MX Server · gRPC + Redis
shape registry, compile-target
filter, tree fan-out catalog")] + end + + subgraph inf["Inference side (after Phase 2+3)"] + IRec["MxV2RefitReceiver
(NEW — replaces ad-hoc rendezvous)"] + IScratch["receive_weights_scratch
+ target_tp_layout
(extended John's path)"] + ICompile["CompilePass dispatch (NEW)
HFRaw · DeepGemmFP8 · CutlassFP8"] + ILive["vLLM model.named_parameters()
(live params — kernel-specific layout)"] + IRec --> IScratch + IScratch --> ICompile + ICompile --> ILive + end + + TPub <-.->|"publish_metadata (incl. compile_target)
set_status · update_status"| MXSVR + IRec <-.->|"discover_v2_sources(compile_target_filter=...)
list_sources · get_metadata"| MXSVR + TAgent <==>|"one-sided RDMA WRITE
UCX rc_mlx5 / RoCE
(HF-raw bytes, post-cast pre-compile)"| IScratch + + style mx fill:#fec,stroke:#963 + style MXSVR fill:#fec,stroke:#963,stroke-width:2px + style TPub fill:#cce,stroke:#33c,stroke-width:2px + style IRec fill:#cce,stroke:#33c,stroke-width:2px + style ICompile fill:#cfc,stroke:#363,stroke-width:2px + style ILive fill:#fcc,stroke:#c33 +``` + +### One refit cycle, after this plan lands + +```mermaid +sequenceDiagram + autonumber + participant O as Orchestrator + participant T as Trainer (NIXLMxWeightBroadcast) + participant MX as MX Server + participant I as Inference worker + participant C as CompilePass + + Note over T,I: BOOT (once per refit run) + O->>I: POST /init_nixl_mx (host, port, rank, kernel_target=deep_gemm_fp8) + I->>I: register live params with NIXL (PI's data plane) + I->>MX: publish_metadata(role=inference, kernel_target=deep_gemm_fp8) + I->>MX: update_status(READY) + + Note over T,I: PER REFIT STEP + O->>I: POST /update_weights + I->>MX: wait_for_all_peers_ready(role=trainer, READY) + + T->>T: lazy_init slots; per-rank scratch fill from state_dict
(NO trainer-side cutlass — bytes are HF-raw) + T->>MX: publish_metadata(role=trainer, compile_target="hf_raw",
compile_metadata={...}, shape registry per tensor) + T->>MX: update_status(INITIALIZING → READY) + + I->>MX: discover_v2_sources(model, min_version=N,
compile_target_filter={"hf_raw"},
target_tp_layout=TP=8 rank=3) + MX-->>I: candidates = [trainer R0 (covers requested slice)] + I->>I: pick_best_source + + Note right of I: SCRATCH PATH (from John's NemoRL+Dynamo work) + I->>T: NIXL one-sided RDMA WRITE → scratch buffers
(HF-raw layout, ~380 Gbps on GB300 RoCE) + I->>I: torch.cuda.synchronize + + Note right of I: COMPILE PASS (new — runs inference-side, ~50-200ms) + I->>C: apply(scratch_buffers, live_params)
e.g. DeepGemm scale interleave, fused gate_up_proj pack + C-->>I: live params updated in DeepGemm-friendly layout + + I->>I: update_mla_absorbed_weights (if MLA model) + I->>MX: publish_self_as_source(role=inference_replica, version=N)
(tree fan-out for next refit) + + I-->>O: 200 OK + O->>O: scheduler advances · next rollout uses new weights +``` + +--- + +## 11. Cross-references + +ModelExpress design docs (NVIDIA-authored, for context on the client surface this plan adopts): + +- [`docs/RL/PRIMERL_MX_OVERVIEW.md`](https://github.com/ai-dynamo/modelexpress/blob/main/docs/RL/PRIMERL_MX_OVERVIEW.md) — the foundational prime-rl × MX integration design (catalog + star wiring story). +- [`docs/RL/NEMORL_MX_OVERVIEW.md`](https://github.com/ai-dynamo/modelexpress/blob/kavink/nemo_rl_moe/docs/RL/NEMORL_MX_OVERVIEW.md) — the v2 design (rank-to-rank, tree fan-out, expert filter, shape registry) that this plan extends with the compile-target axis. +- [`docs/RL/VERL_MX_OVERVIEW.md`](https://github.com/ai-dynamo/modelexpress/blob/main/docs/RL/VERL_MX_OVERVIEW.md) — the verl `MxCheckpointEngine` integration; sibling adopter of the same MX clients. + +Upstream branches this plan refers to: + +- ModelExpress branch [`kavink/nemo_rl_moe`](https://github.com/ai-dynamo/modelexpress/tree/kavink/nemo_rl_moe) — the v2 client surface (`MxV2TrainingPublisher`, `MxV2RefitReceiver`), `shape_descriptors`, sidecar transport, and PR #295's sidecar filter. The MX-side dependency that Phase 2 imports. +- ModelExpress PR [#295](https://github.com/ai-dynamo/modelexpress/pull/295) — synthetic-sidecar `TensorDescriptor` filter in `MxRefitReceiver`. Required for any v2 metadata extension (including this plan's `compile_target` field) to survive `prep_xfer_dlist` validation on cross-node RoCE. Already merged into `kavink/nemo_rl_moe`; just don't back it out. +- NemoRL × Dynamo branches (John Thompson, NVIDIA): [`KavinKrishnan/RL:kavink/mx_integration`](https://github.com/KavinKrishnan/RL/tree/kavink/mx_integration) (NeMo-RL side) + the Dynamo-side companion. Validated at 380 Gbps on GB300 RoCE for an 8.82 GB / 399-tensor refit (Qwen3-4B-Thinking GRPO smoke). + +NVIDIA-internal context (not necessary for upstream review, listed for our own bookkeeping): + +- The 6 inline review comments + summary message we have queued for #2389 — verified line numbers against HEAD `dabaa19f5` (still applicable on `79ea824d8` after I re-checked May 27). Phase 1 of this rollout. +- Current state of #2389: +10 commits since `dabaa19f5` (4× conversion-cast polish, 4× DeepGemm env-var hygiene, 2× config/import fixes). None touched the 6 flagged lines. From 7feee0dbcc420a093e2fef5a3e9e7ceca898a0c5 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 27 May 2026 08:06:35 -0700 Subject: [PATCH 02/18] docs(proposals): scrub stray internal-pensieve reference --- docs/proposals/post-pr2389-kernel-compile-plan.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/post-pr2389-kernel-compile-plan.md b/docs/proposals/post-pr2389-kernel-compile-plan.md index 18f15f6f57..8eaad8cabc 100644 --- a/docs/proposals/post-pr2389-kernel-compile-plan.md +++ b/docs/proposals/post-pr2389-kernel-compile-plan.md @@ -271,7 +271,7 @@ PI's `ExpertSlot` (in `slots.py`) and `build_topology()` (in `nixl_checkpoint_en ## 6. MX client adoption — layered with this plan -The two-phase migration from our previous review (`pensieve/RL/PrimeRL/07_pr_2389_review_comments.md` + the "Phase 1 / Phase 2" framing) is the foundation this plan builds on: +The two-phase migration from our earlier review of #2389 (Phase 1 surgical / Phase 2 client adoption) is the foundation this plan builds on: ### Phase 1 — surgical fixes against the `nixl_mx` in-tree code (drop-in patches) From e958f1cb5a2e6a359b067f96bf75cd6bda65f98c Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 27 May 2026 14:27:39 -0700 Subject: [PATCH 03/18] =?UTF-8?q?feat(transport/mx):=20Phase-2=20=E2=80=94?= =?UTF-8?q?=20heartbeat=20+=20freshest-per-rank=20dedup=20+=20same-rank=20?= =?UTF-8?q?filter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Codifies the two runtime patches we applied on GB200 to unblock Qwen3-30B-A3B bring-up against PR #2389, plus a third surgical fix (heartbeat) that closes the stale-READY-after-restart class of bugs. The three changes are intentionally separable from #2389 and additive to the existing rendezvous API: 1. **HeartbeatThread on publish()**: When publish() succeeds we start a modelexpress.metadata.heartbeat.HeartbeatThread keyed on (mx_source_id, worker_id, worker_rank). The MX server's reaper then transitions crashed workers to STALE on its own. New `enable_heartbeat: bool = True` field on the dataclass to opt out for tests / one-shots. New `close()` method to stop the thread on graceful shutdown. 2. **Freshest-per-(role, rank) dedup**: New module-level helper `_freshest_per_rank(instances, *, metas)` keeps only the entry with the largest `updated_at` per `worker_rank`. wait_for_peers() and wait_for_all_peers_ready() default to using it; instances missing from `metas` get ts=0 and lose to anything timestamped. Was the second GB200 patch (stale READY from previous run beat fresh READY from the restarted trainer). 3. **same_rank_only filter**: New `_filter_same_rank(instances, *, rank)` helper. wait_for_peers() and wait_for_all_peers_ready() now accept `same_rank_only: bool = False` (off by default for back-compat); when set, only peers with `worker_rank == self.rank` are returned. Required on GCP GB200's multi-NIC fabric where cross-subnet routing fails. `_collect_updated_at(instances)` does the GetMetadata fan-out used by the dedup; failures are mapped to ts=0 so the picker doesn't crash on partial catalog state. Unit tests added under tests/unit/transport/test_mx_rendezvous_phase2.py (11 tests, all green; direct-loads mx_rendezvous.py to bypass prime_rl.transport.__init__'s heavy import chain so the suite runs with only `modelexpress` installed — no docker-compose required). Sub-tests cover: - _filter_same_rank: rank match - _freshest_per_rank: largest updated_at wins; missing-updated_at loses to known-updated_at; lone unknown is kept; stable rank-order - publish() spawns HeartbeatThread with correct kwargs (worker_rank, mx_source_id, worker_id, nixl_manager=None) - close() stops the thread; idempotent - enable_heartbeat=False skips the thread entirely - publish() swallows heartbeat start failures (broken heartbeat must not break rendezvous) - _collect_updated_at returns 0 on RPC failure, 0 on not_found, real value on success No breaking changes to the existing tests/unit/transport/test_mx_rendezvous.py suite (those are integration tests against a docker-compose'd MX server). --- src/prime_rl/transport/mx_rendezvous.py | 192 +++++++++++++++-- .../transport/test_mx_rendezvous_phase2.py | 197 ++++++++++++++++++ 2 files changed, 373 insertions(+), 16 deletions(-) create mode 100644 tests/unit/transport/test_mx_rendezvous_phase2.py diff --git a/src/prime_rl/transport/mx_rendezvous.py b/src/prime_rl/transport/mx_rendezvous.py index 817c1ab7dc..1c85324f80 100644 --- a/src/prime_rl/transport/mx_rendezvous.py +++ b/src/prime_rl/transport/mx_rendezvous.py @@ -7,20 +7,79 @@ (role baked into ``SourceIdentity.extra_parameters`` so trainer/inference hash to different ``mx_source_id``s) and the polling loop, and delegates all gRPC to ``modelexpress.MxClient``. + +Phase-2 fixes (post-#2389) baked in: + +- **Heartbeat**: spawning :class:`MxRendezvous` starts a background + :class:`HeartbeatThread` on ``publish()`` so the MX server's reaper can + detect crashed workers and mark them ``STALE``. Crashed workers were + leaving permanent ``READY`` rows that broke restarts on GB200. +- **Freshest-per-(role, rank) dedup**: when multiple entries for the same + (role, rank) live in the catalog (e.g. after a partial pod restart), + callers see only the most recently updated one. This is the second of + the two GB200 runtime patches. +- **Same-rank-only filter**: optional ``same_rank_only=True`` on the wait + methods restricts results to peers with ``worker_rank == self.rank``, + closing the cross-subnet full-mesh path that fails on GCP multi-NIC + RDMA fabrics. Off by default; the caller opts in. """ from __future__ import annotations +import logging import time import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Iterable, Literal from modelexpress import p2p_pb2 from modelexpress.client import MxClient +from modelexpress.metadata.heartbeat import HeartbeatThread Role = Literal["trainer", "inference", "orchestrator"] +_log = logging.getLogger("prime_rl.transport.mx_rendezvous") + + +def _freshest_per_rank( + instances: Iterable[p2p_pb2.SourceInstanceRef], + *, + metas: dict[str, int], +) -> list[p2p_pb2.SourceInstanceRef]: + """Dedup peers by ``worker_rank``, keeping the one with the largest + ``updated_at`` from ``metas``. + + ``metas`` maps ``worker_id`` → ``updated_at`` (ms-epoch as reported by + the MX server). Instances whose ``worker_id`` is missing from ``metas`` + are kept (we err on the side of "visible but not freshness-known"). + + This is the Phase-2 codification of the runtime patch we applied on + GB200: the prime-rl trainer's NIXL agent rotated ``mx_source_id`` on + restart, leaving a stale ``READY`` entry at the same ``worker_rank``; + receivers picked the stale one and got ``NIXL_ERR_NOT_ALLOWED`` when + they tried to ``add_remote_agent``. + """ + by_rank: dict[int, tuple[int, p2p_pb2.SourceInstanceRef]] = {} + for inst in instances: + ts = metas.get(inst.worker_id, 0) + cur = by_rank.get(inst.worker_rank) + if cur is None or ts > cur[0]: + by_rank[inst.worker_rank] = (ts, inst) + return [v[1] for _, v in sorted(by_rank.items())] + + +def _filter_same_rank( + instances: Iterable[p2p_pb2.SourceInstanceRef], *, rank: int +) -> list[p2p_pb2.SourceInstanceRef]: + """Keep only peers whose ``worker_rank == rank``. + + The cross-subnet full-mesh routing path failed on GCP GB200's multi-NIC + fabric — each rank has its own IB subnet, so trainer rank N can only + safely peer with inference rank N. Filtering at the rendezvous layer + prevents the broken connections from ever being attempted. + """ + return [inst for inst in instances if inst.worker_rank == rank] + @dataclass class MxRendezvous: @@ -47,11 +106,13 @@ class MxRendezvous: peer_world_size: int model_name: str worker_id: str = "" + enable_heartbeat: bool = True def __post_init__(self) -> None: if not self.worker_id: self.worker_id = str(uuid.uuid4()) self._mx_source_id: str | None = None + self._heartbeat: HeartbeatThread | None = None @property def peer_role(self) -> Role: @@ -84,21 +145,58 @@ def publish( nixl_metadata: bytes = b"", tensors: Iterable[p2p_pb2.TensorDescriptor] = (), ) -> str: - """Publish this worker's metadata. Returns the assigned ``mx_source_id``.""" + """Publish this worker's metadata. Returns the assigned ``mx_source_id``. + + Side effect (Phase 2): if ``enable_heartbeat`` is True, a + :class:`HeartbeatThread` is started after a successful publish so + the MX server's reaper can detect liveness. Heartbeat is idempotent + — calling ``publish()`` again on the same instance is a no-op for + the heartbeat (the existing thread keeps running). + """ worker = p2p_pb2.WorkerMetadata( worker_rank=self.rank, nixl_metadata=nixl_metadata, tensors=list(tensors), ) self._mx_source_id = self.client.publish_metadata(self._identity(self.role), worker, self.worker_id) + + if self.enable_heartbeat and self._heartbeat is None: + try: + self._heartbeat = HeartbeatThread( + mx_client=self.client, + mx_source_id=self._mx_source_id, + worker_id=self.worker_id, + worker_rank=self.rank, + nixl_manager=None, # prime-rl drives NIXL outside MX's manager + ) + self._heartbeat.start() + except Exception as e: # noqa: BLE001 + _log.warning( + "MxRendezvous: failed to start HeartbeatThread (role=%s rank=%s): %s", + self.role, + self.rank, + e, + ) + return self._mx_source_id + def close(self) -> None: + """Stop the heartbeat thread. Safe to call multiple times.""" + if self._heartbeat is not None: + try: + self._heartbeat.stop() + except Exception as e: # noqa: BLE001 + _log.warning("MxRendezvous: heartbeat.stop() failed: %s", e) + self._heartbeat = None + def wait_for_peers( self, *, status: int | None = None, timeout: float = 1200.0, poll_interval: float = 1.0, + same_rank_only: bool = False, + dedup_freshest_per_rank: bool = True, ) -> list[p2p_pb2.SourceInstanceRef]: """Block until ``peer_world_size`` peers of the counterpart role are visible. @@ -106,32 +204,74 @@ def wait_for_peers( status: If set, only count peers in this :class:`p2p_pb2.SourceStatus`. timeout: Wall-clock seconds to wait before raising :class:`TimeoutError`. poll_interval: Seconds between ``ListSources`` polls. + same_rank_only: If True, only return peers whose ``worker_rank`` + equals this rendezvous's own rank. Required on GB200's + multi-NIC fabric where cross-subnet routing fails. Off by + default to preserve the pre-Phase-2 single-NIC behaviour. + dedup_freshest_per_rank: If True (default), keep only the + freshest ``SourceInstanceRef`` per ``worker_rank``. This + neutralises the stale-READY-after-restart bug we caught on + GB200. Pass ``False`` to keep all duplicates (e.g. debug). """ - import logging - - _log = logging.getLogger("prime_rl.transport.mx_rendezvous") deadline = time.monotonic() + timeout peer_id = self._identity(self.peer_role) _logged = False while True: resp = self.client.list_sources(peer_id, status_filter=status) + kept = list(resp.instances) + if same_rank_only: + kept = _filter_same_rank(kept, rank=self.rank) + if dedup_freshest_per_rank and kept: + kept = _freshest_per_rank( + kept, metas=self._collect_updated_at(kept) + ) if not _logged: all_resp = self.client.list_sources(peer_id) _log.info( - f"wait_for_peers: role={self.peer_role} need={self.peer_world_size} " - f"found_with_status={len(resp.instances)} found_any={len(all_resp.instances)} " - f"status_filter={status} model={peer_id.model_name}" + "wait_for_peers: role=%s need=%s found_with_status=%s found_any=%s " + "post_filter=%s status_filter=%s model=%s same_rank_only=%s", + self.peer_role, + self.peer_world_size, + len(resp.instances), + len(all_resp.instances), + len(kept), + status, + peer_id.model_name, + same_rank_only, ) _logged = True - if len(resp.instances) >= self.peer_world_size: - return list(resp.instances) + if len(kept) >= self.peer_world_size: + return kept if time.monotonic() >= deadline: raise TimeoutError( f"timed out after {timeout}s waiting for {self.peer_world_size} " - f"{self.peer_role!r} peers (saw {len(resp.instances)})" + f"{self.peer_role!r} peers (saw {len(kept)} after filters; " + f"{len(resp.instances)} raw)" ) time.sleep(poll_interval) + def _collect_updated_at( + self, instances: Iterable[p2p_pb2.SourceInstanceRef] + ) -> dict[str, int]: + """Fetch ``updated_at`` per peer in one round of GetMetadata calls. + + Used by the freshest-per-rank dedup. Failures (missing worker, RPC + errors) are mapped to ``0`` so the stale entries lose to anything + with a real timestamp. + """ + out: dict[str, int] = {} + for inst in instances: + try: + resp = self.client.get_metadata(inst.mx_source_id, inst.worker_id) + except Exception: # noqa: BLE001 + out[inst.worker_id] = 0 + continue + if not getattr(resp, "found", False): + out[inst.worker_id] = 0 + continue + out[inst.worker_id] = int(getattr(resp.worker, "updated_at", 0) or 0) + return out + def wait_for_all_peers_ready( self, *, @@ -139,6 +279,8 @@ def wait_for_all_peers_ready( status: int = p2p_pb2.SOURCE_STATUS_READY, timeout: float = 1200.0, poll_interval: float = 0.05, + same_rank_only: bool = False, + dedup_freshest_per_rank: bool = True, ) -> list[p2p_pb2.SourceInstanceRef]: """Discover peer count from MX, then block until ALL of them reach ``status``. @@ -147,27 +289,45 @@ def wait_for_all_peers_ready( entries exist in MX (any status) and uses that count as the target. Each side publishes one entry per rank, so the count equals the peer's world size — no config plumbing needed. + + Phase-2 additions (``same_rank_only`` and ``dedup_freshest_per_rank``) + behave identically to :meth:`wait_for_peers`. """ target_role = role or self.peer_role peer_id = self._identity(target_role) deadline = time.monotonic() + timeout + def _apply_filters( + insts: list[p2p_pb2.SourceInstanceRef], + ) -> list[p2p_pb2.SourceInstanceRef]: + kept = insts + if same_rank_only: + kept = _filter_same_rank(kept, rank=self.rank) + if dedup_freshest_per_rank and kept: + kept = _freshest_per_rank( + kept, metas=self._collect_updated_at(kept) + ) + return kept + peer_count = 0 while peer_count == 0: - peer_count = len(self.client.list_sources(peer_id).instances) + insts = list(self.client.list_sources(peer_id).instances) + kept = _apply_filters(insts) + peer_count = len(kept) if peer_count == 0: if time.monotonic() >= deadline: raise TimeoutError(f"timed out waiting for {target_role!r} peers to appear in MX") time.sleep(poll_interval) while True: - matched = self.client.list_sources(peer_id, status_filter=status) - if len(matched.instances) >= peer_count: - return list(matched.instances) + insts = list(self.client.list_sources(peer_id, status_filter=status).instances) + kept = _apply_filters(insts) + if len(kept) >= peer_count: + return kept if time.monotonic() >= deadline: raise TimeoutError( f"timed out after {timeout}s waiting for {peer_count} " - f"{target_role!r} peers to reach status {status} (saw {len(matched.instances)})" + f"{target_role!r} peers to reach status {status} (saw {len(kept)})" ) time.sleep(poll_interval) diff --git a/tests/unit/transport/test_mx_rendezvous_phase2.py b/tests/unit/transport/test_mx_rendezvous_phase2.py new file mode 100644 index 0000000000..d5a79f8bc8 --- /dev/null +++ b/tests/unit/transport/test_mx_rendezvous_phase2.py @@ -0,0 +1,197 @@ +"""Phase-2 unit tests for MxRendezvous helpers — no docker-compose required. + +Direct-loads mx_rendezvous.py to bypass prime_rl.transport's heavy +__init__.py import chain. +""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from dataclasses import dataclass +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +_HERE = Path(__file__).resolve().parent +_REPO_ROOT = _HERE.parent.parent.parent +_MOD_PATH = _REPO_ROOT / "src" / "prime_rl" / "transport" / "mx_rendezvous.py" + + +@pytest.fixture(scope="module") +def rdzmod(): + if "prime_rl" not in sys.modules: + pkg = types.ModuleType("prime_rl") + pkg.__path__ = [str(_REPO_ROOT / "src" / "prime_rl")] + sys.modules["prime_rl"] = pkg + if "prime_rl.transport" not in sys.modules: + sub = types.ModuleType("prime_rl.transport") + sub.__path__ = [str(_REPO_ROOT / "src" / "prime_rl" / "transport")] + sys.modules["prime_rl.transport"] = sub + + spec = importlib.util.spec_from_file_location( + "prime_rl.transport.mx_rendezvous", _MOD_PATH + ) + mod = importlib.util.module_from_spec(spec) + sys.modules["prime_rl.transport.mx_rendezvous"] = mod + spec.loader.exec_module(mod) + return mod + + +@dataclass +class _FakeInst: + worker_id: str + worker_rank: int + mx_source_id: str = "fake-source" + + +def test_filter_same_rank_keeps_only_matching(rdzmod): + insts = [_FakeInst("w0", 0), _FakeInst("w1", 1), _FakeInst("w2", 2), _FakeInst("w1b", 1)] + kept = rdzmod._filter_same_rank(insts, rank=1) + assert [i.worker_id for i in kept] == ["w1", "w1b"] + + +def test_freshest_per_rank_keeps_largest_updated_at(rdzmod): + insts = [_FakeInst("w0_old", 0), _FakeInst("w0_new", 0), _FakeInst("w1_only", 1), _FakeInst("w0_mid", 0)] + metas = {"w0_old": 100, "w0_new": 300, "w1_only": 200, "w0_mid": 200} + kept = rdzmod._freshest_per_rank(insts, metas=metas) + by_rank = {i.worker_rank: i.worker_id for i in kept} + assert by_rank == {0: "w0_new", 1: "w1_only"} + + +def test_freshest_per_rank_handles_missing_updated_at(rdzmod): + insts = [_FakeInst("ghost", 5), _FakeInst("known", 5)] + metas = {"known": 1} + kept = rdzmod._freshest_per_rank(insts, metas=metas) + assert len(kept) == 1 + assert kept[0].worker_id == "known" + + +def test_freshest_per_rank_returns_lone_unknown_when_no_rival(rdzmod): + insts = [_FakeInst("only_ghost", 7)] + kept = rdzmod._freshest_per_rank(insts, metas={}) + assert len(kept) == 1 + assert kept[0].worker_id == "only_ghost" + + +def test_freshest_per_rank_sorted_by_rank(rdzmod): + insts = [_FakeInst("w2", 2), _FakeInst("w0", 0), _FakeInst("w1", 1)] + kept = rdzmod._freshest_per_rank(insts, metas={"w0": 1, "w1": 1, "w2": 1}) + assert [i.worker_rank for i in kept] == [0, 1, 2] + + +def test_publish_starts_and_close_stops_heartbeat(rdzmod, monkeypatch): + fake_client = MagicMock() + fake_client.publish_metadata.return_value = "mx-source-xyz" + + spawned = [] + + class _FakeHB: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.started = False + self.stopped = False + spawned.append(self) + + def start(self): + self.started = True + + def stop(self): + self.stopped = True + + monkeypatch.setattr(rdzmod, "HeartbeatThread", _FakeHB) + + rdz = rdzmod.MxRendezvous(client=fake_client, role="trainer", rank=2, peer_world_size=4, model_name="m") + sid = rdz.publish(nixl_metadata=b"x", tensors=[]) + assert sid == "mx-source-xyz" + assert len(spawned) == 1 + hb = spawned[0] + assert hb.started + assert hb.kwargs["worker_rank"] == 2 + assert hb.kwargs["mx_source_id"] == "mx-source-xyz" + assert hb.kwargs["nixl_manager"] is None + + rdz.close() + assert hb.stopped + rdz.close() + + +def test_publish_skips_heartbeat_when_disabled(rdzmod, monkeypatch): + fake_client = MagicMock() + fake_client.publish_metadata.return_value = "sid" + + spawned = [] + + class _FakeHB: + def __init__(self, **kwargs): + spawned.append(self) + + def start(self): + pass + + monkeypatch.setattr(rdzmod, "HeartbeatThread", _FakeHB) + rdz = rdzmod.MxRendezvous( + client=fake_client, role="inference", rank=0, peer_world_size=1, model_name="m", enable_heartbeat=False + ) + rdz.publish() + assert spawned == [] + + +def test_publish_swallows_heartbeat_start_failure(rdzmod, monkeypatch): + fake_client = MagicMock() + fake_client.publish_metadata.return_value = "sid" + + class _BrokenHB: + def __init__(self, **kwargs): + raise RuntimeError("can't allocate thread") + + monkeypatch.setattr(rdzmod, "HeartbeatThread", _BrokenHB) + rdz = rdzmod.MxRendezvous(client=fake_client, role="trainer", rank=0, peer_world_size=1, model_name="m") + sid = rdz.publish() + assert sid == "sid" + assert rdz._heartbeat is None + + +def test_collect_updated_at_returns_zero_on_failure(rdzmod): + fake_client = MagicMock() + fake_client.get_metadata.side_effect = RuntimeError("boom") + rdz = rdzmod.MxRendezvous( + client=fake_client, role="trainer", rank=0, peer_world_size=1, model_name="m", enable_heartbeat=False + ) + out = rdz._collect_updated_at([_FakeInst("a", 0), _FakeInst("b", 1)]) + assert out == {"a": 0, "b": 0} + + +def test_collect_updated_at_returns_zero_on_not_found(rdzmod): + fake_client = MagicMock() + + class _Resp: + found = False + worker = MagicMock(updated_at=0) + + fake_client.get_metadata.return_value = _Resp() + rdz = rdzmod.MxRendezvous( + client=fake_client, role="trainer", rank=0, peer_world_size=1, model_name="m", enable_heartbeat=False + ) + out = rdz._collect_updated_at([_FakeInst("x", 0)]) + assert out == {"x": 0} + + +def test_collect_updated_at_returns_real_value(rdzmod): + fake_client = MagicMock() + + class _Resp: + found = True + + def __init__(self): + self.worker = MagicMock(updated_at=42) + + fake_client.get_metadata.return_value = _Resp() + rdz = rdzmod.MxRendezvous( + client=fake_client, role="trainer", rank=0, peer_world_size=1, model_name="m", enable_heartbeat=False + ) + out = rdz._collect_updated_at([_FakeInst("x", 0)]) + assert out == {"x": 42} From 08058331f95a55d24eb344b4b7c7d256b2ccd4ff Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Thu, 28 May 2026 10:40:07 -0700 Subject: [PATCH 04/18] feat(conversions): cutlass FP8 e4m3 per-channel + compile_target/metadata tagging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends prime_rl/trainer/models/conversions/ to address the live coworker complaint that prime-rl breaks on Qwen3-MoE with cutlass kernels — the registry currently has only `bf16_cast` and `fp8_128x128`; anything else raises NotImplementedError, and there's no compile_target tag on the publish so wrong-target receivers silently misinterpret bytes. This is the trainer-side half of the design fix; the receiver-side filtering API is already shipped in modelexpress as PR #349 (Phase 3a/3b on kavink/post-2389-phase3-4). Once Phase 2 graduation lands on KavinKrishnan/prime-rl#1, the MxV2TrainingPublisher will read each tensor's resolved ConversionEntry.compile_target + compile_metadata and tag the v2 publish so receivers can filter via discover_v2_sources(compile_target_filter=…, required_compile_metadata=…). What lands: ConversionEntry gains two new fields with safe defaults: - compile_target: str = "hf_raw" - compile_metadata: dict[str, Any] = {} register(...) takes them as kwargs; existing call sites are unchanged. Mirrors the constants in modelexpress.shape_descriptors (Phase 3a) but without a hard import dep in either direction — both repos keep their own canonical string set. select_default_conversion is refactored to a table-driven design. The old if/else chain is replaced by _DEFAULT_RULES: list[(predicate, name)] which the resolver walks in order. Adding a new kernel = adding one row via register_default_rule(predicate, name) from the kernel's own module on import. A predicate that raises on a malformed config is treated as "doesn't match" and skipped, keeping the resolver robust to model-card weirdness without forcing every predicate to be defensive. The AutoConfig import is deferred into the function body so the registry loads without requiring `transformers` (the registry is imported by tests + tooling that have no HF download capability). Existing entries get their tags retroactively: - bf16_cast / fp32_cast: compile_target="hf_raw" - fp8_128x128: compile_target="deep_gemm_fp8" + metadata{block_size: [128,128], scale_layout:"blockwise", dtype:"e4m3"} New conversion: cutlass_fp8_e4m3_per_channel - One scalar scale per output row (vs DeepGemm's per-128x128-block). - 2D dispatch: (out, in) weight → (out,) scale. 3D dispatch: (E, out, in) stacked MoE → (E, out) scale. - compile_target="cutlass_fp8", compile_metadata={dtype:"e4m3", scale_layout:"per_channel", scale_axis:-1, activation_scheme: "dynamic"} — matches cutlass scaled_mm + vLLM's native FP8 path. - Two default-resolver predicates: * quant_method="fp8" + quant_format="cutlass" (explicit) * quant_method="fp8" + weight_block_size=None + activation_scheme="dynamic" (the vLLM-published convention) Both predicates run AFTER the deep-gemm rule, so models with block_size=[128,128] AND activation_scheme="dynamic" still resolve to fp8_128x128 (regression-tested). Per-channel helpers in trainer/models/fp8.py: - fp8_per_channel_quantize(weight) → (q_e4m3, scale_f32). Handles 2D and 3D via the same code path; reduction over the innermost axis. - fp8_per_channel_quantize_into(weight, out, sf) — writes into preallocated buffers, matches the convention of fp8_block_quantize. Tests: 19/19 green via direct-load + transformers stub. Categories: - Per-channel quantize: 2D shape, 3D shape, 1D rejected, bf16 dequant accuracy (≤5% median rel error), into-buffer write. - Registry: existing entries carry correct compile_target + compile_metadata, cutlass entry registered + listed, default-rule insert/append ordering works, unknown quant error message lists registered names. - select_default_conversion dispatch: no-quant → bf16, [128,128] blockwise → fp8_128x128, quant_format=cutlass → cutlass, no weight_block_size + dynamic → cutlass, deep-gemm wins when both rules match. - Conversion fn dispatch: 2D linear path correctness, 3D MoE path correctness, requires_scale=True enforced. Adding a sibling kernel (per-token cutlass, awq, gptq, mxfp4, …) is now one new module ~80 LOC: write the quant fn, register() it with appropriate compile_target/metadata, register_default_rule() with its HF-config predicate. Branches off PR #2389 head 79ea824d8. Independent of the Phase 2 graduation PR — these can land in parallel. --- .../trainer/models/conversions/__init__.py | 153 ++++++++- .../trainer/models/conversions/bf16_cast.py | 21 +- .../trainer/models/conversions/cutlass_fp8.py | 105 ++++++ .../models/conversions/fp8_blockwise.py | 31 +- src/prime_rl/trainer/models/fp8.py | 54 +++ .../models/conversions/test_cutlass_fp8.py | 311 ++++++++++++++++++ 6 files changed, 651 insertions(+), 24 deletions(-) create mode 100644 src/prime_rl/trainer/models/conversions/cutlass_fp8.py create mode 100644 tests/unit/train/models/conversions/test_cutlass_fp8.py diff --git a/src/prime_rl/trainer/models/conversions/__init__.py b/src/prime_rl/trainer/models/conversions/__init__.py index 1790ca1284..1c9247502a 100644 --- a/src/prime_rl/trainer/models/conversions/__init__.py +++ b/src/prime_rl/trainer/models/conversions/__init__.py @@ -2,7 +2,10 @@ A conversion is a function that writes one source tensor into one destination tensor, optionally producing a paired scale buffer. Each conversion is -registered under a string name (e.g. ``"fp8_128x128"``). +registered under a string name (e.g. ``"fp8_128x128"``) and carries a +``compile_target`` tag plus a ``compile_metadata`` dict that downstream MX +clients (Phase 3a on ``ai-dynamo/modelexpress:kavink/post-2389-phase3-4``) +use to advertise the bytes' layout to receivers. Resolution flow at startup: @@ -10,29 +13,71 @@ :func:`select_default_conversion` to pick one conversion name to use as the default for every spec that doesn't pin its own. The choice is driven entirely by ``config.quantization_config`` (or its absence). + The resolver is **table-driven** (see ``_DEFAULT_RULES``) so adding a + new kernel = adding one row, not editing if/else chains. 2. For each :class:`~prime_rl.trainer.models.conversion_spec.ConversionSpec`, :func:`resolve` returns the registry entry — explicit ``conversion_type`` on the spec wins, otherwise the startup-chosen default applies. The registry never inspects destination buffer dtype; slot allocation is owned by the transfer slot builder. + +When the Phase 2 graduation of ``MxRendezvous`` onto +``MxV2TrainingPublisher`` lands (see +``KavinKrishnan/prime-rl:kavink/post-2389-phase2-rendezvous-fixes``), the +publisher reads each tensor's resolved ``ConversionEntry.compile_target`` +and ``compile_metadata`` and tags ``TensorDescriptorV2`` accordingly. +Receivers filter via ``MxV2RefitReceiver.discover_v2_sources( +compile_target_filter=…, required_compile_metadata=…)``. Until graduation +lands, the fields are populated but unused — callers can read them via +``ConversionEntry.compile_target`` to plumb manually if needed. """ from __future__ import annotations -from dataclasses import dataclass -from typing import Callable +from dataclasses import dataclass, field +from typing import Any, Callable from torch import Tensor -from transformers import AutoConfig ConversionFn = Callable[[Tensor, Tensor, "Tensor | None"], None] +# Canonical compile-target strings. Mirror the constants in +# ``modelexpress.shape_descriptors`` (Phase 3a, kavink/post-2389-phase3-4) +# so the two repos use exactly the same vocabulary without a hard import +# dependency in either direction. +COMPILE_TARGET_HF_RAW = "hf_raw" +COMPILE_TARGET_DEEPGEMM_FP8 = "deep_gemm_fp8" +COMPILE_TARGET_CUTLASS_FP8 = "cutlass_fp8" +COMPILE_TARGET_VLLM_FUSED = "vllm_fused" +COMPILE_TARGET_TRTLLM = "trtllm" + + @dataclass(frozen=True) class ConversionEntry: + """Registry record for one trainer→inference conversion kernel. + + Fields: + fn: The actual conversion function. Signature + ``(src, out, scale_out_or_None) -> None``. + requires_scale: True if ``fn`` writes a scale buffer; the slot + builder must allocate one. + compile_target: One of the ``COMPILE_TARGET_*`` strings. Identifies + the layout family the output bytes belong to. Receivers filter + on this via the v2 MX client. Default ``"hf_raw"`` means "no + kernel-specific layout, plain HF state-dict". + compile_metadata: Free-form key/value blob describing the specific + compile invocation (e.g. ``{"block_size": 128, + "scale_layout": "K-major"}``). Receivers should treat a + mismatch on any byte-affecting field as a hard reject even + if ``compile_target`` matches. + """ + fn: ConversionFn requires_scale: bool + compile_target: str = COMPILE_TARGET_HF_RAW + compile_metadata: dict[str, Any] = field(default_factory=dict) _REGISTRY: dict[str, ConversionEntry] = {} @@ -43,10 +88,17 @@ def register( fn: ConversionFn, *, requires_scale: bool, + compile_target: str = COMPILE_TARGET_HF_RAW, + compile_metadata: dict[str, Any] | None = None, ) -> None: if name in _REGISTRY: raise ValueError(f"conversion {name!r} is already registered") - _REGISTRY[name] = ConversionEntry(fn=fn, requires_scale=requires_scale) + _REGISTRY[name] = ConversionEntry( + fn=fn, + requires_scale=requires_scale, + compile_target=compile_target, + compile_metadata=dict(compile_metadata) if compile_metadata else {}, + ) def get(name: str) -> ConversionEntry: @@ -55,29 +107,86 @@ def get(name: str) -> ConversionEntry: return _REGISTRY[name] +def registered_names() -> list[str]: + """Snapshot the currently-registered conversion names. Used in tests + diagnostics.""" + return sorted(_REGISTRY) + + +# Table-driven default selection. Each row is a predicate on the parsed +# HF ``quantization_config`` plus the conversion name to return when it +# matches. Walked in order; first match wins. Extending support for a new +# kernel = appending one row (or registering a row from the kernel's +# module on import — see how cutlass_fp8.py does this). +_QuantPredicate = Callable[[dict[str, Any]], bool] +_DEFAULT_RULES: list[tuple[_QuantPredicate, str]] = [] + + +def register_default_rule( + predicate: _QuantPredicate, + name: str, + *, + insert_first: bool = False, +) -> None: + """Add a rule to the default-conversion resolver. + + Args: + predicate: callable taking the dict form of the HF + ``quantization_config`` (always non-None — the resolver + short-circuits to ``"bf16_cast"`` when no quantization_config + is present). Return True to claim this config. + name: registered conversion name to return on match. Must already + be in ``_REGISTRY`` (or be registered before + ``select_default_conversion`` is called). + insert_first: if True, prepend the rule so it beats earlier- + registered rules. Use sparingly — preferred is to append and + let earlier rules with stricter predicates win. + """ + pair = (predicate, name) + if insert_first: + _DEFAULT_RULES.insert(0, pair) + else: + _DEFAULT_RULES.append(pair) + + def select_default_conversion(inference_model_name: str) -> str: """Pick the default conversion name for the given inference model. - Loads the HF config and inspects ``quantization_config``: - - * absent → ``"bf16_cast"`` (no quantization; trainer→inference is a - plain dtype cast). - * ``quant_method == "fp8"`` with ``weight_block_size == [128, 128]`` → - ``"fp8_128x128"``. - * anything else → :class:`NotImplementedError`. + Loads the HF config and inspects ``quantization_config``. When no + quantization_config is present we short-circuit to ``"bf16_cast"`` so + test environments without a real HF download can still exercise the + default path. When present, we walk the ``_DEFAULT_RULES`` table in + order and return the first matching name. If nothing matches the + function raises :class:`NotImplementedError` with the full set of + registered conversions in the message — extend support by adding a + row to ``_DEFAULT_RULES`` (see :func:`register_default_rule`) from + the kernel's own module. """ + # Deferred import: ``transformers`` is a heavy dep we don't want to + # pay at registry-load time (the registry is imported by tests and + # tooling that have no HF download capability). The function is the + # only place that needs it. + from transformers import AutoConfig + config = AutoConfig.from_pretrained(inference_model_name) quant = getattr(config, "quantization_config", None) if quant is None: return "bf16_cast" if hasattr(quant, "to_dict"): quant = quant.to_dict() - method = quant["quant_method"] - block_size = tuple(quant.get("weight_block_size") or ()) - if method == "fp8" and block_size == (128, 128): - return "fp8_128x128" + for predicate, name in _DEFAULT_RULES: + try: + if predicate(quant): + return name + except Exception: + # A predicate that raises on an unexpected config shape should + # not crash the resolver — treat it as "doesn't match" and + # move on. This keeps registry hooks robust to model-name + # weirdness without forcing every predicate to be defensive. + continue raise NotImplementedError( - f"unsupported inference quantization: quant_method={method!r}, weight_block_size={block_size}" + f"unsupported inference quantization: {quant!r}; " + f"registered conversions: {sorted(_REGISTRY)}; " + f"register a new rule via prime_rl.trainer.models.conversions.register_default_rule" ) @@ -88,12 +197,20 @@ def resolve(conversion_type: str | None, default: str) -> ConversionEntry: from prime_rl.trainer.models.conversions import bf16_cast as _bf16_cast # noqa: E402, F401 from prime_rl.trainer.models.conversions import fp8_blockwise as _fp8_blockwise # noqa: E402, F401 +from prime_rl.trainer.models.conversions import cutlass_fp8 as _cutlass_fp8 # noqa: E402, F401 __all__ = [ + "COMPILE_TARGET_CUTLASS_FP8", + "COMPILE_TARGET_DEEPGEMM_FP8", + "COMPILE_TARGET_HF_RAW", + "COMPILE_TARGET_TRTLLM", + "COMPILE_TARGET_VLLM_FUSED", "ConversionEntry", "ConversionFn", - "register", "get", + "register", + "register_default_rule", + "registered_names", "resolve", "select_default_conversion", ] diff --git a/src/prime_rl/trainer/models/conversions/bf16_cast.py b/src/prime_rl/trainer/models/conversions/bf16_cast.py index 16b8dae4fe..bb2450d21c 100644 --- a/src/prime_rl/trainer/models/conversions/bf16_cast.py +++ b/src/prime_rl/trainer/models/conversions/bf16_cast.py @@ -5,7 +5,10 @@ import torch from torch import Tensor -from prime_rl.trainer.models.conversions import register +from prime_rl.trainer.models.conversions import ( + COMPILE_TARGET_HF_RAW, + register, +) def bf16_cast(src: Tensor, out: Tensor, scale_out: Tensor | None = None) -> None: @@ -18,5 +21,17 @@ def fp32_cast(src: Tensor, out: Tensor, scale_out: Tensor | None = None) -> None out.copy_(src.to(torch.float32)) -register("bf16_cast", bf16_cast, requires_scale=False) -register("fp32_cast", fp32_cast, requires_scale=False) +register( + "bf16_cast", + bf16_cast, + requires_scale=False, + compile_target=COMPILE_TARGET_HF_RAW, + compile_metadata={"dtype": "bfloat16"}, +) +register( + "fp32_cast", + fp32_cast, + requires_scale=False, + compile_target=COMPILE_TARGET_HF_RAW, + compile_metadata={"dtype": "float32"}, +) diff --git a/src/prime_rl/trainer/models/conversions/cutlass_fp8.py b/src/prime_rl/trainer/models/conversions/cutlass_fp8.py new file mode 100644 index 0000000000..e24a9b0c72 --- /dev/null +++ b/src/prime_rl/trainer/models/conversions/cutlass_fp8.py @@ -0,0 +1,105 @@ +"""Cutlass-style FP8 e4m3 with per-output-channel scaling. Registered as +``"cutlass_fp8_e4m3_per_channel"``. + +Layout contract (matches cutlass ``scaled_mm`` + vLLM's native FP8 path): + +* 2D linear weights: ``W.shape == (out_features, in_features)``, scale is + one float32 per output row → ``scale.shape == (out_features,)``. +* 3D stacked-expert MoE weights: + ``W.shape == (num_local_experts, out_features, in_features)``, scale is + one float32 per (expert, output-row) → ``scale.shape == (num_local_experts, out_features)``. + +Dispatches between the 2D and 3D paths via :func:`fp8_per_channel_quantize_into` +based on ``src.ndim`` — same dispatch convention as ``fp8_128x128``. + +Tagged with ``compile_target="cutlass_fp8"`` so receivers running cutlass +kernels can filter for it via the v2 MX client's +``discover_v2_sources(compile_target_filter={"cutlass_fp8"})`` (Phase 3b, +``ai-dynamo/modelexpress:kavink/post-2389-phase3-4``). + +``compile_metadata`` documents the byte-affecting choices: + +* ``dtype``: ``"e4m3"`` (vs ``"e5m2"`` for higher-range cutlass variants — + add as a separate entry when needed). +* ``scale_layout``: ``"per_channel"`` — receiver must allocate a 1D scale + per output row, not a 2D blockwise scale. +* ``scale_axis``: ``-1`` — reduction was over the input-features axis; + receiver dequantizes by broadcasting scale along the same axis. +* ``activation_scheme``: ``"dynamic"`` — matches HF's + ``quantization_config.activation_scheme="dynamic"`` for cutlass FP8. + +Adding a sibling cutlass entry (e.g. per-token activations, e5m2, etc.) is +~80 LOC in another file that calls :func:`register` and +:func:`register_default_rule` for its own HF-config signature. +""" + +from __future__ import annotations + +from torch import Tensor + +from prime_rl.trainer.models.conversions import ( + COMPILE_TARGET_CUTLASS_FP8, + register, + register_default_rule, +) +from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize_into + + +def cutlass_fp8_e4m3_per_channel( + src: Tensor, + out: Tensor, + scale_out: Tensor | None, +) -> None: + """Quantize ``src`` (bf16 or fp32) into per-channel FP8 e4m3. + + Writes into preallocated ``out`` (e4m3) + ``scale_out`` (float32). + Dispatches 2D vs 3D via ``src.ndim`` — same convention as + ``fp8_128x128``. + """ + assert scale_out is not None, ( + "cutlass_fp8_e4m3_per_channel requires a scale_out buffer" + ) + fp8_per_channel_quantize_into(src, out=out, sf=scale_out) + + +register( + "cutlass_fp8_e4m3_per_channel", + cutlass_fp8_e4m3_per_channel, + requires_scale=True, + compile_target=COMPILE_TARGET_CUTLASS_FP8, + compile_metadata={ + "dtype": "e4m3", + "scale_layout": "per_channel", + "scale_axis": -1, + "activation_scheme": "dynamic", + }, +) + + +def _is_cutlass_fp8_per_channel(quant: dict) -> bool: + """HF ``quantization_config`` signature for cutlass per-channel FP8. + + Two recognised shapes: + + * ``{"quant_method": "fp8", "weight_block_size": None, + "activation_scheme": "dynamic"}`` — what vLLM and most cutlass- + targeting checkpoints publish. + * ``{"quant_method": "fp8", "quant_format": "cutlass"}`` — used by a + few model cards (Qwen3-MoE FP8 cutlass variants in particular) + that disambiguate cutlass from DeepGemm by setting an explicit + format string instead of leaving ``weight_block_size`` empty. + + The DeepGemm 128x128 rule (registered earlier) takes precedence when + both predicates would match because that rule was registered before + this one in ``_DEFAULT_RULES``. + """ + if quant.get("quant_method") != "fp8": + return False + if quant.get("quant_format") == "cutlass": + return True + block_size = tuple(quant.get("weight_block_size") or ()) + activation_scheme = quant.get("activation_scheme") + return block_size == () and activation_scheme == "dynamic" + + +register_default_rule(_is_cutlass_fp8_per_channel, "cutlass_fp8_e4m3_per_channel") diff --git a/src/prime_rl/trainer/models/conversions/fp8_blockwise.py b/src/prime_rl/trainer/models/conversions/fp8_blockwise.py index 3a9256ab7d..2a5907a9a0 100644 --- a/src/prime_rl/trainer/models/conversions/fp8_blockwise.py +++ b/src/prime_rl/trainer/models/conversions/fp8_blockwise.py @@ -1,14 +1,20 @@ """FP8 e4m3 blockwise quantization, 128x128 blocks. Registered as ``"fp8_128x128"``. Dispatches between the 2D linear layer path and the 3D stacked-expert path -based on ``src.ndim``. +based on ``src.ndim``. Tagged with ``compile_target="deep_gemm_fp8"`` so +receivers running DeepGemm kernels can filter for it via the v2 MX +client's ``discover_v2_sources(compile_target_filter=…)`` (Phase 3b). """ from __future__ import annotations from torch import Tensor -from prime_rl.trainer.models.conversions import register +from prime_rl.trainer.models.conversions import ( + COMPILE_TARGET_DEEPGEMM_FP8, + register, + register_default_rule, +) from prime_rl.trainer.models.fp8 import fp8_block_quantize, grouped_fp8_block_quantize @@ -20,4 +26,23 @@ def fp8_128x128(src: Tensor, out: Tensor, scale_out: Tensor | None) -> None: fp8_block_quantize(src, out=out, sf=scale_out) -register("fp8_128x128", fp8_128x128, requires_scale=True) +register( + "fp8_128x128", + fp8_128x128, + requires_scale=True, + compile_target=COMPILE_TARGET_DEEPGEMM_FP8, + compile_metadata={ + "dtype": "e4m3", + "scale_layout": "blockwise", + "block_size": [128, 128], + }, +) + +# HF config signature for DeepGemm-style FP8: 128x128 blockwise. +register_default_rule( + lambda quant: ( + quant.get("quant_method") == "fp8" + and tuple(quant.get("weight_block_size") or ()) == (128, 128) + ), + "fp8_128x128", +) diff --git a/src/prime_rl/trainer/models/fp8.py b/src/prime_rl/trainer/models/fp8.py index c04bf2f3b7..35576b0303 100644 --- a/src/prime_rl/trainer/models/fp8.py +++ b/src/prime_rl/trainer/models/fp8.py @@ -86,3 +86,57 @@ def grouped_fp8_block_quantize( if sf is not None: sf.copy_(s_accum) return q_accum, s_accum + + +# ---------------------------------------------------------------------------- +# Per-output-channel FP8 (cutlass-style): one scale per row of W. Used by +# cutlass scaled_mm + vLLM's native FP8 path. For a 2D weight of shape +# (out_features, in_features), reduction is over in_features (axis=-1) and +# the resulting scale has shape (out_features,). For a 3D stacked-expert +# weight of shape (num_local_experts, out_features, in_features) we run the +# same recipe per expert, producing a (num_local_experts, out_features) +# scale tensor. No padding / block reshuffling — the bytes go out in the +# same layout the trainer holds them in, which matches cutlass's +# RowMajor + per-channel scale convention. +# ---------------------------------------------------------------------------- + + +def fp8_per_channel_quantize( + weight: Tensor, +) -> tuple[Tensor, Tensor]: + """Per-output-channel symmetric FP8 e4m3 quantization. + + Supports both 2D ``(out, in)`` linear weights and 3D + ``(E, out, in)`` stacked-expert weights via the same code path. + Returns ``(quantized, scale)`` where ``scale`` has shape + ``weight.shape[:-1]`` (i.e. one scalar per output row, per expert). + """ + if weight.ndim not in (2, 3): + raise ValueError( + f"fp8_per_channel_quantize expects 2D or 3D, got shape={tuple(weight.shape)}" + ) + fp8_max = torch.finfo(torch.float8_e4m3fn).max # 448 for e4m3 + # amax over the innermost (input-features) axis. + amax = weight.detach().float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + scale = (amax / fp8_max).clamp(min=1e-12) + q = (weight.float() / scale).clamp(-fp8_max, fp8_max).to(torch.float8_e4m3fn) + return q.contiguous(), scale.squeeze(-1).to(torch.float32).contiguous() + + +def fp8_per_channel_quantize_into( + weight: Tensor, + out: Tensor | None = None, + sf: Tensor | None = None, +) -> tuple[Tensor, Tensor]: + """Per-channel FP8 quantize, optionally writing into preallocated buffers. + + Shape contract: + - ``out.shape == weight.shape``, dtype ``torch.float8_e4m3fn`` + - ``sf.shape == weight.shape[:-1]``, dtype ``torch.float32`` + """ + q, s = fp8_per_channel_quantize(weight) + if out is not None: + out.copy_(q) + if sf is not None: + sf.copy_(s) + return q, s diff --git a/tests/unit/train/models/conversions/test_cutlass_fp8.py b/tests/unit/train/models/conversions/test_cutlass_fp8.py new file mode 100644 index 0000000000..cd1ea3ece7 --- /dev/null +++ b/tests/unit/train/models/conversions/test_cutlass_fp8.py @@ -0,0 +1,311 @@ +"""Tests for cutlass FP8 e4m3 per-channel conversion + registry-extension plumbing. + +Direct-loads the conversions package to bypass the heavy +``prime_rl.trainer`` import chain (CUDA + torchrun + ray + …) so the suite +runs on a plain CPU CI box with only ``torch`` installed. +""" + +from __future__ import annotations + +import importlib +import importlib.util +import sys +import types +from pathlib import Path + +import pytest +import torch + + +_HERE = Path(__file__).resolve().parent +_REPO_ROOT = _HERE.parent.parent.parent.parent.parent +_CONV_PKG_DIR = _REPO_ROOT / "src" / "prime_rl" / "trainer" / "models" / "conversions" +_FP8_PATH = _REPO_ROOT / "src" / "prime_rl" / "trainer" / "models" / "fp8.py" + + +def _direct_load(name: str, path: Path): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + spec.loader.exec_module(mod) + return mod + + +@pytest.fixture(scope="module") +def conv_pkg(): + """Load the conversions package + its dependencies in isolation. + + Order matters: ``conversions/__init__.py`` registers ``cutlass_fp8`` as + a late side-effect import; we need ``prime_rl.trainer.models.fp8`` to + be importable first. + """ + # Synthesize the prime_rl.trainer.models package hierarchy so the + # relative imports inside the conversion modules resolve. + for fqn, path in [ + ("prime_rl", _REPO_ROOT / "src" / "prime_rl"), + ("prime_rl.trainer", _REPO_ROOT / "src" / "prime_rl" / "trainer"), + ("prime_rl.trainer.models", _REPO_ROOT / "src" / "prime_rl" / "trainer" / "models"), + ]: + if fqn in sys.modules: + continue + pkg = types.ModuleType(fqn) + pkg.__path__ = [str(path)] + sys.modules[fqn] = pkg + + # Load fp8.py first — the conversion modules import from it. + _direct_load("prime_rl.trainer.models.fp8", _FP8_PATH) + + # Now load the conversions package, then its submodules. We point at + # the directory's __init__.py explicitly so we don't get the partial + # package from a parent that's already half-loaded. + pkg = _direct_load( + "prime_rl.trainer.models.conversions", _CONV_PKG_DIR / "__init__.py" + ) + return pkg + + +# ---------------------------------------------------------------------------- +# Per-output-channel quantize helper (lives in fp8.py) +# ---------------------------------------------------------------------------- + + +def test_fp8_per_channel_2d_round_trip_shape(conv_pkg): + from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize + + w = torch.randn(64, 256, dtype=torch.bfloat16) + q, s = fp8_per_channel_quantize(w) + assert q.shape == (64, 256) + assert q.dtype == torch.float8_e4m3fn + assert s.shape == (64,) + assert s.dtype == torch.float32 + + +def test_fp8_per_channel_3d_round_trip_shape(conv_pkg): + from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize + + w = torch.randn(8, 64, 256, dtype=torch.bfloat16) + q, s = fp8_per_channel_quantize(w) + assert q.shape == (8, 64, 256) + assert q.dtype == torch.float8_e4m3fn + assert s.shape == (8, 64) + assert s.dtype == torch.float32 + + +def test_fp8_per_channel_rejects_1d(conv_pkg): + from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize + + with pytest.raises(ValueError, match="2D or 3D"): + fp8_per_channel_quantize(torch.randn(64)) + + +def test_fp8_per_channel_dequant_close_to_original(conv_pkg): + """Round-trip accuracy: per-channel scaling has ~1% error band on bf16 inputs.""" + from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize + + torch.manual_seed(0) + w = torch.randn(32, 128, dtype=torch.bfloat16) * 0.1 + q, s = fp8_per_channel_quantize(w) + dequant = q.float() * s.unsqueeze(-1) + # FP8 e4m3 has ~3-bit mantissa → relative error tolerance is generous + rel = (dequant - w.float()).abs() / (w.float().abs() + 1e-6) + assert rel.median().item() < 0.05 # 5 % median error is realistic for fp8 e4m3 + + +def test_fp8_per_channel_into_writes_buffers(conv_pkg): + from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize_into + + w = torch.randn(16, 64, dtype=torch.bfloat16) + out = torch.empty(16, 64, dtype=torch.float8_e4m3fn) + sf = torch.empty(16, dtype=torch.float32) + fp8_per_channel_quantize_into(w, out=out, sf=sf) + # Both buffers should now reflect a real quantization (not the empty pattern). + assert sf.gt(0).all() + assert out.float().abs().max() <= 448.0 # fp8 e4m3 finite range + + +# ---------------------------------------------------------------------------- +# Registry extensions: compile_target + compile_metadata + new entry +# ---------------------------------------------------------------------------- + + +def test_conversion_entry_carries_compile_target(conv_pkg): + entry = conv_pkg.get("bf16_cast") + assert entry.compile_target == conv_pkg.COMPILE_TARGET_HF_RAW + assert entry.compile_metadata == {"dtype": "bfloat16"} + + +def test_fp8_128x128_tagged_deep_gemm(conv_pkg): + entry = conv_pkg.get("fp8_128x128") + assert entry.compile_target == conv_pkg.COMPILE_TARGET_DEEPGEMM_FP8 + assert entry.compile_metadata["block_size"] == [128, 128] + assert entry.compile_metadata["scale_layout"] == "blockwise" + + +def test_cutlass_fp8_entry_registered(conv_pkg): + entry = conv_pkg.get("cutlass_fp8_e4m3_per_channel") + assert entry.requires_scale is True + assert entry.compile_target == conv_pkg.COMPILE_TARGET_CUTLASS_FP8 + assert entry.compile_metadata == { + "dtype": "e4m3", + "scale_layout": "per_channel", + "scale_axis": -1, + "activation_scheme": "dynamic", + } + + +def test_cutlass_fp8_in_registered_names(conv_pkg): + names = conv_pkg.registered_names() + assert "cutlass_fp8_e4m3_per_channel" in names + assert "fp8_128x128" in names + assert "bf16_cast" in names + + +def test_register_default_rule_appends(conv_pkg): + """register_default_rule appends by default and prepends with insert_first=True.""" + + sentinel_name = "bf16_cast" # we know this exists + + def predicate_a(quant): + return quant.get("quant_method") == "test_a" + + def predicate_b(quant): + return quant.get("quant_method") == "test_b" + + # These mutate module state; use unique enough names that they don't + # collide with the real rules. + conv_pkg.register_default_rule(predicate_a, sentinel_name) + conv_pkg.register_default_rule(predicate_b, sentinel_name, insert_first=True) + + # We can't read the table directly without breaking the encapsulation, + # but we can verify behaviorally: predicate_b should be matched before + # the existing rules; predicate_a should be matched after. + import prime_rl.trainer.models.conversions as conv + + rules = conv._DEFAULT_RULES + # predicate_b should now be at index 0 + assert rules[0][0] is predicate_b + # predicate_a should be at the end + assert rules[-1][0] is predicate_a + + +def test_unknown_quant_raises_listing_registered(conv_pkg, fake_hf_config): + """When no rule matches, the error message lists what IS registered.""" + fake_hf_config["quant"] = {"quant_method": "totally_unknown_method"} + with pytest.raises(NotImplementedError, match="registered conversions"): + conv_pkg.select_default_conversion("fake/model") + + +# ---------------------------------------------------------------------------- +# select_default_conversion dispatch (the new table-driven path) +# ---------------------------------------------------------------------------- + + +@pytest.fixture +def fake_hf_config(monkeypatch): + """Stub ``transformers.AutoConfig`` for the test session so + ``select_default_conversion`` runs without an HF download. + + Because the conversions module imports ``AutoConfig`` lazily inside + the function, we have to populate ``sys.modules['transformers']`` + with our stub *before* the function call resolves the import. + """ + holder = {"quant": None} + + class _Fake: + @property + def quantization_config(self): + return holder["quant"] + + class _FakeAutoConfig: + @staticmethod + def from_pretrained(*args, **kwargs): + return _Fake() + + transformers_stub = types.ModuleType("transformers") + transformers_stub.AutoConfig = _FakeAutoConfig + monkeypatch.setitem(sys.modules, "transformers", transformers_stub) + return holder + + +def test_default_no_quant_is_bf16(conv_pkg, fake_hf_config): + fake_hf_config["quant"] = None + assert conv_pkg.select_default_conversion("any/model") == "bf16_cast" + + +def test_default_deep_gemm_fp8(conv_pkg, fake_hf_config): + fake_hf_config["quant"] = { + "quant_method": "fp8", + "weight_block_size": [128, 128], + } + assert conv_pkg.select_default_conversion("any/model") == "fp8_128x128" + + +def test_default_cutlass_fp8_via_explicit_format(conv_pkg, fake_hf_config): + fake_hf_config["quant"] = { + "quant_method": "fp8", + "quant_format": "cutlass", + } + assert ( + conv_pkg.select_default_conversion("any/model") + == "cutlass_fp8_e4m3_per_channel" + ) + + +def test_default_cutlass_fp8_via_dynamic_no_block_size(conv_pkg, fake_hf_config): + fake_hf_config["quant"] = { + "quant_method": "fp8", + "weight_block_size": None, + "activation_scheme": "dynamic", + } + assert ( + conv_pkg.select_default_conversion("any/model") + == "cutlass_fp8_e4m3_per_channel" + ) + + +def test_default_deep_gemm_wins_over_cutlass_when_block_size_set( + conv_pkg, fake_hf_config +): + """Both rules could plausibly fire for a config with block_size=[128,128] + AND activation_scheme="dynamic"; the deep-gemm rule was registered first + and must win.""" + fake_hf_config["quant"] = { + "quant_method": "fp8", + "weight_block_size": [128, 128], + "activation_scheme": "dynamic", + } + assert conv_pkg.select_default_conversion("any/model") == "fp8_128x128" + + +# ---------------------------------------------------------------------------- +# End-to-end fn dispatch: 2D + 3D shapes via the registered conversion entry +# ---------------------------------------------------------------------------- + + +def test_cutlass_fp8_fn_dispatches_2d_linear(conv_pkg): + entry = conv_pkg.get("cutlass_fp8_e4m3_per_channel") + src = torch.randn(32, 128, dtype=torch.bfloat16) + out = torch.empty(32, 128, dtype=torch.float8_e4m3fn) + sf = torch.empty(32, dtype=torch.float32) + entry.fn(src, out, sf) + assert sf.gt(0).all() + assert out.float().abs().max() <= 448.0 + + +def test_cutlass_fp8_fn_dispatches_3d_moe(conv_pkg): + entry = conv_pkg.get("cutlass_fp8_e4m3_per_channel") + src = torch.randn(4, 32, 128, dtype=torch.bfloat16) # E=4 experts + out = torch.empty(4, 32, 128, dtype=torch.float8_e4m3fn) + sf = torch.empty(4, 32, dtype=torch.float32) + entry.fn(src, out, sf) + assert sf.shape == (4, 32) + assert sf.gt(0).all() + assert out.shape == (4, 32, 128) + + +def test_cutlass_fp8_fn_requires_scale(conv_pkg): + entry = conv_pkg.get("cutlass_fp8_e4m3_per_channel") + src = torch.randn(8, 16, dtype=torch.bfloat16) + out = torch.empty(8, 16, dtype=torch.float8_e4m3fn) + with pytest.raises(AssertionError, match="scale_out"): + entry.fn(src, out, None) From d676523d068621cd220260e0e8380583fe791f26 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Thu, 28 May 2026 12:20:17 -0700 Subject: [PATCH 05/18] fix(transport/mx_rendezvous): tolerate both modelexpress.heartbeat module paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The v0.5.2 trainer image ships MX 0.3.0 where HeartbeatThread is at modelexpress.heartbeat. Newer MX (0.4+ on kavink/nemo_rl_moe) moved it under modelexpress.metadata.heartbeat as part of the metadata-module reorg. Tolerate both with a try/except at import time so the same source code works against either MX version. No behavior change at runtime — the class is identical between paths. Unit tests (11/11) still pass since the test fixture patches the HeartbeatThread symbol on the module post-import. --- src/prime_rl/transport/mx_rendezvous.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/prime_rl/transport/mx_rendezvous.py b/src/prime_rl/transport/mx_rendezvous.py index 1c85324f80..10414808a5 100644 --- a/src/prime_rl/transport/mx_rendezvous.py +++ b/src/prime_rl/transport/mx_rendezvous.py @@ -34,7 +34,17 @@ from modelexpress import p2p_pb2 from modelexpress.client import MxClient -from modelexpress.metadata.heartbeat import HeartbeatThread + +# HeartbeatThread moved in MX 0.4+ from ``modelexpress.heartbeat`` to +# ``modelexpress.metadata.heartbeat`` as part of the metadata-module +# reorganization. Tolerate both so this code works against the v0.5.2 +# image (MX 0.3.0, old path) and the newer ``kavink/nemo_rl_moe`` MX +# (which exposes the new path). The MX-side migration tracker is in +# ``pensieve/RL/PrimeRL/09_rfc_updates_needed.md``. +try: + from modelexpress.metadata.heartbeat import HeartbeatThread # MX 0.4+ +except ImportError: # pragma: no cover - environment-dependent + from modelexpress.heartbeat import HeartbeatThread # MX 0.3 Role = Literal["trainer", "inference", "orchestrator"] From 1b36af89ab0d34a41642f19d40771724bf0113e5 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Thu, 28 May 2026 20:06:46 -0700 Subject: [PATCH 06/18] docs(proposals): build notes for Phase 2 + Phase 3 source-baked image (v0.7.x) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Captures the empirical findings from baking PRs #1 and #2 into an ARM64 GB200 image and running it on the kavin namespace for 8+ hours on Qwen3-30B-A3B-Instruct-2507 with gsm8k. Documents three real surprises the unit tests didn't cover: 1. Dockerfile.cuda's `uv sync` is missing `--extra disagg`, so modelexpress isn't installed in stock images; inference workers crash at the first import. Shipped v0.7.1 as a one-line overlay that adds the extra until the upstream Dockerfile.cuda can be updated. 2. `LD_PRELOAD` path for libcudart.so.12 — v0.5.2 had /usr/local/cuda present in the final stage; v0.7.0 (built from upstream Dockerfile.cuda as-is) doesn't. The pip-installed wheel path (/app/.venv/lib/python3.12/site-packages/nvidia/cuda_runtime/lib/) is the new canonical location. 3. The configmap monkeypatch (patch_nixl_mx.py) and Phase 2's source-baked fixes are complementary — they patch different layers (broadcast vs rendezvous-wait) and both should stay until PR #1 merges upstream. Build experience numbers: - v0.7.0 from-scratch ARM64 build under QEMU: 6h45min (uv sync 45m, flash-attn from source 3h45m). - v0.7.1 overlay on top of v0.7.0: ~3 min. Cluster observations from v0.5.2 + configmap monkeypatch (the runtime-patched path our PR #1 codifies into source): - 183 successful RL refit cycles in one 66-min uninterrupted window - Reward variance 0.5-1.0 across orchestrator steps (real learning) - Off-policy level = 0 throughout - Zero NIXL data-plane errors - Recurring orchestrator wait_for_all_peers_ready timeout (~once per 30-66 min) is the exact bug class Phase 2's rendezvous-level dedup eliminates Also notes seven RFC updates queued in pensieve/RL/PrimeRL/09_rfc_updates_needed.md, three of which are new from this build experience (disagg extra, LD_PRELOAD path, vLLM PR #43375 / Anyscale RDT positioning). Companion to the RFC at docs/proposals/post-pr2389-kernel-compile-plan.md. --- docs/proposals/build-notes-2026-05-28.md | 157 +++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 docs/proposals/build-notes-2026-05-28.md diff --git a/docs/proposals/build-notes-2026-05-28.md b/docs/proposals/build-notes-2026-05-28.md new file mode 100644 index 0000000000..dead46c2d6 --- /dev/null +++ b/docs/proposals/build-notes-2026-05-28.md @@ -0,0 +1,157 @@ +# Build notes — Phase-2 + Phase-3 source-baked image (2026-05-28) + +> **Companion to**: [`docs/proposals/post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) +> **Status**: empirical findings from baking Phase 2 (rendezvous fixes) + Phase 3 (conversion-registry extensions) into an ARM64 GB200 image and running it against the live `kavin` namespace. Updates the RFC's framing where the build experience contradicted assumptions in the original RFC. + +This document captures **what we learned producing a usable image** containing the two follow-up PRs ([phase-2 rendezvous fixes](https://github.com/KavinKrishnan/prime-rl/pull/1) and [conversion-registry extensions](https://github.com/KavinKrishnan/prime-rl/pull/2)) on top of PR #2389 (HEAD `79ea824d8`). The unit tests for both PRs were already green; this doc records the cluster + image surface area that the unit tests don't cover. + +## 1. What we built + +Two images, in order: + +| Tag | Base | What's added | Status | +|---|---|---|---| +| `prime-rl-mx-on-nixl:v0.7.0-kavin-phase2-phase3` | `nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04` (full Dockerfile.cuda rebuild) | Phase 2 + Phase 3 source merged in. Built from `kavink/post-2389-image-build-2026-05-28` (which merges PR #1 + PR #2 on top of `79ea824d8`). | **Pushed to nvcr** | +| `prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3` | `v0.7.0-kavin-phase2-phase3` | Adds the `disagg` extra (modelexpress + nixl-cu12 + vllm-router). Fixes the import error from v0.7.0. | **Pushed to nvcr** | + +`v0.7.1` is the one to deploy. `v0.7.0` is kept as a reference of the from-scratch build artifact. + +## 2. Build mechanics (ARM64 GB200 / QEMU) + +The from-scratch ARM64 build of `v0.7.0` took **6h 45min on x86 host with QEMU arm64 emulation** (buildkit `multi-arch` builder). Breakdown: + +| Stage | Time | Notes | +|---|---|---| +| Pull `nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04` ARM64 base | ~5 min | First time only; cached after | +| `apt-get install` builder + final stages | ~3 min total | Both stages of multi-stage Dockerfile | +| `COPY src/ + packages/ + deps/` | seconds | Trivial | +| **`uv sync --extra ... --locked --no-dev`** | **45 min** | Resolves + downloads + installs ~350 packages including torch 2.7+cu130 (~5 GB), nvidia-cudnn-cu12 (738 MiB), flashinfer-cubin (large), tilelang, xgrammar, vllm 0.21.0+cu129 etc. Under QEMU emulation. | +| **`docker-arm64-post-install.sh` (flash-attn from source for sm_100 / GB200)** | **~3h 45min** | 73 CUDA kernel `.o` files, each compiled via emulated `nvcc` for sm_80 + sm_90 + sm_100. Most expensive kernels are `hdim192_bf16_causal` and `hdim256_bf16` for backward pass (15-40 min each). | +| Final stage `COPY --from=builder /app` + image export | ~7 min | 15.9 GB final image, 6.5 GB of which is one big layer (the venv) | + +`v0.7.1` overlay on top of `v0.7.0` was **~3 min** (the `uv sync` with `disagg` extra reuses every cached layer except the new modelexpress/nixl-cu12/vllm-router wheels). + +**Practical implication**: every meaningful rebuild from the Dockerfile.cuda base is ~7 hours on a non-ARM host. Use overlay Dockerfiles for additive changes. Reserve from-scratch only for `pyproject.toml` / `uv.lock` updates or major source restructuring. + +## 3. Three real issues the build surfaced that aren't in the RFC + +### 3.1 `Dockerfile.cuda` is missing `--extra disagg` for nixl_mx use + +[`Dockerfile.cuda`](../../Dockerfile.cuda) line 52: + +```dockerfile +RUN --mount=type=cache,target=/app/.cache/uv \ + uv sync --extra flash-attn --extra flash-attn-3 --extra flash-attn-cute --extra envs --extra gpt-oss --group mamba-ssm --locked --no-dev +``` + +The `disagg` extra ([`pyproject.toml` line 90](../../pyproject.toml#L90)) contains: + +```toml +disagg = [ + "deep-ep ; platform_machine == 'x86_64'", + "deep-gemm ; platform_machine == 'x86_64'", + "nixl", + "nixl-cu12 ; platform_machine == 'x86_64'", + "vllm-router ; platform_machine == 'x86_64'", + "modelexpress", +] +``` + +Without it, **`modelexpress` is not installed**, and the inference worker crashes at the first import of `prime_rl.inference.vllm.worker.nixl_mx`: + +``` +File "/app/src/prime_rl/inference/vllm/worker/nixl_mx.py", line 7, in + from modelexpress import p2p_pb2 +ModuleNotFoundError: No module named 'modelexpress' +``` + +The pre-PR-#2389 `Dockerfile.cuda` predates the `disagg` extra so this is an accidental gap, not an intentional opt-out. **Suggested change**: add `--extra disagg` (or rely on `--extra all`) for any image targeting `weight_broadcast.type=nixl_mx`. We've shipped `v0.7.1` as a one-line overlay that does this until the change can land in `Dockerfile.cuda` itself. + +### 3.2 `LD_PRELOAD` path for libcudart.so.12 moved + +The existing configmap's three run-scripts (`run_trainer.sh`, `run_inference.sh`, `run_orchestrator.sh`) all preload libcudart for ARM64 NIXL compatibility: + +```bash +export LD_PRELOAD="/usr/local/cuda/lib64/libcudart.so.12:${LD_PRELOAD:-}" +``` + +`/usr/local/cuda` exists in the v0.5.2 image (which appears to have been built from a Dockerfile variant that retained the CUDA tooling in the final stage). In `v0.7.0` (built from the upstream `Dockerfile.cuda` as-is), the final stage is `python:3.12-slim` which **does not** have `/usr/local/cuda`. `libcudart.so.12` lives only inside the pip-installed `nvidia-cuda-runtime` wheel: + +``` +/app/.venv/lib/python3.12/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 +``` + +Symptom on v0.7.0 with the unmodified configmap: + +``` +ERROR: ld.so: object '/usr/local/cuda/lib64/libcudart.so.12' from LD_PRELOAD cannot be preloaded +``` + +**Fix applied**: the three run-scripts now use the wheel-internal path. Alternative: symlink `/usr/local/cuda/lib64/libcudart.so.12 -> /app/.venv/.../libcudart.so.12` in the Dockerfile's final stage. Either works; we picked the env-var path because it's a configmap edit, no image rebuild. + +### 3.3 The configmap `patch_nixl_mx.py` and Phase 2 source coexist + +The kavin namespace runs a configmap-injected monkeypatch at container start (`patch_nixl_mx.py`) that rewrites `src/prime_rl/trainer/rl/broadcast/nixl_mx.py` to add same-rank-only peer filter + freshest-per-rank dedup *at TransportPlan construction time*. + +Phase 2 ([PR #1](https://github.com/KavinKrishnan/prime-rl/pull/1)) adds the same semantic guarantees but at a different layer — inside `src/prime_rl/transport/mx_rendezvous.py:wait_for_all_peers_ready`. The two patches are **complementary, not redundant**: + +| Code path | Bug class | Covered by | +|---|---|---| +| `trainer/rl/broadcast/nixl_mx.py:lazy_init` → `TransportPlan(peer_metadata=…)` | Trainer adds dead peers as NIXL remote agents during the per-step broadcast | `patch_nixl_mx.py` (runtime monkeypatch) | +| `transport/mx_rendezvous.py:wait_for_all_peers_ready(role="trainer")` | Orchestrator counts historical trainer entries in Redis and times out waiting for `n_historical` to all reach READY when only `n_alive` exist | Phase 2 PR (source-level) | + +On v0.7.1 + the existing configmap, both fire. The trainer log shows `[patch_nixl_mx] PATCHED v2 (kavin_freshest_per_rank)` from the configmap script; the rendezvous wait methods get the Phase 2 dedup automatically because the source is in the baked image. Empirically the orchestrator restart pattern we saw on v0.5.2 (~once per 30-66 min on this workload) should go away on v0.7.1. **Validation pending** — image just deployed at time of writing. + +When PR #1 merges upstream, the configmap monkeypatch becomes redundant for the trainer-side path too and should be removed. Until then, both layers complement each other. + +## 4. Cluster observations under v0.5.2 + configmap monkeypatch + +For the record, the v0.5.2 + configmap-monkeypatch combination we ran for 8+ hours before v0.7.1 deploy: + +- Workload: Qwen3-30B-A3B-Instruct-2507, FSDP 2×2, EP=4 (32/128 experts per rank), FLASHINFER attention, gsm8k env +- Trainer steady state: ~10–21 s/step (varies with sequence length 280–500 tokens) +- Reward signal: variance 0.5–1.0 per orchestrator step — **real learning gradient**, not just reward=1.0 collapse +- Off-policy level: 0 across all observed steps (in-lockstep refit) +- Best uninterrupted window: **183 successful RL refit cycles over 66 min** between orchestrator restarts +- Zero NIXL data-plane errors (no `REMOTE_DISCONNECT`, no `NOT_ALLOWED`, no stale-READY) — confirms the same-rank-only + freshest-per-rank patches are correct +- Recurring orchestrator timeout pattern: `TimeoutError: timed out after 1200.0s waiting for 12 'trainer' peers to reach status 1 (saw 4)` — exactly what Phase 2's rendezvous-level dedup fixes + +That last bullet is the bug class v0.7.1 is meant to eliminate. The configmap monkeypatch couldn't fix it because the relevant call site is in the orchestrator's rendezvous, which is in a different module from the trainer-side broadcast the monkeypatch was rewriting. + +## 5. Branches + image artifacts pushed + +| Branch | What's in it | Where | +|---|---|---| +| [`kavink/post-2389-kernel-compile-plan`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-kernel-compile-plan) | RFC document + this build-notes doc | `KavinKrishnan/prime-rl` | +| [`kavink/post-2389-phase2-rendezvous-fixes`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-phase2-rendezvous-fixes) | Phase 2 source (heartbeat + dedup + same-rank), 11/11 unit tests green, plus the `modelexpress.heartbeat` module-path tolerance fix | [Draft PR #1](https://github.com/KavinKrishnan/prime-rl/pull/1) | +| [`kavink/post-2389-conversion-registry-extensions`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-conversion-registry-extensions) | Phase 3 conversion-registry extensions (`compile_target` + `compile_metadata` + `cutlass_fp8_e4m3_per_channel`), 19/19 unit tests green | [Draft PR #2](https://github.com/KavinKrishnan/prime-rl/pull/2) | +| [`kavink/post-2389-image-build-2026-05-28`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-image-build-2026-05-28) | Merge of Phase 2 + Phase 3 + the import-tolerance fix; this is the exact source tree v0.7.0 / v0.7.1 was built from | `KavinKrishnan/prime-rl` (this push) | + +Image artifacts on `nvcr.io/nvidian/dynamo-dev/`: + +- `prime-rl-mx-on-nixl:v0.7.0-kavin-phase2-phase3` — full from-scratch ARM64 build (broken — missing `disagg`) +- `prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3` — overlay that adds `disagg` extra + +MX side ([`ai-dynamo/modelexpress#349`](https://github.com/ai-dynamo/modelexpress/pull/349)) updated with the graduation glue commit that plumbs `ConversionEntry.compile_target` + `ConversionEntry.compile_metadata` through `MxV2TrainingPublisher.add_tensor(compile_target=…, compile_metadata=…)`. Wire round-trip is unit-tested. + +## 6. What to update in the RFC (`post-pr2389-kernel-compile-plan.md`) — but not yet + +These are the four edits queued in [`pensieve/RL/PrimeRL/09_rfc_updates_needed.md`](https://github.com/ai-dynamo/modelexpress/) (internal), augmented by what we learned from the build: + +1. **Reframe Phase 3** — trainer-side post-processed direct is primary; receiver-side compile passes are v4+ (scratch buffers are a fallback only, not the primary v3 design as the original RFC implied). +2. **Add Phase 0** — Phase B UCX/dma-buf env profile as cluster prerequisite (`UCX_TLS=rc,cuda_copy`, `NIXL_UCX_TLS=rc,cuda_copy`, `UCX_CUDA_COPY_DMABUF=yes`, etc.) — from NeMo-RL + Dynamo's empirical 380 Gbps validation. +3. **Mark Phases 2/3/4 as shipped, not paper** (with PR + commit references). +4. **Add a sub-section on conversion registry extensions** documenting the `ConversionEntry` schema extension + how to add a new kernel (~80 LOC per kernel). + +**New from this build experience**: + +5. **Document the `disagg` extra requirement in §0** alongside the env profile — easy gotcha that costs an entire rebuild to discover. +6. **Document the `LD_PRELOAD` path** for libcudart in §0 — pre-existing run-scripts assumed v0.5.2's `/usr/local/cuda` layout. +7. **§4 / §5 on the "fallback path"** — describe vLLM PR #43375 (Ray Direct Transport, Anyscale) as the canonical receiver-pull-via-load_weights instance of the fallback path. Our positioning of trainer-side post-processed direct as "primary path, zero receive-side compute" is unchanged; RDT is the upstream-stamped instance of the alternative receiver-pull path. + +## 7. Open follow-ups + +- **Validate v0.7.1 end-to-end** on kavin (pending — v0.7.1 just deployed). Expected: zero NIXL errors AND zero orchestrator-`wait_for_all_peers_ready` timeouts. If both hold for a long uninterrupted window (>3 hours), we declare the source-baked Phase 2 + Phase 3 production-ready. +- **Send a one-line PR to upstream `Dockerfile.cuda`** adding `--extra disagg` (or `--extra all`). Tiny patch, unblocks every other team trying to bake `nixl_mx` mode into an image. +- **MX side: roll the server with the `SourceIdentity` round-trip fix** (proto change committed but not deployed). After that, the `__mx_v2_meta__` sidecar transport workaround can be dropped from `MxV2TrainingPublisher` (it's already filtered before NIXL register via [PR #295](https://github.com/ai-dynamo/modelexpress/pull/295)). +- **`pull_one(name)` semantic on MX** — inspired by vLLM PR #43375's RDT contract. Would let MX expose Ray-like per-tensor elasticity without abandoning the trainer-side compile model. ~50 LOC; not on the critical path but a clean addition for the post-Phase-4 work. From dbe936fff9b86988ce1a71ade26c2b858dad0a61 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Thu, 28 May 2026 20:27:54 -0700 Subject: [PATCH 07/18] =?UTF-8?q?docs(proposals):=20build=20notes=20=C2=A7?= =?UTF-8?q?8=20=E2=80=94=20vLLM=20native=20RL=20APIs=20reframe=20Phase=202?= =?UTF-8?q?/3/4=20upstream=20form?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit vLLM published https://vllm.ai/blog/2026-05-28-native-rl-apis the same day, announcing a standardized WeightTransferEngine abstract base + 4-phase lifecycle (init / start / update / finish) + a pluggable WeightTransferEngineFactory.register_engine(...) extension point. This is the upstream integration seam that the in-tree MxRendezvous reimplementation in PR #2389 and the worker_extension_cls injection in inference/vllm/worker/nixl_mx.py have been emulating. The cleanest form of all our Phase 2/3/4 work upstream is a single MxWeightTransferEngine adapter (~150-200 LOC) that subclasses WeightTransferEngine and wraps the existing MxV2RefitReceiver + MxV2TrainingPublisher. Three immediate consequences captured in §8: §8.1 — Phase 2/3/4 should be repackaged as MxWeightTransferEngine for upstream contribution; the existing patches stay correct, the packaging just becomes upstream-native. §8.2 — The blog credits Matej Sirovatka specifically. He's likely mid-flight on a native-APIs rewrite of prime-rl's nixl_mx broadcast. Ask him before pushing Phase 2 upstream; the work may retarget to the adapter path directly. §8.3 — Their validation was at 16x 8xH200, DPEP32, 256 GPUs total. That scale makes Phase 4's multi-source slice planning load-bearing (mixed-TP/EP is the common case), not optional. Validates the design direction and sets the next cluster validation target after the DP=4 kavin smoke. §8.4 — pause_generation(mode="keep") + two-phase DPEP pause are features we don't yet match. Keep mode unlocks true async RL; queue after Phase 2 lands. Updated follow-up list grows from 4 to 7 items, with the three new ones being: write MxWeightTransferEngine, adopt keep-mode pause in the orchestrator, and coordinate with Robert Shaw / the vLLM RL roadmap on the K8s-native weight transfer engine they mention as ongoing work (which describes MX itself, modulo who's driving the upstream PR). --- docs/proposals/build-notes-2026-05-28.md | 76 ++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/docs/proposals/build-notes-2026-05-28.md b/docs/proposals/build-notes-2026-05-28.md index dead46c2d6..01a8cbb229 100644 --- a/docs/proposals/build-notes-2026-05-28.md +++ b/docs/proposals/build-notes-2026-05-28.md @@ -155,3 +155,79 @@ These are the four edits queued in [`pensieve/RL/PrimeRL/09_rfc_updates_needed.m - **Send a one-line PR to upstream `Dockerfile.cuda`** adding `--extra disagg` (or `--extra all`). Tiny patch, unblocks every other team trying to bake `nixl_mx` mode into an image. - **MX side: roll the server with the `SourceIdentity` round-trip fix** (proto change committed but not deployed). After that, the `__mx_v2_meta__` sidecar transport workaround can be dropped from `MxV2TrainingPublisher` (it's already filtered before NIXL register via [PR #295](https://github.com/ai-dynamo/modelexpress/pull/295)). - **`pull_one(name)` semantic on MX** — inspired by vLLM PR #43375's RDT contract. Would let MX expose Ray-like per-tensor elasticity without abandoning the trainer-side compile model. ~50 LOC; not on the critical path but a clean addition for the post-Phase-4 work. + +## 8. Strategic update — vLLM published native RL APIs (2026-05-28) + +Same day as this doc, the vLLM team published [Native RL APIs in vLLM](https://vllm.ai/blog/2026-05-28-native-rl-apis), announcing a standardized `WeightTransferEngine` abstract base + four-phase lifecycle (`init_weight_transfer_engine` / `start_weight_update` / `update_weights` / `finish_weight_update`) and a pluggable `WeightTransferEngineFactory.register_engine(...)` extension point. Existing built-in backends: NCCL (packed broadcast), IPC (CUDA shared mem). PR [#43375](https://github.com/vllm-project/vllm/pull/43375) (Anyscale "RDT") plugs Ray Direct Transport into the same factory. + +Three things this changes about our post-PR-#2389 plan: + +### 8.1 Phase 2 + Phase 3 + Phase 4 have a clean upstream form: `MxWeightTransferEngine` + +The native API gives us a standard surface to plug MX into vLLM. The upstream-friendly shape of all our follow-up work is a single adapter class: + +```python +class MxWeightTransferEngine(WeightTransferEngine): + init_info_cls = MxInitInfo # mx_server_url, model_name, worker_rank, ... + update_info_cls = MxUpdateInfo # version, compile_target_filter, target_tp_layout, ... + + def init_transfer_engine(self, init_info: MxInitInfo): + self._receiver = MxV2RefitReceiver(...) + + def receive_weights(self, update_info, load_weights): + plan = self._receiver.discover_v2_sources_for_slice( # Phase 4 + model_name=update_info.model_name, + target_layout=update_info.target_tp_layout, + compile_target_filter=update_info.compile_target_filter, # Phase 3b + required_compile_metadata=update_info.required_compile_metadata, + ) + for name, tensor in self._receiver.receive_via_plan(plan): # Phase 4 stitch + load_weights([(name, tensor)]) + + @classmethod + def trainer_send_weights(cls, iterator, trainer_args): + # wraps MxV2TrainingPublisher.add_tensor(compile_target=..., compile_metadata=...) + .publish() + ... +``` + +Registers via `WeightTransferEngineFactory.register_engine("mx_nixl", MxWeightTransferEngine)`. Estimated ~150-200 LOC adapter on top of the MX clients we already have. + +The in-tree `MxRendezvous` reimplementation in PR #2389 + the `worker_extension_cls` injection in `inference/vllm/worker/nixl_mx.py` both **become unnecessary once this lands upstream**. The native API is the integration seam they were emulating; our adapter consumes it directly. + +### 8.2 Matej is acknowledged in the blog — coordination needed before pushing Phase 2 upstream + +The blog credits *"Prime-RL team (especially Matej Sirovatka) and Junjie Zhang for helping to validate and debug the RL APIs with large-scale runs"*. Matej is actively involved in the design. + +This affects the trajectory of [draft PR #1](https://github.com/KavinKrishnan/prime-rl/pull/1) (Phase 2). The semantic fix it ships is correct regardless of which integration path prime-rl converges on, but the *form* differs: + +- **If Matej is mid-flight on a native-APIs rewrite of `nixl_mx`**: our Phase 2 PR retargets to that path (becomes part of the `MxWeightTransferEngine` adapter rather than landing in the in-tree `MxRendezvous` class). +- **If Matej hasn't started**: Phase 2 lands as-is, the rewrite happens later as a separate PR. + +Either way, ask Matej first. + +### 8.3 Validation in the blog is at DPEP32 across 16 nodes — Phase 4 becomes load-bearing + +The blog reports Prime-RL validated `zai-org/GLM-5.1-FP8` in P/D-disaggregated deployment across **16× 8xH200 nodes** (2 replicas of 4P+4D, both **DPEP32**) for 100+ steps with stable KL mismatch and upward RL curve. **256 GPUs total at DPEP32**. + +At this scale, mixed-TP / mixed-EP is the common case (trainer TP/EP layout almost never matches inference TP/EP layout), and Phase 4's `discover_v2_sources_for_slice` + multi-source `receive_via_plan` is the difference between "works" and "doesn't". This validates Phase 4's design direction and sets the next cluster validation target after kavin (DP=4). + +### 8.4 Two async-RL features the blog ships that we don't yet match + +| Feature | What | Our position | +|---|---|---| +| `pause_generation(mode="keep")` | Pause in-flight requests *without* aborting or waiting; resume from the partial-token state | Today we wait for rollouts to complete before refit. Adopting `keep` mode is the unlock for truly async RL. Queue as a follow-up after Phase 2 lands. | +| Two-phase pause/resume for DPEP | `EngineCore`-level pause state + periodic all-reduce coordination to prevent deadlocks in wide-DP deployments | Not applicable at DP=4 (kavin cluster). Mandatory once we scale to DPEP16+. Track for the next scale-up. | + +The blog also mentions a *"new K8s-native weight transfer engine"* and *"sharding-aware, RDMA-native weight transfer in a generic way"* as ongoing work in the vLLM RL community — both of which describe what MX already is. If MX isn't already the implementation they're referring to, reach out to Robert Shaw (acknowledged as organizing RL-related efforts) to coordinate. + +### 8.5 Updated follow-up list — was four items, now seven + +| # | Item | Effort | +|---|---|---| +| 1 | Validate v0.7.1 end-to-end on kavin (pending — image just deployed) | hours of soak | +| 2 | Send a one-line PR to upstream `Dockerfile.cuda` adding `--extra disagg` | <30 min | +| 3 | Roll the MX server with the `SourceIdentity` round-trip fix; deprecate the `__mx_v2_meta__` sidecar transport | ~1 day | +| 4 | Implement `pull_one(name)` semantic on MX for Ray/RDT-style per-tensor elasticity | ~50 LOC | +| 5 | **NEW**: Sketch + implement `MxWeightTransferEngine` as the upstream-PR form of Phase 2/3/4 | 150-200 LOC adapter | +| 6 | **NEW**: Adopt `pause_generation(mode="keep")` in prime-rl orchestrator for true async RL | ~50 LOC + tests | +| 7 | **NEW**: Coordinate with Matej + Robert Shaw + the vLLM RL roadmap on the K8s-native weight transfer engine. Either contribute MX as the canonical implementation or converge with whoever's already building it. | meeting + scoping | From 4b33e90fec3676e0079976914f07939dbc0d1e3c Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Thu, 28 May 2026 21:03:31 -0700 Subject: [PATCH 08/18] docs(proposals): add post-pr2389-status-and-plan.md + cross-link the three docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The three proposal docs now form a coherent set: - post-pr2389-status-and-plan.md — executive summary; failure-class to fix mapping; mermaid diagram of the data + metadata planes; Phase 0 unblock guidance - post-pr2389-kernel-compile-plan.md — full RFC with phase-by-phase design rationale (unchanged except for cross-link header) - build-notes-2026-05-28.md — operational findings from the source-baked image build, plus the vLLM native RL APIs reframe in section 8 Each doc now has a header block linking to the other two so readers can navigate based on intent (status vs design vs operational). The status-and-plan doc is the natural entry point for someone coming to the work cold; the RFC and build-notes are the deep dives. --- docs/proposals/build-notes-2026-05-28.md | 9 +- .../post-pr2389-kernel-compile-plan.md | 6 + docs/proposals/post-pr2389-status-and-plan.md | 232 ++++++++++++++++++ 3 files changed, 245 insertions(+), 2 deletions(-) create mode 100644 docs/proposals/post-pr2389-status-and-plan.md diff --git a/docs/proposals/build-notes-2026-05-28.md b/docs/proposals/build-notes-2026-05-28.md index 01a8cbb229..8f0ee37443 100644 --- a/docs/proposals/build-notes-2026-05-28.md +++ b/docs/proposals/build-notes-2026-05-28.md @@ -1,7 +1,12 @@ # Build notes — Phase-2 + Phase-3 source-baked image (2026-05-28) -> **Companion to**: [`docs/proposals/post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) -> **Status**: empirical findings from baking Phase 2 (rendezvous fixes) + Phase 3 (conversion-registry extensions) into an ARM64 GB200 image and running it against the live `kavin` namespace. Updates the RFC's framing where the build experience contradicted assumptions in the original RFC. +> **Related docs in this directory**: +> - [`post-pr2389-status-and-plan.md`](./post-pr2389-status-and-plan.md) — executive summary of where things stand + failure-class → fix mapping +> - [`post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) — the full RFC with phase-by-phase design rationale +> +> This doc is the operational findings: how the source-baked image was built, what broke, what we learned, and how the upstream vLLM native RL APIs reframe everything. + +**Status**: empirical findings from baking Phase 2 (rendezvous fixes) + Phase 3 (conversion-registry extensions) into an ARM64 GB200 image and running it against a live GB200 cluster. Updates the RFC's framing where the build experience contradicted assumptions in the original RFC. This document captures **what we learned producing a usable image** containing the two follow-up PRs ([phase-2 rendezvous fixes](https://github.com/KavinKrishnan/prime-rl/pull/1) and [conversion-registry extensions](https://github.com/KavinKrishnan/prime-rl/pull/2)) on top of PR #2389 (HEAD `79ea824d8`). The unit tests for both PRs were already green; this doc records the cluster + image surface area that the unit tests don't cover. diff --git a/docs/proposals/post-pr2389-kernel-compile-plan.md b/docs/proposals/post-pr2389-kernel-compile-plan.md index 8eaad8cabc..78db06fa4e 100644 --- a/docs/proposals/post-pr2389-kernel-compile-plan.md +++ b/docs/proposals/post-pr2389-kernel-compile-plan.md @@ -1,5 +1,11 @@ # Post-PR-#2389 plan: kernel-compile separation, mixed-TP, and MX client adoption +> **Related docs in this directory**: +> - [`post-pr2389-status-and-plan.md`](./post-pr2389-status-and-plan.md) — executive summary of where things stand + failure-class → fix mapping +> - [`build-notes-2026-05-28.md`](./build-notes-2026-05-28.md) — image-build experience, cluster observations, vLLM native RL APIs reframing +> +> This doc is the deep-dive RFC with full phase-by-phase design rationale. + **Status**: Planning doc. Branch `kavink/post-2389-kernel-compile-plan` (NVIDIA-authored, off PR [#2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) HEAD `79ea824d8`). **Premise**: This plan is what we propose to build on top of `nixl_mx` once #2389 merges to `main`. It (a) graduates the in-tree `MxRendezvous` reimplementation onto NVIDIA's published ModelExpress clients, (b) introduces a compile-target registry to fix the trainer-side cutlass-pinning issue surfaced during #2389's FP8 cast-pipeline iteration, and (c) extends the v2 shape registry to handle mixed-TP / sharded-source transfers. **None of this fights the #2389 data plane** — the `Slot` / `TransportPlan` / `NixlAgentWrapper` / `classic_cuda_pool` stack stays untouched. We extend the rendezvous and metadata surfaces only. diff --git a/docs/proposals/post-pr2389-status-and-plan.md b/docs/proposals/post-pr2389-status-and-plan.md new file mode 100644 index 0000000000..f1c838f2fc --- /dev/null +++ b/docs/proposals/post-pr2389-status-and-plan.md @@ -0,0 +1,232 @@ +# prime-rl × ModelExpress — Status of the post-#2389 work and how it addresses the MoE / kernel / quant pain points + +> **Related docs in this directory**: +> - [`post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) — the full RFC with phase-by-phase design rationale +> - [`build-notes-2026-05-28.md`](./build-notes-2026-05-28.md) — image-build experience, cluster observations, vLLM native RL APIs reframing +> +> This doc is the executive summary: where prime-rl × MX stands today, what the failure classes are, and where in the follow-up plan each one gets resolved. + +**Scope**: status of the ModelExpress + NIXL weight-refit integration in prime-rl, the three classes of failure observed on MoE models (kernel-target mismatches, fused-vs-unfused gates across kernels, quantization + packing mismatches), and the four-phase follow-up plan that resolves them. + +**TL;DR**: [PR #2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) lands the first cut of MX-driven NIXL refit in prime-rl and runs at ~10s/cycle on GB200. The three failure classes reduce to one root cause — prime-rl's weight-conversion registry has only two entries, and there is no `compile_target` tag on published bytes that distinguishes DeepGemm vs Cutlass vs other layouts. Two follow-up PRs are in flight that fix this without changing prime-rl's transfer architecture, plus a ~80-LOC-per-kernel extension that can land independently of either PR. + +--- + +## 1. Current state + +### 1.1 PR #2389 — what it does + +[PR #2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) adds a third weight-broadcast type to prime-rl: `weight_broadcast.type = "nixl_mx"`. The data flow per refit cycle is: + +1. **Trainer side** — for each step's state-dict tensor, run a *trainer-side conversion* (today: `bf16_cast` or `fp8_128x128`) that produces the destination layout vLLM expects, with fusion (e.g. `q_proj/k_proj/v_proj → qkv_proj`) and quant already applied. Land the bytes into pre-allocated NIXL-registered buffers (`ShardedSlot` / `GatheredSlot` / `ExpertSlot`). +2. **Rendezvous** — register with ModelExpress (MX) so the inference side knows where to pull from. Lightweight — just metadata; the bytes never go through MX. +3. **NIXL RDMA push** — trainer writes directly into the inference workers' pre-registered parameter buffers over RDMA. Sub-second for a 30 GB model on GB200 cross-node. +4. **Inference side** — vLLM workers receive a `/update_weights` HTTP from the orchestrator, synchronize, and resume serving. + +### 1.2 Validation status + +Validation has run on a GB200 cluster: + +| Component | State | +|---|---| +| Workload | Qwen3-30B-A3B-Instruct-2507, FSDP 2×2, EP=4, 32/128 experts per rank, FLASHINFER attention | +| Image | `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.5.2` | +| Steady-state refit | ~10 s/cycle, reward=1.0000, off-policy=0 | +| Errors | Zero NIXL errors (no `REMOTE_DISCONNECT`, no `NOT_ALLOWED`, no stale-READY) | +| Tested cycles | 4/4 clean, plus 1 trainer restart cycle (clean exit on `max_steps`, k8s respawned, resumed cleanly) | + +PR #2389 is a real working baseline — not a prototype that falls over on first push. + +### 1.3 Two GB200-specific runtime patches folded in + +Two issues on the multi-NIC fabric required patches to the trainer's NIXL setup: + +- **Same-rank-only peer filter** — on GCP GB200 the four RDMA NICs (`rdma-0..3`) are separate L3 subnets, so trainer rank N can only safely peer with inference rank N. Cross-subnet pairs fail. +- **Freshest-per-rank dedup** — when MX has multiple entries at the same `worker_rank` (e.g. after a pod restart), the freshest entry by `updated_at` must be picked, not the first one. Otherwise the inference side picks a stale entry and the NIXL `add_remote_agent` call refuses with `NIXL_ERR_NOT_ALLOWED`. + +Both are applied today as a runtime monkey-patch in the configmap (`patch_nixl_mx.py`). The [Phase 2 draft PR](https://github.com/KavinKrishnan/prime-rl/pull/1) bakes them into the rendezvous class so they are no longer runtime shims. + +### 1.4 What PR #2389 does not yet do (where the failure classes live) + +Two specific gaps map directly to the observed failure classes: + +| Gap | Failure class | +|---|---| +| The conversion registry (`prime_rl/trainer/models/conversions/__init__.py`) has exactly **two entries**: `bf16_cast` and `fp8_128x128`. Anything else raises `NotImplementedError` at startup. | Quantization and packing mismatches — if the target inference engine wants DeepGemm K-major scale interleaving, Cutlass FP8, MXFP4, a non-128 block size, etc., prime-rl cannot produce those bytes at all. | +| There is no `compile_target` tag on published tensors. Receivers cannot tell whether the bytes were compiled for DeepGemm, Cutlass, or anything else. | MoE-specific corruption + fused-vs-unfused-gate mismatches — when trainer and receiver disagree on the layout, the result is silent byte misinterpretation instead of a clean error. The receiver loads the wrong layout and rollouts are corrupted, often without an obvious crash. | + +These are the same root cause expressed three ways: *prime-rl currently assumes the trainer and inference side agree on one layout, baked in at config time, with no runtime check*. Heterogeneous fleets or kernel-target disagreements break it silently. + +--- + +## 2. The follow-up effort + +The follow-up work is captured in the [post-#2389 RFC](https://github.com/KavinKrishnan/prime-rl/blob/kavink/post-2389-kernel-compile-plan/docs/proposals/post-pr2389-kernel-compile-plan.md). It has four phases and a clear strategic direction. + +### Strategic direction + +The design stays on the **trainer-side post-processed transfer** model: the trainer compiles weights into the destination kernel's exact layout, RDMA-writes them straight into the inference worker's pre-registered parameter buffers, and the inference side does no compute on receipt. This is what prime-rl does today; the `ConversionSpec` / `ShardedSlot` / `ExpertSlot` / `GatheredSlot` structure is the right shape. + +The design does *not* move to a scratch-buffer + `model.load_weights()` pattern. The internal NemoRL + Dynamo integration prototype uses that pattern as a workaround for a vLLM-specific quirk (`stacked_params_mapping` does HF→fused name remapping at load time, so the trainer can publish HF-raw names and the receiver figures it out). It works, but it adds receive-side compute (~50-200 ms per refit on FP8 casts) and is not needed for prime-rl, which already does fusion + quant trainer-side and writes directly into pre-allocated fused buffers. + +### Four phases + +| Phase | What | Where | Status | +|---|---|---|---| +| **Phase 2** | Bake the two GB200 runtime patches into `MxRendezvous` permanently; add a `HeartbeatThread` so crashed workers don't leave stale READY entries; swap the in-tree rendezvous over to the published v2 ModelExpress client. | [`KavinKrishnan/prime-rl#1`](https://github.com/KavinKrishnan/prime-rl/pull/1) | Draft PR open, 11/11 unit tests green | +| **Phase 3a** | Add `compile_target` and `compile_metadata` fields to ModelExpress's `TensorDescriptorV2`. Trainer stamps every publish with what layout it produced — `deep_gemm_fp8`, `cutlass_fp8`, `block_size=128`, `gate_fusion=gate_up_swiglu`, etc. | [`ai-dynamo/modelexpress#349`](https://github.com/ai-dynamo/modelexpress/pull/349) | Draft PR open, 6/6 unit tests green | +| **Phase 3b** | Receivers filter sources by `compile_target` + `compile_metadata` *before* RDMA. Mismatched bytes are refused at discovery; no more silent corruption. | Same PR | 4/4 unit tests green | +| **Phase 4** | Multi-source slice picker for the mixed-TP case (e.g. trainer TP=4 publishing to inference TP=8). Each receiver discovers which subset of publisher ranks covers its slice. | Same PR | 8/8 unit tests green | +| **Phase 0 (parallel)** | prime-rl-only: extend the conversion registry from two entries to N. **No MX dependency. Can land independently of the four phases above.** | Open work (see §4 below) | — | + +### How each phase maps to the three failure classes + +| Failure class | What fixes it | +|---|---| +| Quantization and packing issues (`NotImplementedError`) | **Phase 0**: extend `prime_rl/trainer/models/conversions/` with the kernel-specific layouts required. ~80 LOC per kernel. | +| MoE expert layout differences across kernels | **Phase 0** (new `ConversionSpec` per (model × kernel)) + **Phase 3a/3b** to surface mismatches as clean errors. | +| Some gates are fused vs not | **Phase 0** (per-(model × kernel) `ConversionSpec` defining `sources` + `cat_dim` differently) + **Phase 3a/3b** so the trainer's `compile_metadata.gate_fusion` tag travels and is filtered on. | + +In all three cases, **Phase 0 is the immediate unblock** — Phase 3 is the safety net that prevents silent corruption when there is a mismatch. + +--- + +## 3. The architecture, with what's new highlighted + +```mermaid +flowchart TB + classDef today fill:#e6f3ff,stroke:#0066cc,color:#000 + classDef phase0 fill:#fff4cc,stroke:#c68a00,color:#000,stroke-width:2px + classDef phase2 fill:#d4edda,stroke:#28a745,color:#000,stroke-width:2px + classDef phase3 fill:#f8d7da,stroke:#dc3545,color:#000,stroke-width:2px + classDef physical fill:#eee,stroke:#666,color:#000 + + subgraph trainer["Trainer pod — FSDP 2×2 / EP=4"] + TSD[("HF state_dict
(bf16, unfused)")] + TCONV["ConversionSpec registry
Today: bf16_cast, fp8_128x128
Phase 0: + deep_gemm_fp8, cutlass_fp8,
+ MoE fusion variants per model"] + TSLOTS["Slots: ShardedSlot · GatheredSlot · ExpertSlot
(pre-allocated, NIXL-registered)"] + TRENDZ["MxRendezvous
Today: thin wrapper over MxClient
Phase 2: swap to MxV2TrainingPublisher
+ heartbeat + same-rank + freshest-dedup"] + TTAG["Phase 3a: stamp every publish with
compile_target + compile_metadata
(rides on TensorDescriptorV2)"] + TNIXL["NIXL agent (UCX rc_mlx5)"] + TSD --> TCONV --> TSLOTS --> TNIXL + TSLOTS -.metadata.-> TRENDZ + TRENDZ -.-> TTAG + end + + subgraph mx["ModelExpress server (control plane only — no weight bytes)"] + MXSVR[("gRPC + Redis catalog
(workers, status, descriptors, tags)")] + end + + subgraph inference["Inference pod — vLLM v1, 4 ranks, EP=4"] + IRENDZ["MxRendezvous / MxV2RefitReceiver"] + IFILT["Phase 3b: discover_v2_sources(
compile_target_filter={…},
required_compile_metadata={…})
Mismatched bytes refused BEFORE RDMA"] + IPICK["Phase 4: discover_v2_sources_for_slice(…)
multi-source picker for mixed-TP"] + IPARAMS["vLLM model.named_parameters()
(pre-registered, fused destinations like qkv_proj)"] + INIXL["NIXL agent (UCX rc_mlx5)"] + IRENDZ --> IFILT --> IPICK + INIXL --> IPARAMS + end + + TNIXL ==>|"RDMA WRITE
(post-processed bytes,
~10s for 30GB Qwen3-30B-A3B)"| INIXL + TRENDZ <-.->|"publish metadata + tags"| MXSVR + MXSVR <-.->|"discover + filter"| IRENDZ + + class TSD,TSLOTS,IPARAMS today + class TNIXL,INIXL,TCONV today + class TCONV phase0 + class TRENDZ,IRENDZ phase2 + class TTAG,IFILT,IPICK phase3 + class MXSVR physical +``` + +Legend: 🟦 today's PR #2389 surface, 🟨 Phase 0 (immediate unblock), 🟩 Phase 2, 🟥 Phase 3 / 4. + +The data plane (RDMA write of post-processed bytes from trainer NIC to inference NIC) stays exactly as it is today. Everything new lives in the **metadata plane** (what gets stamped on the publish) and the **registry** (what conversions are available trainer-side). + +--- + +## 4. Phase 0 — extending the conversion registry + +This is independent of both PRs in flight. The conversion registry is plug-in: + +``` +prime_rl/trainer/models/conversions/ +├── __init__.py # registry, select_default_conversion() +├── bf16_cast.py # registered as "bf16_cast" ← today +└── fp8_blockwise.py # registered as "fp8_128x128" ← today +``` + +To add support for a new kernel layout, drop in a new file: + +```python +# prime_rl/trainer/models/conversions/deep_gemm_fp8.py +from prime_rl.trainer.models.conversions import register + +def _convert(src, dst, scale): + # src: bf16 HF-format tensor + # dst: pre-allocated FP8 destination buffer (kernel's expected layout) + # scale: paired scale buffer if requires_scale=True + ... + +register("deep_gemm_fp8", _convert, requires_scale=True) +``` + +The default-conversion picker in `__init__.py:select_default_conversion` reads the inference model's `config.json` and chooses based on `quantization_config`. Today it knows two cases; extend the if/elif chain for new ones. + +**For MoE specifically — the gate-fusion choice lives in the `ConversionSpec` definitions per model.** Example pattern in `prime_rl/trainer/models/qwen3_moe/converting_qwen3_moe.py`: + +```python +ConversionSpec( + dst="mlp.experts.gate_up_proj.weight", # fused destination + sources=("mlp.experts.gate_proj.weight", # source 1 + "mlp.experts.up_proj.weight"), # source 2 + cat_dim=0, # concat along dim 0 +) +``` + +For an unfused-gate kernel target, a parallel set of `ConversionSpec`s would emit `gate_proj` and `up_proj` separately. The framework supports multiple `ConversionSpec` sets via the `Conversion` selector in `MaybeQuantize`. + +Total scope per kernel target: ~80 LOC for the conversion function + per-model `ConversionSpec` additions for the required models. + +Once Phase 3a lands in MX and Phase 2 lands in prime-rl, every publish automatically carries the `compile_target` tag (e.g. `"deep_gemm_fp8"`) and `compile_metadata` (e.g. `{"block_size": 128, "scale_layout": "K-major", "gate_fusion": "unfused"}`), and any inference worker expecting a different layout refuses the source cleanly at discovery instead of corrupting rollouts. + +--- + +## 5. Information required to write the missing conversion entries + +Two inputs are required per kernel target before the missing conversion entries can be written: + +1. **Inference engine + quant config target**. Examples: "vLLM 0.7 + DeepGemm grouped-GEMM MoE FP8 block-128", "vLLM 0.7 + Cutlass MoE FP8 with K-major scales", "Triton MoE with unfused gates and BF16". The exact kernel name + scale layout determines the required output format. +2. **Model architectures involved**. Existing conversion specs cover Qwen3, Qwen3-MoE, GLM-MoE-DSA, Nemotron-H, MiniMax-M2, and Laguna. Any model not yet in `prime_rl/trainer/models/` requires an additional ~150 LOC for the model adapter. + +--- + +## 6. References + +**Branches and PRs** + +- Upstream: [PrimeIntellect-ai/prime-rl#2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) +- Phase 2 draft: [KavinKrishnan/prime-rl#1](https://github.com/KavinKrishnan/prime-rl/pull/1) — `MxRendezvous` heartbeat + dedup + same-rank filter +- Phase 3+4 draft: [ai-dynamo/modelexpress#349](https://github.com/ai-dynamo/modelexpress/pull/349) — `compile_target` + multi-source slice picker +- RFC: [`KavinKrishnan/prime-rl:kavink/post-2389-kernel-compile-plan`](https://github.com/KavinKrishnan/prime-rl/blob/kavink/post-2389-kernel-compile-plan/docs/proposals/post-pr2389-kernel-compile-plan.md) — the full post-#2389 plan +- Build notes: [`KavinKrishnan/prime-rl:kavink/post-2389-kernel-compile-plan/docs/proposals/build-notes-2026-05-28.md`](https://github.com/KavinKrishnan/prime-rl/blob/kavink/post-2389-kernel-compile-plan/docs/proposals/build-notes-2026-05-28.md) — image-build experience, cluster observations, and the vLLM-native-RL-APIs reframing +- Inline review on upstream PR: six inline + one summary, posted at PR HEAD `79ea824d8` — covering cross-subnet `add_remote_agent`, freshest-per-rank dedup, missing heartbeat, hardcoded 1200s timeout, unconditional `update_mla_absorbed_weights`, HSDP barrier ordering + +**Code locations in prime-rl** + +- Conversion registry: `src/prime_rl/trainer/models/conversions/__init__.py` +- Existing conversions: `src/prime_rl/trainer/models/conversions/{bf16_cast,fp8_blockwise}.py` +- `ConversionSpec` definition: `src/prime_rl/trainer/models/conversion_spec.py` +- Per-model `ConversionSpec` registrations: `src/prime_rl/trainer/models//converting_.py` +- Slots: `src/prime_rl/trainer/models/slots.py` +- NIXL+MX trainer broadcast: `src/prime_rl/trainer/rl/broadcast/nixl_mx.py` +- NIXL+MX inference worker: `src/prime_rl/inference/vllm/worker/nixl_mx.py` +- Rendezvous: `src/prime_rl/transport/mx_rendezvous.py` +- Transport plan: `src/prime_rl/transport/transport_plan.py` + +**Cluster** + +- GB200 dev cluster, dedicated namespace +- Image (today): `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.5.2` +- Source-baked Phase 2 + Phase 3 image: `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3` +- ModelExpress server: deployed in the same namespace, Redis-backed From 78c0e0c6590d5777c9a7138bd6ef8a673539492d Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Mon, 1 Jun 2026 19:16:12 -0700 Subject: [PATCH 09/18] =?UTF-8?q?RFC:=20weight=5Fbroadcast.type=3D"mx=5Fv2?= =?UTF-8?q?"=20=E2=80=94=20the=20complete=20prime-rl=20=C3=97=20ModelExpre?= =?UTF-8?q?ss=20design?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a single new weight-broadcast type that consolidates every post-#2389 optimization into one config knob: weight_broadcast.type = "mx_v2" Coexists with the existing "nixl_mx" (PR #2389) for migration; no behavior of "nixl_mx" is affected by this change. What's included --------------- Trainer side * src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py - NIXLMxV2WeightBroadcast — drop-in replacement for NIXLMxWeightBroadcast (PR #2389) - Uses MxV2TrainingPublisher from the MX v2 fat clients - Heartbeat + freshest-per-rank dedup + same-rank routing baked in (Phase 2 — no more configmap monkeypatch) - Stamps every publish with compile_target + compile_metadata from the conversion registry (Phase 3a) — receivers can refuse mismatched layouts at discovery - Preserves prime-rl's trainer-side conversion + slot layout + HSDP barrier ordering unchanged - Per-step: slot.fill_from(state_dict) → publisher.add_tensor() ×N → publisher.publish(version=step) → publisher.mark_ready() Inference side * src/prime_rl/inference/vllm/worker/nixl_mx_v2.py - NIXLMxV2WeightUpdateWorker — pull-mode worker extension - Uses MxWeightTransferEngine (vLLM WeightTransferEngine adapter from MX PR #349 — same shape as Anyscale RDT PR #43375) - Phase 3b receiver-side compile_target_filter — refuses incompatible bytes BEFORE RDMA - Tree fan-out via publish_self_as_replica=True (TensorHub pattern; receivers republish so newcomers pull from peers instead of trainer) - Surfaces per-cycle metrics (bytes/Gbps/discovery_seconds/ source_worker_rank) back through the RPC return value Config + selector * packages/prime-rl-configs/src/prime_rl/configs/trainer.py - New MxV2WeightBroadcastConfig with the Phase 2/3 knobs - Discriminated union extended; existing configs unchanged * src/prime_rl/trainer/rl/broadcast/__init__.py - Selector dispatches "mx_v2" to NIXLMxV2WeightBroadcast Inference server + orchestrator wiring * src/prime_rl/inference/vllm/server.py - WORKER_EXTENSION_CLS["mx_v2"] mapping - POST /init_nixl_mx_v2 (mirrors /init_nixl_mx) - POST /update_weights_v2 (per-cycle refit; returns metrics) * src/prime_rl/utils/client.py - init_nixl_mx_v2_broadcast() async helper - update_weights_v2() async helper that returns per-server metrics Image * Dockerfile.cuda.mx-v2 — overlay on v0.7.1-kavin-phase2-phase3: 1. `uv pip install` the MX PR #349 branch (Phase 4 + engine) 2. COPY the 5 v2 prime-rl files 3. Smoke tests at build time (engine import, flash_attn ABI) * docs/proposals/image-build-mx-v2.md — build mechanics + A/B deployment plan RFC * docs/proposals/post-pr2389-mx-v2.md — full design doc: - Capability comparison table (nixl_mx vs mx_v2) - Module-by-module design - Migration plan (v0.x → v0.x+1 deprecation → v0.x+2 removal) - Validation matrix against PR #2389 on the same workload - References to all related PRs (#1 Phase 2, #2 Phase 3, ai-dynamo/modelexpress#349, vLLM #43375, TensorHub paper, vLLM native RL APIs blog) What's NOT in this commit ------------------------- * Unit tests for the new prime-rl integration files (TODO) * Built + pushed image artifact (TODO — needs Docker buildx) * End-to-end cluster validation on Qwen3-30B-A3B (TODO — needs cluster booking + parallel deployment) * Deletion of "nixl_mx" code (intentional — coexist for ≥1 release) The 58 MX-side unit tests on PR #349 already cover the v2 fat clients + engine adapter that this RFC consumes. The new tests TODO is for the thin prime-rl-side glue (~250 LOC across the 2 new files). --- Dockerfile.cuda.mx-v2 | 36 +++ docs/proposals/image-build-mx-v2.md | 96 +++++++ docs/proposals/post-pr2389-mx-v2.md | 271 ++++++++++++++++++ .../src/prime_rl/configs/trainer.py | 66 ++++- src/prime_rl/inference/vllm/server.py | 51 ++++ .../inference/vllm/worker/nixl_mx_v2.py | 214 ++++++++++++++ src/prime_rl/trainer/rl/broadcast/__init__.py | 5 + .../trainer/rl/broadcast/nixl_mx_v2.py | 253 ++++++++++++++++ src/prime_rl/utils/client.py | 72 +++++ 9 files changed, 1063 insertions(+), 1 deletion(-) create mode 100644 Dockerfile.cuda.mx-v2 create mode 100644 docs/proposals/image-build-mx-v2.md create mode 100644 docs/proposals/post-pr2389-mx-v2.md create mode 100644 src/prime_rl/inference/vllm/worker/nixl_mx_v2.py create mode 100644 src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py diff --git a/Dockerfile.cuda.mx-v2 b/Dockerfile.cuda.mx-v2 new file mode 100644 index 0000000000..ddb6b9f244 --- /dev/null +++ b/Dockerfile.cuda.mx-v2 @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 prime intellect & contributors +# SPDX-License-Identifier: Apache-2.0 +# +# Overlay Dockerfile for the v2 prime-rl × ModelExpress integration. +# Layers on top of v0.7.1-kavin-phase2-phase3 (which already has Phase 2 +# + Phase 3 source baked in). See docs/proposals/image-build-mx-v2.md. + +FROM nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3 + +USER root + +# ──────────────────────────────────────────────────────────────────── +# 1. Update modelexpress to the PR #349 branch +# (Phase 4 multi-source slice picker + MxWeightTransferEngine) +# ──────────────────────────────────────────────────────────────────── +RUN --mount=type=cache,target=/app/.cache/uv \ + /app/.venv/bin/uv pip install --no-deps --reinstall \ + "modelexpress @ git+https://github.com/ai-dynamo/modelexpress.git@kavink/post-2389-phase3-4#subdirectory=modelexpress_client/python" + +# ──────────────────────────────────────────────────────────────────── +# 2. Overlay the v2 prime-rl source files +# ──────────────────────────────────────────────────────────────────── +COPY --chown=appuser:appuser src/prime_rl/transport/ /app/src/prime_rl/transport/ +COPY --chown=appuser:appuser src/prime_rl/inference/vllm/worker/nixl_mx_v2.py /app/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py +COPY --chown=appuser:appuser src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py /app/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py +COPY --chown=appuser:appuser src/prime_rl/trainer/rl/broadcast/__init__.py /app/src/prime_rl/trainer/rl/broadcast/__init__.py +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/packages/prime-rl-configs/src/prime_rl/configs/trainer.py + +# ──────────────────────────────────────────────────────────────────── +# 3. Smoke-test that the v2 path imports cleanly +# ──────────────────────────────────────────────────────────────────── +RUN /app/.venv/bin/python -c "from modelexpress.vllm_weight_transfer import MxWeightTransferEngine, MxInitInfo, MxUpdateInfo; print('engine adapter:', MxWeightTransferEngine)" +RUN /app/.venv/bin/python -c "from modelexpress.nemo_rl_v2 import MxV2TrainingPublisher, MxV2RefitReceiver, TrainerWorldLayout; print('v2 fat clients OK')" +RUN /app/.venv/bin/python -c "import flash_attn; import flash_attn.ops; print('flash_attn ABI:', flash_attn.__version__)" + +USER appuser diff --git a/docs/proposals/image-build-mx-v2.md b/docs/proposals/image-build-mx-v2.md new file mode 100644 index 0000000000..36e318276c --- /dev/null +++ b/docs/proposals/image-build-mx-v2.md @@ -0,0 +1,96 @@ +# Image build plan — `prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2` + +> Companion to [`post-pr2389-mx-v2.md`](./post-pr2389-mx-v2.md) (Workstream B). + +## Goal + +Ship a single deployable image that contains everything `weight_broadcast.type = "mx_v2"` needs: + +- Phase 2 source (heartbeat + same-rank + freshest-per-rank) — *already in* `v0.7.1-kavin-phase2-phase3` +- Phase 3 source (conversion-registry extensions, compile_target tagging) — *already in* `v0.7.1-kavin-phase2-phase3` +- **NEW:** Phase 4 + `MxWeightTransferEngine` from MX PR #349 (`kavink/post-2389-phase3-4` branch) +- **NEW:** the v2 prime-rl source files from this RFC (`nixl_mx_v2.py` × 2, updated `__init__.py`, updated `configs/trainer.py`) + +## Strategy: overlay, not from-scratch + +The build-notes ([`build-notes-2026-05-28.md`](./build-notes-2026-05-28.md) §2) measured **6h 45min** for a from-scratch `Dockerfile.cuda` build on QEMU arm64. The overlay for v0.7.1 was **~3 min**. We overlay. + +```dockerfile +# Dockerfile.cuda.mx-v2 +FROM nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3 + +# ── 1. Update modelexpress to the PR #349 branch (Phase 4 + engine adapter) ── +RUN --mount=type=cache,target=/app/.cache/uv \ + uv pip install --no-deps --reinstall \ + "modelexpress @ git+https://github.com/ai-dynamo/modelexpress.git@kavink/post-2389-phase3-4#subdirectory=modelexpress_client/python" + +# ── 2. Overlay the v2 prime-rl files ───────────────────────────────────────── +COPY src/prime_rl/transport/ /app/src/prime_rl/transport/ +COPY src/prime_rl/inference/vllm/worker/nixl_mx_v2.py /app/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py +COPY src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py /app/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py +COPY src/prime_rl/trainer/rl/broadcast/__init__.py /app/src/prime_rl/trainer/rl/broadcast/__init__.py +COPY packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +``` + +Build: + +```bash +docker buildx build \ + --platform linux/arm64 \ + --file Dockerfile.cuda.mx-v2 \ + --tag nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2 \ + --push \ + . +``` + +Estimated: ~5 min (one git clone + uv install for modelexpress, then a 5-file COPY). + +## What about the `flash_attn` ABI issue? + +The earlier `flash_attn.ops` `ModuleNotFoundError` we saw was inside the *MX overlay benchmark pod* — that pod was using the v0.5.2 image with a Python overlay, and the v0.5.2 image's vLLM expects an older flash_attn layout. + +**`v0.7.1-kavin-phase2-phase3` does NOT have that problem** — it was built from a fresh `Dockerfile.cuda` that pins `flash-attn==2.8.3+cu128torch2.11` (via the `flash-attn` extra in `pyproject.toml`) and rebuilds vLLM against it. Since we're overlaying on top of that, we inherit the fixed pin and never touch the ABI. + +**Validation step:** confirm by running `python -c "import flash_attn; import flash_attn.ops"` inside the new image before doing anything else. + +## Deployment + +Same configmap pattern as `v0.7.1`. New trainer + inference manifests with `weight_broadcast.type = "mx_v2"`: + +```yaml +# configmap delta for v2 deployment +weight_broadcast: + type: mx_v2 # was: nixl_mx + host: modelexpress-server.kavin.svc.cluster.local + port: 8001 + same_rank_only: true # Phase 2 default + dedup_freshest_per_rank: true # Phase 2 default + publish_compile_target: true # Phase 3 default + publish_self_as_replica: true # tree fan-out default + inference_world_size: 4 + inference_model_name: Qwen/Qwen3-30B-A3B-Instruct-2507 +``` + +For A/B against PR #2389, run **two parallel deployments** in the kavin namespace under separate Job names — `prime-rl-nixl-mx-v0-7-1` (baseline) vs `prime-rl-mx-v2-kavin` (this work). Both use the same MX server (different `mx_source_id`s by content hash). + +## What to measure + +The validation matrix from `post-pr2389-mx-v2.md` §Validation plan, against the same workload PR #2389 was validated with: + +| Config | Refit cycle | Bandwidth | Notes | +|---|---|---|---| +| `nixl_mx` baseline | target ~10 s | ~80 ms NIXL push | PR #2389 push | +| `mx_v2` defaults | target ≤ 10 s | should match | Pull semantics, single-receiver | +| `mx_v2` + 4 receivers | target ≤ 10 s | `fanout_factor > 1.0` | Tree fan-out engages | +| `mx_v2` + filter mismatch | refuse at discovery, 0 RDMA bytes | — | Phase 3 safety net in production | +| `mx_v2` elastic scale-up | 2 → 4 replicas mid-training | new receivers join under 2 s | Phase 2 same-rank routing in production | + +## Open items + +| # | Item | Owner / status | +|---|---|---| +| 1 | Confirm v0.7.1 image has `flash_attn.ops` importable | Smoke test, ~30 s | +| 2 | Build + push `v0.7.2-kavin-mx-v2` | ~5 min after Phase A code lands | +| 3 | Deploy parallel A/B Jobs in kavin namespace | Cluster booking — ~5 hours of GPU node time | +| 4 | Capture per-cycle timing + bandwidth JSONs into `pensieve/RL/PrimeRL/results/` | Direct mirror of Slide 9 / Table B in the presentation | +| 5 | Replace Slide 9 Table A's synthetic numbers with end-to-end numbers | Once #4 is in hand | diff --git a/docs/proposals/post-pr2389-mx-v2.md b/docs/proposals/post-pr2389-mx-v2.md new file mode 100644 index 0000000000..5ee5810832 --- /dev/null +++ b/docs/proposals/post-pr2389-mx-v2.md @@ -0,0 +1,271 @@ +# RFC: `weight_broadcast.type = "mx_v2"` — the complete prime-rl × ModelExpress design + +> **Status:** Draft. Targets [`PrimeIntellect-ai/prime-rl`](https://github.com/PrimeIntellect-ai/prime-rl) upstream as a follow-up to [PR #2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389). +> +> **Companion docs in this branch:** +> - [`post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) — the four-phase plan this RFC consolidates +> - [`post-pr2389-status-and-plan.md`](./post-pr2389-status-and-plan.md) — current status of the four phases +> - [`build-notes-2026-05-28.md`](./build-notes-2026-05-28.md) — image-build mechanics for the source-baked Phase 2 + Phase 3 image +> - [`image-build-mx-v2.md`](./image-build-mx-v2.md) — the v2 overlay-image plan (Workstream B of this RFC) + +## TL;DR + +This RFC proposes a single new weight-broadcast type, **`weight_broadcast.type = "mx_v2"`**, that contains every optimization the post-#2389 RFC identifies, behind one config knob. It coexists with the existing `"nixl_mx"` (PR #2389) for migration; no behavior of `"nixl_mx"` changes. + +| Capability | `nixl_mx` (PR #2389) | `mx_v2` (this RFC) | +|---|---|---| +| Data plane | NIXL RDMA WRITE (push) | NIXL RDMA WRITE *or* READ (engine-dispatched) | +| Control plane | In-tree `MxRendezvous` (185 LOC, 4 gRPC calls) | `MxWeightTransferEngine` over MX v2 fat clients | +| Heartbeat + freshest-dedup + same-rank routing | Runtime monkey-patch via configmap | Baked in (Phase 2) | +| compile_target safety net | None — silent corruption on layout mismatch | Phase 3 — refuses mismatched at discovery, before RDMA | +| Mixed-TP / mixed-EP | Requires matching layouts | Phase 4 — multi-source slice picker + stitching | +| Tree fan-out (pipeline replication) | None — single-source from trainer | Receivers republish; trainer NIC stops being the bottleneck past ~4 receivers | +| MoE expert filter | None — every receiver pulls every expert | Bandwidth-proportional to EP (8× savings for EP=8 on a 192-expert model) | +| vLLM native API alignment | Bespoke worker extension | Targets `WeightTransferEngine` ABC (same shape as Anyscale RDT PR #43375) | +| Net LOC | baseline | +255 (new) / −185 (delete in-tree rendezvous in follow-up) | + +## Motivation + +PR #2389 ships a working NIXL+MX path on GB200 (~10 s/cycle on Qwen3-30B-A3B). Production hardening since then surfaced six issues that map to one root cause: **prime-rl owns the rendezvous + transport layer, so every cross-cutting capability (heartbeat, compile-target metadata, slice picking) lives in prime-rl too.** + +The cross-cutting work — heartbeat, dedup, shape registry, MoE expert filter, mixed-TP slice picker, tree fan-out — has been built and unit-tested **inside ModelExpress** (PR [#349](https://github.com/ai-dynamo/modelexpress/pull/349), branch `kavink/post-2389-phase3-4`). What's missing is the single integration PR that consumes all of it from prime-rl. + +This RFC is that integration PR. + +## Design + +### 1. New config type + +`packages/prime-rl-configs/src/prime_rl/configs/trainer.py` adds `MxV2WeightBroadcastConfig`: + +```python +class MxV2WeightBroadcastConfig(BaseWeightBroadcastConfig): + type: Literal["mx_v2"] = "mx_v2" + + # ─── Control plane ────────────────────────────────────────────── + host: str = "localhost" + port: int = 29501 + timeout: int = 1200 + + # ─── Discovery (Phase 2) ──────────────────────────────────────── + same_rank_only: bool = True + """GB200/EFA multi-NIC fabrics: receivers pull from same-rank trainer only.""" + dedup_freshest_per_rank: bool = True + """When multiple READY entries share a worker_rank, pick the freshest by updated_at.""" + + # ─── Layout metadata (Phase 3) ────────────────────────────────── + publish_compile_target: bool = True + """Trainer stamps every publish with the conversion's compile_target tag.""" + compile_target_filter: list[str] | None = None + """Receiver-side whitelist. None = accept anything (back-compat). Set to + {'cutlass_fp8'} or {'hf_raw','cutlass_fp8'} to refuse mismatches before RDMA.""" + + # ─── Sharding (Phase 4) ───────────────────────────────────────── + target_tp_layout: TargetTPLayout | None = None + """None = matched-TP fast path (single-source same-rank pull). + Set when trainer TP/EP layout differs from inference.""" + + # ─── Pipeline replication (TensorHub pattern) ─────────────────── + publish_self_as_replica: bool = True + """After a successful receive, inference workers republish themselves as + sources; subsequent receivers can pull from peers instead of the trainer.""" + + inference_world_size: int = 1 + inference_model_name: str = "" +``` + +### 2. Selector dispatch + +`src/prime_rl/trainer/rl/broadcast/__init__.py` adds one `elif`: + +```python +elif config.type == "mx_v2": + from prime_rl.trainer.rl.broadcast.nixl_mx_v2 import NIXLMxV2WeightBroadcast + + assert parallel_dims is not None, "mx_v2 requires parallel_dims" + return NIXLMxV2WeightBroadcast(output_dir, config, parallel_dims) +``` + +### 3. New trainer broadcast — `src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py` + +Replaces the bespoke `MxRendezvous` + manual NIXL `post_write` flow with `MxV2TrainingPublisher`: + +```python +class NIXLMxV2WeightBroadcast(WeightBroadcast): + def __init__(self, output_dir, config, parallel_dims): + self.publisher = MxV2TrainingPublisher( + agent_name=make_agent_name("trainer", world.rank), + device_id=torch.cuda.current_device(), + mx_server_url=f"{config.host}:{config.port}", + worker_rank=world.rank, + world_layout=TrainerWorldLayout(...), # from parallel_dims + ) + + def lazy_init(self, model): + self.publisher.initialize(model_name=config.inference_model_name) + # Slot allocation + conversion still owned by prime-rl + self.model_slots = model.build_slots(...) + self.conversion = select_default_conversion(...) + + @torch.no_grad() + def broadcast_weights(self, model, step): + # 1. Run trainer-side conversion (prime-rl owns this) + for slot in self.model_slots: + slot.fill_from(model.state_dict(), self.conversion) + + # 2. Register each slot tensor with the publisher, tagged with + # compile_target + compile_metadata from the conversion registry + for slot in self.model_slots: + for name, tensor, _ in slot.buffers: + self.publisher.add_tensor( + name=name, + tensor=tensor, + compile_target=self.conversion.compile_target, + compile_metadata=self.conversion.compile_metadata, + # MoE expert metadata where applicable + is_expert=slot.is_expert, + expert_axis=slot.expert_axis, + owned_expert_ids=slot.owned_expert_ids, + ) + + # 3. One publish() per step + self.publisher.publish(version=step) + self.publisher.mark_ready() +``` + +**Key invariants preserved from PR #2389:** +- Trainer-side conversion (FP8 packing, fusion, sharding) — prime-rl owns the kernel +- Slot layout — `Sharded` / `Gathered` / `Expert` slots stay +- HSDP barrier — `dp_replicate > 1` only publishes from rank 0 +- Per-step lifecycle — `lazy_init` on first call, `broadcast_weights` every step + +**What changes vs PR #2389:** +- The push (`nixl_agent.post_write` loop) becomes a publish (`publisher.add_tensor` × N + `publisher.publish`); the actual NIXL WRITE is now driven from the receiver side via `receive_weights_scratch` +- Trainer no longer needs to know inference world size in advance — receivers discover via catalog + +### 4. New inference worker — `src/prime_rl/inference/vllm/worker/nixl_mx_v2.py` + +Uses `MxWeightTransferEngine` via vLLM's `worker_extension_cls`: + +```python +class NIXLMxV2WeightUpdateWorker(Worker): + """vLLM worker extension for the v2 pull path.""" + + def init_nixl_mx_v2(self, host: str, port: int, rank_offset: int, **engine_init_kwargs): + from modelexpress.vllm_weight_transfer import MxInitInfo, MxWeightTransferEngine + global_rank = rank_offset + self.device.index + inference_model_name = self.model_runner.model_config.model + + self.engine = MxWeightTransferEngine(init_info=MxInitInfo( + mx_server_url=f"{host}:{port}", + model_name=inference_model_name, + worker_rank=global_rank, + agent_name=make_agent_name("inference", global_rank), + device_id=self.device.index, + publish_self_as_replica=engine_init_kwargs.get("tree_fanout", True), + )) + + @torch.no_grad() + def update_weights_via_mx_v2(self, step: int, *, compile_target_filter=None, target_tp_layout=None) -> None: + from modelexpress.vllm_weight_transfer import MxUpdateInfo + self.engine.receive_weights( + MxUpdateInfo( + version=step, + compile_target_filter=set(compile_target_filter) if compile_target_filter else None, + target_tp_layout=target_tp_layout, + timeout_seconds=self.config.timeout, + ), + load_weights=self._load_weights_batch, + ) + # Same post-load housekeeping as PR #2389 + update_mla_absorbed_weights(self.raw_model) + + def _load_weights_batch(self, batch: list[tuple[str, torch.Tensor]]) -> None: + """Feed yielded tensors through vLLM's model.load_weights(). + vLLM handles HF→fused name remapping via stacked_params_mapping.""" + self.raw_model.load_weights(batch) +``` + +**Key changes vs PR #2389:** +- No pre-registered NIXL buffers on inference side (uses `receive_weights_scratch` under the hood) +- Trainer push → receiver pull semantics +- HF-format publish → vLLM `load_weights` handles fused param remapping (matches NeMo-RL pattern) + +### 5. Conversion registry — already done + +`src/prime_rl/trainer/models/conversions/__init__.py` was extended on `kavink/post-2389-conversion-registry-extensions` ([Draft PR #2](https://github.com/KavinKrishnan/prime-rl/pull/2)) with: +- `compile_target` + `compile_metadata` fields on `ConversionEntry` +- `cutlass_fp8_e4m3_per_channel` registered alongside `bf16_cast` and `fp8_128x128` +- 19/19 unit tests green + +This RFC just consumes that work — no additional changes to conversion-registry code. + +### 6. Image — overlay on `v0.7.1-kavin-phase2-phase3` + +The v2 source layers cleanly on top of `prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3`, which already contains Phase 2 + Phase 3. The Dockerfile is a 5-line overlay: + +```dockerfile +FROM nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3 +COPY src/prime_rl/transport/ /app/src/prime_rl/transport/ +COPY src/prime_rl/inference/vllm/worker/nixl_mx_v2.py /app/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py +COPY src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py /app/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py +COPY src/prime_rl/trainer/rl/broadcast/__init__.py /app/src/prime_rl/trainer/rl/broadcast/__init__.py +COPY packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +``` + +Plus a 1-line `uv sync --no-deps --reinstall-package modelexpress` if we need to pull in the MX-side engine adapter that ships with PR #349. + +See [`image-build-mx-v2.md`](./image-build-mx-v2.md) for the build mechanics + cluster deployment steps. + +## Migration + +| Phase | Config | Status | +|---|---|---| +| **v0.x** (now) | `nixl_mx` and `mx_v2` coexist | `nixl_mx` remains the documented default. `mx_v2` opt-in. | +| **v0.x+1** | `nixl_mx` deprecated with warning | After 4 weeks of `mx_v2` bake-time on `kavin` + at least one external user. | +| **v0.x+2** | `nixl_mx` removed | After another release cycle. Tracks vLLM's native `WeightTransferEngine` API merge — once that's available, `mx_v2` registers as `backend="mx_nixl"` and `WeightTransferConfig` becomes the recommended entry point. | + +No user is forced to migrate. `nixl_mx` users get heartbeat + dedup via PR #1 (Phase 2 source-bake) regardless of this RFC. + +## Validation plan + +### Pre-merge (this branch, before opening upstream) + +1. **Unit tests:** existing 58 MX-side tests (35 v2 shape/picker + 14 engine + 9 bench) + new tests for the prime-rl integration files (≥10 new unit tests covering `NIXLMxV2WeightBroadcast.broadcast_weights` and `NIXLMxV2WeightUpdateWorker.update_weights_via_mx_v2`). + +2. **Cluster A/B on `kavin` namespace, GB200, Qwen3-30B-A3B-Instruct-2507:** + + | Config | Refit cycle | Bandwidth | Notes | + |---|---|---|---| + | `type=nixl_mx` (PR #2389 baseline) | target ~10 s | ~80 ms NIXL push | Push, no filter, no fan-out | + | `type=mx_v2`, defaults | target ≤ 10 s | should match | Pull, filter on, fan-out on but with 1 inference replica (no-op) | + | `type=mx_v2` + 4 inference replicas | target ≤ 10 s | `fanout_factor > 1.0` | Tree fan-out kicks in | + | `type=mx_v2` + mismatched filter | should refuse at discovery | 0 RDMA bytes | Compile-target safety net under production workload | + +3. **Elastic scale-up:** scale inference from 2 → 4 replicas mid-training. With Phase 2 same-rank routing baked in, all 4 should join cleanly without orchestrator restart. Measured via the harness in `modelexpress/benchmarks/bench_elastic_scaling.py` but against the real model, not synthetic tensors. + +### Post-merge (upstream) + +1. Add `mx_v2` smoke test to upstream prime-rl CI. +2. Coordinate with upstream PrimeIntellect on a real-RL-job validation matrix. + +## What this RFC does *not* do + +| Out of scope | Why | +|---|---| +| Delete PR #2389's `nixl_mx` code | Coexist for ≥1 release cycle. PR #1 (Phase 2 fixes) lands into `nixl_mx` regardless. | +| Implement delta-sync ([HF blog](https://huggingface.co/blog/delta-weight-sync)) | Layer-2 optimization — composes orthogonally with `mx_v2` and lands separately once vLLM merges `pause_generation(mode="keep")`. | +| Implement true async refit (Composer 2 / Fireworks) | Layer-2 optimization. Same reason as delta-sync. | +| Cross-DC / WAN | TensorHub-pattern; `mx_v2` already supports it via MX catalog metadata, but no cross-DC validation here. | +| Production hardening of MX server | Owned by `ai-dynamo/modelexpress`. We consume; we don't fork. | + +## References + +- [PR #2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) — the baseline +- [`KavinKrishnan/prime-rl#1`](https://github.com/KavinKrishnan/prime-rl/pull/1) — Phase 2 source-baked rendezvous fixes (heartbeat + dedup + same-rank) +- [`KavinKrishnan/prime-rl#2`](https://github.com/KavinKrishnan/prime-rl/pull/2) — Phase 3 conversion-registry extensions (compile_target + cutlass_fp8) +- [`ai-dynamo/modelexpress#349`](https://github.com/ai-dynamo/modelexpress/pull/349) — Phase 3 + 4 + `MxWeightTransferEngine` adapter (v2 fat clients, multi-source slice picker, vLLM API adapter) +- [vLLM PR #43375](https://github.com/vllm-project/vllm/pull/43375) — Anyscale Ray Direct Transport; same `WeightTransferEngine` API shape, complementary Ray-based catalog choice +- [vLLM native RL APIs blog](https://blog.vllm.ai/2026/05/28/native-rl-apis.html) — the upstream API surface this RFC targets +- [TensorHub paper (arXiv 2604.09107)](https://arxiv.org/pdf/2604.09107v1) — Reference-Oriented Storage, pipeline replication, mutability contract +- [`post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) — the four-phase plan this RFC consolidates diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index fd61ceb49a..a890ed85dd 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -494,8 +494,72 @@ class NIXLMxWeightBroadcastConfig(BaseWeightBroadcastConfig): """HF model name of the inference target.""" +class MxV2WeightBroadcastConfig(BaseWeightBroadcastConfig): + """v2 weight broadcast over NIXL + ModelExpress fat clients. + + Selectable from config via ``weight_broadcast.type = "mx_v2"``. + Coexists with the existing ``"nixl_mx"`` path (PR #2389) for + migration. See ``docs/proposals/post-pr2389-mx-v2.md`` for the + full design. Maps to + :class:`prime_rl.trainer.rl.broadcast.nixl_mx_v2.NIXLMxV2WeightBroadcast` + (trainer) and + :class:`prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker` + (inference). + """ + + type: Literal["mx_v2"] = "mx_v2" + + # ─── Control plane (same as nixl_mx) ──────────────────────────── + host: str = "localhost" + """Host for the ModelExpress server.""" + + port: int = 29501 + """Port for the ModelExpress server.""" + + timeout: int = 1200 + """Timeout in seconds for rendezvous and per-step transfers.""" + + inference_world_size: int = 1 + """Number of GPUs used for inference.""" + + inference_model_name: str = "" + """HF model name of the inference target.""" + + # ─── Discovery (Phase 2) ──────────────────────────────────────── + same_rank_only: bool = True + """GB200/EFA multi-NIC fabrics: receivers pull from same-rank trainer only. + rdma-0..3 are separate L3 subnets, so cross-rank writes are unrouted.""" + + dedup_freshest_per_rank: bool = True + """When multiple READY entries share a worker_rank (e.g. after a pod + restart), pick the freshest by ``updated_at``. Without this, stale + catalog entries cause ``NIXL_ERR_NOT_ALLOWED`` on ``add_remote_agent``.""" + + # ─── Layout metadata (Phase 3) ────────────────────────────────── + publish_compile_target: bool = True + """Trainer stamps every publish with the conversion's compile_target tag + (e.g. ``cutlass_fp8``, ``deep_gemm_fp8``, ``hf_raw``) so receivers can + refuse mismatched layouts at discovery, before any RDMA cycle.""" + + compile_target_filter: list[str] | None = None + """Receiver-side whitelist of acceptable compile_target strings. + ``None`` (default) = accept anything — back-compat with PR #2389 + publishers that don't carry the tag. Set e.g. ``["cutlass_fp8"]`` + or ``["cutlass_fp8", "hf_raw"]`` to refuse mismatches.""" + + # ─── Pipeline replication (TensorHub pattern) ─────────────────── + publish_self_as_replica: bool = True + """After a successful receive, inference workers republish their + NIXL buffers as additional sources. Subsequent receivers can pull + from peers instead of the trainer, amplifying total egress + bandwidth. Trainer NIC stops being the bottleneck past ~4 receivers.""" + + WeightBroadcastConfig: TypeAlias = Annotated[ - FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLMxWeightBroadcastConfig, + FileSystemWeightBroadcastConfig + | NCCLWeightBroadcastConfig + | NIXLMxWeightBroadcastConfig + | MxV2WeightBroadcastConfig, Field(discriminator="type"), ] diff --git a/src/prime_rl/inference/vllm/server.py b/src/prime_rl/inference/vllm/server.py index 09c0d1a201..d7d5a55d86 100644 --- a/src/prime_rl/inference/vllm/server.py +++ b/src/prime_rl/inference/vllm/server.py @@ -57,6 +57,7 @@ def models(request: Request) -> OpenAIServingModels: "nccl": "prime_rl.inference.vllm.worker.nccl.NCCLWeightUpdateWorker", "filesystem": "prime_rl.inference.vllm.worker.filesystem.FileSystemWeightUpdateWorker", "nixl_mx": "prime_rl.inference.vllm.worker.nixl_mx.NIXLMxWeightUpdateWorker", + "mx_v2": "prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker", } @@ -128,6 +129,56 @@ async def init_nixl_mx(request: Request): return {"status": "ok"} +@router.post("/init_nixl_mx_v2") +async def init_nixl_mx_v2(request: Request): + """Boot-time init for the ``mx_v2`` worker extension. + + Mirrors ``/init_nixl_mx`` but targets + :meth:`NIXLMxV2WeightUpdateWorker.init_nixl_mx_v2`. Accepts optional + ``publish_self_as_replica`` (tree fan-out; default True) and + ``listen_port`` (NIXL listen port; default None = auto-pick). + """ + data = await request.json() + await engine_client(request).collective_rpc( + "init_nixl_mx_v2", + args=(data["host"], data["port"], data["rank_offset"]), + kwargs={ + "publish_self_as_replica": data.get("publish_self_as_replica", True), + "listen_port": data.get("listen_port"), + }, + ) + return {"status": "ok"} + + +@router.post("/update_weights_v2") +async def update_weights_v2(request: Request): + """Per-cycle refit RPC for the ``mx_v2`` worker extension. + + Body fields: + step (int, required): trainer version to pull (engine accepts + sources with ``version >= step``). + compile_target_filter (list[str], optional): Phase 3b filter. + ``None`` = accept anything (back-compat). + timeout_seconds (float, optional): per-receive RDMA wait cap. + same_rank_only (bool, optional): default True (Phase 2). + + Returns the per-worker metrics dict aggregated across collective_rpc + fan-out so the orchestrator can emit them to dashboards without log + parsing. + """ + data = await request.json() + metrics = await engine_client(request).collective_rpc( + "update_weights_via_mx_v2", + args=(int(data["step"]),), + kwargs={ + "compile_target_filter": data.get("compile_target_filter"), + "timeout_seconds": float(data.get("timeout_seconds", 300.0)), + "same_rank_only": bool(data.get("same_rank_only", True)), + }, + ) + return {"status": "ok", "metrics": metrics} + + async def custom_init_app_state( engine_client: EngineClient, state: State, diff --git a/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py b/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py new file mode 100644 index 0000000000..67a219b19f --- /dev/null +++ b/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py @@ -0,0 +1,214 @@ +"""v2 inference-worker extension for ModelExpress weight refits. + +The v2 of :class:`NIXLMxWeightUpdateWorker` (PR #2389), built on the +``MxWeightTransferEngine`` adapter from ModelExpress PR #349. The adapter +wraps every Phase 2/3/4 capability behind vLLM's :class:`WeightTransferEngine` +ABC (the same shape Anyscale's RDT PR `#43375 +`_ targets), so this +module is **maximally thin** — it just instantiates the engine and +plumbs vLLM's ``load_weights`` callback through. + +Key differences from PR #2389: + +- **Pull semantics, not push.** The trainer ``publish()``-es weights to + the MX catalog; this worker calls ``engine.receive_weights(...)`` + which discovers + pulls. No pre-registered NIXL buffers on the + inference side (the engine uses the scratch-buffer path internally). +- **Compile-target safety net (Phase 3b).** Optional + ``compile_target_filter`` refuses sources whose tensors don't match + the kernel layout this worker expects — BEFORE any RDMA cycle is + spent. Set ``filter=None`` for back-compat (accept anything). +- **Mixed-TP path (Phase 4).** When ``target_tp_layout`` is set, the + engine uses the multi-source slice picker; otherwise it uses the + matched-TP single-source fast path. +- **Tree fan-out (TensorHub pattern).** When + ``publish_self_as_replica=True`` in the engine's ``init_info``, the + worker republishes itself as a source after each successful receive, + so subsequent receivers can pull from peers instead of the trainer. + +See :file:`docs/proposals/post-pr2389-mx-v2.md` for the design rationale +and migration plan. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch.nn import Module +from vllm.logger import init_logger + +from prime_rl.inference.vllm.worker.weight_transfer import update_mla_absorbed_weights +from prime_rl.transport.nixl_agent import make_agent_name, pin_ucx_rail + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_worker import Worker + + Worker = Worker # type: ignore +else: + Worker = object # type: ignore + +logger = init_logger("vllm.inference.vllm.worker_nixl_mx_v2") + + +class NIXLMxV2WeightUpdateWorker(Worker): + """vLLM worker extension for the v2 (pull-mode) weight-refit path. + + Mounted via vLLM's ``worker_extension_cls`` plumbing — same hook + PR #2389 uses for ``NIXLMxWeightUpdateWorker``. Two RPC endpoints: + + - :meth:`init_nixl_mx_v2` — called once at worker boot, sets up the + :class:`MxWeightTransferEngine` for this rank. + - :meth:`update_weights_via_mx_v2` — called per refit cycle by the + orchestrator; engine discovers + pulls + feeds vLLM's + ``load_weights``. + """ + + # ------------------------------------------------------------------ + # Model accessor (matches PR #2389) + # ------------------------------------------------------------------ + + @property + def raw_model(self) -> Module: + model_runner = self.model_runner + model = ( + model_runner.model.runnable + if hasattr(model_runner.model, "runnable") + else model_runner.model + ) + assert isinstance(model, Module) + return model + + # ------------------------------------------------------------------ + # Init RPC + # ------------------------------------------------------------------ + + def init_nixl_mx_v2( + self, + host: str, + port: int, + rank_offset: int, + *, + publish_self_as_replica: bool = True, + listen_port: int | None = None, + ) -> None: + """Build the :class:`MxWeightTransferEngine` for this worker. + + Args: + host, port: ``modelexpress-server`` URL. + rank_offset: orchestrator-assigned base rank for this pod; + ``global_rank = rank_offset + self.device.index``. + publish_self_as_replica: if True (default), after each + successful receive this worker republishes itself as + a source so newcomers can pull from it (tree fan-out). + listen_port: optional explicit NIXL listen port; ``None`` + = auto. + """ + from modelexpress.vllm_weight_transfer import MxInitInfo, MxWeightTransferEngine + + local_rank = self.device.index + global_rank = rank_offset + local_rank + inference_model_name = self.model_runner.model_config.model + + pin_ucx_rail(local_rank) + + self._engine = MxWeightTransferEngine( + init_info=MxInitInfo( + mx_server_url=f"{host}:{port}", + model_name=inference_model_name, + worker_rank=global_rank, + agent_name=make_agent_name("inference", global_rank), + device_id=local_rank, + listen_port=listen_port, + publish_self_as_replica=publish_self_as_replica, + ) + ) + self._global_rank = global_rank + logger.info( + f"[mx_v2] init: rank={global_rank} model={inference_model_name} " + f"publish_self_as_replica={publish_self_as_replica}" + ) + + # ------------------------------------------------------------------ + # Per-refit RPC + # ------------------------------------------------------------------ + + @torch.no_grad() + def update_weights_via_mx_v2( + self, + step: int, + *, + compile_target_filter: list[str] | None = None, + timeout_seconds: float = 300.0, + same_rank_only: bool = True, + ) -> dict[str, float | int | None]: + """Pull version ``step`` of the weights from the catalog. + + Args: + step: training-step counter; engine pulls sources with + ``version >= step``. + compile_target_filter: receiver-side Phase 3b filter. + ``None`` (default) = back-compat, accept any layout. + Set e.g. ``["cutlass_fp8"]`` or + ``["cutlass_fp8", "hf_raw"]`` to refuse mismatches at + discovery (no RDMA cycle spent on refusal). + timeout_seconds: cap on the engine's per-receive RDMA wait. + same_rank_only: enforce same-rank routing (required on + GB200/EFA multi-NIC fabrics where rdma-0..3 are + separate L3 subnets). + + Returns: + Per-cycle metrics dict (bytes / Gbps / discovery_seconds / + rdma_seconds) suitable for emission to dashboards. + """ + from modelexpress.vllm_weight_transfer import MxUpdateInfo + + update_info = MxUpdateInfo( + version=step, + compile_target_filter=set(compile_target_filter) if compile_target_filter else None, + target_tp_layout=None, # matched-TP fast path; Phase 4 wire-up future + timeout_seconds=timeout_seconds, + same_rank_only=same_rank_only, + ) + self._engine.receive_weights(update_info, load_weights=self._load_weights_batch) + + # Post-load housekeeping: same as PR #2389's path. + torch.cuda.synchronize(self.device) + update_mla_absorbed_weights(self.raw_model) + + # Surface the engine's metrics so the orchestrator / dashboards + # can read per-cycle bandwidth + discovery latency without + # parsing logs. + stats = self._engine.last_transfer_stats + metrics = { + "step": step, + "bytes_received": stats.bytes_received if stats else 0, + "tensors_received": stats.tensors_received if stats else 0, + "rdma_seconds": stats.elapsed_seconds if stats else 0.0, + "bandwidth_gbps": stats.bandwidth_gbps if stats else 0.0, + "discovery_seconds": self._engine.last_discovery_seconds, + "source_worker_rank": stats.source_worker_rank if stats else None, + } + logger.info( + f"[mx_v2] refit step={step} " + f"bytes={metrics['bytes_received'] / 1e6:.1f}MB " + f"rdma={metrics['rdma_seconds']:.3f}s " + f"{metrics['bandwidth_gbps']:.1f}Gbps " + f"from_rank={metrics['source_worker_rank']}" + ) + return metrics + + # ------------------------------------------------------------------ + # vLLM load-weights bridge + # ------------------------------------------------------------------ + + def _load_weights_batch(self, batch: list[tuple[str, torch.Tensor]]) -> None: + """Feed yielded ``(name, tensor)`` pairs through vLLM's load_weights. + + vLLM's :meth:`model.load_weights` handles HF→fused name remapping + via ``stacked_params_mapping`` (e.g. ``q_proj|k_proj|v_proj → + qkv_proj``), so this worker doesn't need to know about fusion — + the engine yields HF-format names and vLLM does the rest. + Matches the NemoRL v2 pattern + Anyscale's RDT pattern. + """ + self.raw_model.load_weights(batch) diff --git a/src/prime_rl/trainer/rl/broadcast/__init__.py b/src/prime_rl/trainer/rl/broadcast/__init__.py index d3883cbd1f..549f2d9273 100644 --- a/src/prime_rl/trainer/rl/broadcast/__init__.py +++ b/src/prime_rl/trainer/rl/broadcast/__init__.py @@ -23,5 +23,10 @@ def setup_weight_broadcast( assert parallel_dims is not None, "nixl_mx requires parallel_dims" return NIXLMxWeightBroadcast(output_dir, config, parallel_dims) + elif config.type == "mx_v2": + from prime_rl.trainer.rl.broadcast.nixl_mx_v2 import NIXLMxV2WeightBroadcast + + assert parallel_dims is not None, "mx_v2 requires parallel_dims" + return NIXLMxV2WeightBroadcast(output_dir, config, parallel_dims) else: raise ValueError(f"Invalid weight broadcast type: {config.type}") diff --git a/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py b/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py new file mode 100644 index 0000000000..fa4bde44cd --- /dev/null +++ b/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py @@ -0,0 +1,253 @@ +"""v2 trainer-side weight broadcast using the ModelExpress v2 fat clients. + +This is the v2 of :class:`NIXLMxWeightBroadcast` (PR #2389), built on +:class:`MxV2TrainingPublisher` instead of the in-tree :class:`MxRendezvous`. +The data plane is unchanged — NIXL RDMA, GPU-direct, no CPU staging — +but the control-plane glue (heartbeat, freshest-per-rank dedup, same-rank +routing, compile_target metadata, multi-source slice picker, tree +fan-out) is graduated onto the published MX v2 surface. + +The trainer-side conversion (FP8 packing, fusion, sharding into +``Sharded`` / ``Gathered`` / ``Expert`` slots) is *unchanged* from +PR #2389 — prime-rl still owns that kernel. What changes is **how the +already-converted bytes get published** (one ``publisher.publish()`` +per step instead of a per-tensor ``post_write`` loop driven from the +trainer), and **what metadata rides along** (``compile_target`` + +``compile_metadata`` from the conversion registry + per-tensor MoE +expert ownership). + +HSDP: when ``dp_replicate > 1`` only the primary replica (``dp_replicate +rank 0``) participates. Non-primary replicas hold bit-identical weights; +broadcasting a second copy would be pure waste. + +See :file:`docs/proposals/post-pr2389-mx-v2.md` for the design rationale +and migration plan. +""" + +from __future__ import annotations + +import time +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +import torch.nn as nn +from modelexpress.nemo_rl_v2 import MxV2TrainingPublisher, TrainerWorldLayout +from transformers import AutoConfig + +from prime_rl.configs.trainer import MxV2WeightBroadcastConfig +from prime_rl.trainer.models import PreTrainedModelPrimeRL +from prime_rl.trainer.models.conversions import select_default_conversion +from prime_rl.trainer.parallel_dims import ParallelDims +from prime_rl.trainer.rl.broadcast.base import WeightBroadcast +from prime_rl.trainer.runs import get_multi_run_manager +from prime_rl.trainer.utils import get_world +from prime_rl.transport.classic_cuda_pool import classic_cuda_alloc +from prime_rl.transport.nixl_agent import make_agent_name, pin_ucx_rail + + +class NIXLMxV2WeightBroadcast(WeightBroadcast): + """v2 weight broadcast over NIXL + ModelExpress fat clients. + + Selectable from config via ``weight_broadcast.type = "mx_v2"``. + Coexists with the existing ``"nixl_mx"`` path (PR #2389); no + behavior of ``"nixl_mx"`` is affected by importing this module. + + Args: + output_dir: training output directory (forwarded to base class). + config: parsed :class:`MxV2WeightBroadcastConfig`. + parallel_dims: ``ParallelDims`` instance describing the trainer's + FSDP / TP / EP / DP layout — used to construct the + ``TrainerWorldLayout`` carried in v2 metadata. + """ + + def __init__( + self, + output_dir: Path, + config: MxV2WeightBroadcastConfig, + parallel_dims: ParallelDims, + ) -> None: + super().__init__(output_dir) + self.config = config + self.world = get_world() + self.parallel_dims = parallel_dims + + self.is_initialized = False + self._publisher: MxV2TrainingPublisher | None = None + self._model_slots: list[Any] | None = None + self._conversion = None + self._hf_config = None + + if self.is_primary_hsdp_rank: + pin_ucx_rail(torch.cuda.current_device()) + + self._multi_run_manager = get_multi_run_manager() + + # ------------------------------------------------------------------ + # HSDP gate — only rank 0 of dp_replicate publishes + # ------------------------------------------------------------------ + + @property + def is_primary_hsdp_rank(self) -> bool: + if self.parallel_dims.dp_replicate_enabled: + return self.parallel_dims.get_mesh("dp_replicate").get_local_rank() == 0 + return True + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _build_world_layout(self) -> TrainerWorldLayout: + """Translate prime-rl's ParallelDims into the MX v2 world layout.""" + return TrainerWorldLayout( + fsdp_world_size=getattr(self.parallel_dims, "dp_shard_size", 1), + tp_world_size=getattr(self.parallel_dims, "tp_size", 1), + pp_world_size=getattr(self.parallel_dims, "pp_size", 1), + ep_world_size=getattr(self.parallel_dims, "ep_size", 1), + ) + + def lazy_init(self, model: PreTrainedModelPrimeRL) -> None: + """Build the v2 publisher + slot layout on first call. + + The model isn't available at ``__init__`` time (the WeightBroadcast + instance is constructed before the trainer model is materialized), + so slot construction and publisher initialization happen on the + first ``broadcast_weights`` call. + """ + if self.is_initialized: + return + + self._hf_config = AutoConfig.from_pretrained(self.config.inference_model_name) + self._conversion = select_default_conversion(self.config.inference_model_name) + + with classic_cuda_alloc(): + self._model_slots = model.build_slots( + self.parallel_dims, self._conversion, self._hf_config.torch_dtype + ) + + # The v2 publisher owns the NIXL agent + MX client + heartbeat. + # We pass our rank as ``worker_rank``; receivers with + # ``same_rank_only=True`` (Phase 2 default) will only pull from + # the trainer rank matching their own. + self._publisher = MxV2TrainingPublisher( + agent_name=make_agent_name("trainer", self.world.rank), + device_id=torch.cuda.current_device(), + mx_server_url=f"{self.config.host}:{self.config.port}", + worker_rank=self.world.rank, + world_layout=self._build_world_layout(), + ) + self._publisher.initialize( + model_name=self.config.inference_model_name, + dtype=str(self._hf_config.torch_dtype).replace("torch.", ""), + ) + self.is_initialized = True + self.logger.info( + f"[mx_v2] publisher initialized: rank={self.world.rank} " + f"layout={self._build_world_layout().encode()} " + f"compile_target={self._conversion.compile_target}" + ) + + # ------------------------------------------------------------------ + # Per-step broadcast + # ------------------------------------------------------------------ + + @torch.no_grad() + def broadcast_weights(self, model: nn.Module, step: int) -> None: + """Publish version ``step`` of the converted weights. + + Per-step lifecycle: + + 1. (HSDP) only the primary replica participates; others barrier. + 2. Fill the conversion slots from ``model.state_dict()`` — + **same code path PR #2389 uses**, prime-rl owns the kernel. + 3. For each slot's buffers, call ``publisher.add_tensor(...)`` + tagged with the conversion's ``compile_target`` / + ``compile_metadata`` (Phase 3) and any per-tensor MoE + expert metadata. + 4. ``publisher.publish(version=step)`` + ``mark_ready()`` — + catalog entry now visible to receivers polling for + ``min_version=step``. + 5. Bump the heartbeat (the publisher's ``HeartbeatThread`` + runs in the background; nothing to do here). + """ + if self.is_primary_hsdp_rank: + self.lazy_init(model) + + if self.world.is_master: + for idx in self._multi_run_manager.used_idxs: + if self._multi_run_manager.ready_to_update[idx]: + self._multi_run_manager.ready_to_update[idx] = False + + dist.barrier() + + if not self.is_primary_hsdp_rank: + # Non-primary HSDP replicas: bit-identical weights; nothing to publish. + dist.barrier() + return + + start = time.perf_counter() + + # 2. Fill slots from the live model state-dict via the conversion. + # This is where FP8 packing + fusion happens; same code path + # as PR #2389. We do NOT change the kernel. + state_dict = model.state_dict() + for slot in self._model_slots: + slot.fill_from(state_dict, self._conversion) + + # 3. Register every slot tensor with the v2 publisher, tagged with + # compile_target + compile_metadata so receivers can refuse + # mismatched layouts at discovery (Phase 3). + # Falls back to the safe "hf_raw" default when: + # - publish_compile_target=False (caller opts out), or + # - the conversion is on an older registry without the + # compile_target/compile_metadata fields (graceful + # degradation; back-compat with PR #2389 conversions). + if self.config.publish_compile_target: + compile_target = getattr(self._conversion, "compile_target", "hf_raw") + compile_metadata = getattr(self._conversion, "compile_metadata", None) + else: + compile_target = "hf_raw" + compile_metadata = None + + n_tensors = 0 + for slot in self._model_slots: + slot_is_expert = bool(getattr(slot, "is_expert", False)) + slot_expert_axis = int(getattr(slot, "expert_axis", 0)) + slot_owned_experts = tuple(getattr(slot, "owned_expert_ids", ())) + for buf_key, tensor, _ in slot.buffers: + self._publisher.add_tensor( + name=buf_key, + tensor=tensor, + is_expert=slot_is_expert, + expert_axis=slot_expert_axis, + owned_expert_ids=slot_owned_experts, + compile_target=compile_target, + compile_metadata=compile_metadata, + ) + n_tensors += 1 + + # 4. Publish + mark READY in one shot. + mx_source_id = self._publisher.publish(version=step) + self._publisher.mark_ready() + + elapsed = time.perf_counter() - start + self.logger.info( + f"[mx_v2] publish step={step} tensors={n_tensors} " + f"compile_target={compile_target} mx_source_id={mx_source_id} " + f"elapsed={elapsed:.3f}s" + ) + + dist.barrier() + + # ------------------------------------------------------------------ + # Teardown + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + if self._publisher is not None: + try: + self._publisher.shutdown() + finally: + self._publisher = None + self.is_initialized = False diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index d633bbbba0..308f9d7fdd 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -539,3 +539,75 @@ async def _init(admin_client: AsyncClient, rank_offset: int) -> None: response.raise_for_status() await asyncio.gather(*[_init(admin_client, i * gpus_per_server) for i, admin_client in enumerate(admin_clients)]) + + +async def init_nixl_mx_v2_broadcast( + admin_clients: list[AsyncClient], + host: str, + port: int, + inference_world_size: int, + *, + publish_self_as_replica: bool = True, + listen_port: int | None = None, +) -> None: + """Initialize the ``mx_v2`` (pull-mode) receivers on inference servers. + + Mirrors :func:`init_nixl_mx_broadcast` but targets the v2 worker + extension (``NIXLMxV2WeightUpdateWorker``) which uses the published + :class:`MxWeightTransferEngine` adapter instead of the in-tree + :class:`MxRendezvous`. + """ + logger = get_logger() + gpus_per_server = inference_world_size // len(admin_clients) + + logger.info( + f"Initializing NIXL+MX v2 broadcast: {len(admin_clients)} servers, " + f"inference_world_size={inference_world_size}, gpus_per_server={gpus_per_server}, " + f"publish_self_as_replica={publish_self_as_replica}" + ) + + async def _init(admin_client: AsyncClient, rank_offset: int) -> None: + response = await admin_client.post( + "/init_nixl_mx_v2", + json={ + "host": host, + "port": port, + "rank_offset": rank_offset, + "publish_self_as_replica": publish_self_as_replica, + "listen_port": listen_port, + }, + ) + response.raise_for_status() + + await asyncio.gather(*[_init(admin_client, i * gpus_per_server) for i, admin_client in enumerate(admin_clients)]) + + +async def update_weights_v2( + admin_clients: list[AsyncClient], + step: int, + *, + compile_target_filter: list[str] | None = None, + timeout_seconds: float = 300.0, + same_rank_only: bool = True, +) -> list[dict]: + """Drive a v2 (pull-mode) refit on all inference servers. + + Mirrors the existing ``/update_weights`` poke but for the + ``mx_v2`` worker path. Returns the per-server metrics dicts so the + orchestrator can emit per-cycle timing to its dashboards. + """ + + async def _update(admin_client: AsyncClient) -> dict: + response = await admin_client.post( + "/update_weights_v2", + json={ + "step": int(step), + "compile_target_filter": compile_target_filter, + "timeout_seconds": float(timeout_seconds), + "same_rank_only": bool(same_rank_only), + }, + ) + response.raise_for_status() + return response.json() + + return list(await asyncio.gather(*[_update(c) for c in admin_clients])) From df1b81a4577b566b1e5f4f9dd4c4089b5f007564 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 2 Jun 2026 15:51:23 -0700 Subject: [PATCH 10/18] test: add unit tests for mx_v2 worker + broadcast + selector + small log cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds 24 unit tests covering the new weight_broadcast.type="mx_v2" path: tests/unit/train/rl/test_nixl_mx_v2.py (10 tests) tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py (6 tests) tests/unit/inference/vllm/test_mx_v2_server_endpoints.py (3 + 5 gated) What's covered -------------- Trainer broadcast (NIXLMxV2WeightBroadcast): * Construction doesn't eagerly initialize the publisher * is_primary_hsdp_rank gates correctly for the 3 cases (no-HSDP, HSDP-primary, HSDP-non-primary) * lazy_init builds MxV2TrainingPublisher with the right TrainerWorldLayout, mx_server_url, and model_name * lazy_init is idempotent on repeated calls * broadcast_weights threads compile_target + compile_metadata into every publisher.add_tensor call when publish_compile_target=True * broadcast_weights falls back to "hf_raw" when publish_compile_target=False (back-compat default) * broadcast_weights threads is_expert / expert_axis / owned_expert_ids for MoE slots correctly * Non-primary HSDP ranks skip publish entirely * Each slot's fill_from is invoked with the resolved conversion * shutdown() is idempotent Inference worker (NIXLMxV2WeightUpdateWorker): * init_nixl_mx_v2 constructs the right MxInitInfo and pins UCX rail * publish_self_as_replica False propagates correctly * update_weights_via_mx_v2 constructs the right MxUpdateInfo and calls engine.receive_weights with the load_weights callback * No compile_target_filter passes None (back-compat) * _load_weights_batch forwards to raw_model.load_weights (HF→fused name remap via vLLM's stacked_params_mapping) * Metrics dict is well-formed even when engine.last_transfer_stats is None (early-cycle robustness) Server-side glue: * WORKER_EXTENSION_CLS table has "mx_v2" entry pointing at NIXLMxV2WeightUpdateWorker * Existing nccl / filesystem / nixl_mx entries preserved * Trainer-side selector __init__.py routes mx_v2 to the new broadcast * 5 HTTP endpoint + orchestrator-client tests gated to CI (need full prime-rl install — they skip locally cleanly with explanation) Test pattern ----------- Uses importlib.util.spec_from_file_location + sys.modules stubs so the tests run anywhere torch + pytest is available (no need for the full prime-rl venv install). Same pattern as the MX-side test_vllm_weight_transfer.py tests on PR #349. Local result: 19 passed, 5 skipped, no failures. Source fix ---------- nixl_mx_v2.py:_build_world_layout() was being called twice in lazy_init (once for the publisher's world_layout arg, once for the log message). Refactored to call once and bind to a local. Pure correctness + efficiency cleanup, no behavior change. --- .../trainer/rl/broadcast/nixl_mx_v2.py | 5 +- tests/unit/inference/vllm/__init__.py | 0 .../vllm/test_mx_v2_server_endpoints.py | 417 +++++++++++++++ tests/unit/inference/vllm/worker/__init__.py | 0 .../vllm/worker/test_nixl_mx_v2_worker.py | 274 ++++++++++ tests/unit/train/rl/test_nixl_mx_v2.py | 477 ++++++++++++++++++ 6 files changed, 1171 insertions(+), 2 deletions(-) create mode 100644 tests/unit/inference/vllm/__init__.py create mode 100644 tests/unit/inference/vllm/test_mx_v2_server_endpoints.py create mode 100644 tests/unit/inference/vllm/worker/__init__.py create mode 100644 tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py create mode 100644 tests/unit/train/rl/test_nixl_mx_v2.py diff --git a/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py b/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py index fa4bde44cd..dfbc8679e2 100644 --- a/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py +++ b/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py @@ -130,12 +130,13 @@ def lazy_init(self, model: PreTrainedModelPrimeRL) -> None: # We pass our rank as ``worker_rank``; receivers with # ``same_rank_only=True`` (Phase 2 default) will only pull from # the trainer rank matching their own. + world_layout = self._build_world_layout() self._publisher = MxV2TrainingPublisher( agent_name=make_agent_name("trainer", self.world.rank), device_id=torch.cuda.current_device(), mx_server_url=f"{self.config.host}:{self.config.port}", worker_rank=self.world.rank, - world_layout=self._build_world_layout(), + world_layout=world_layout, ) self._publisher.initialize( model_name=self.config.inference_model_name, @@ -144,7 +145,7 @@ def lazy_init(self, model: PreTrainedModelPrimeRL) -> None: self.is_initialized = True self.logger.info( f"[mx_v2] publisher initialized: rank={self.world.rank} " - f"layout={self._build_world_layout().encode()} " + f"layout={world_layout.encode()} " f"compile_target={self._conversion.compile_target}" ) diff --git a/tests/unit/inference/vllm/__init__.py b/tests/unit/inference/vllm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/inference/vllm/test_mx_v2_server_endpoints.py b/tests/unit/inference/vllm/test_mx_v2_server_endpoints.py new file mode 100644 index 0000000000..18b53af10a --- /dev/null +++ b/tests/unit/inference/vllm/test_mx_v2_server_endpoints.py @@ -0,0 +1,417 @@ +"""Unit tests for the ``mx_v2`` server-side glue. + +Three pieces tested here: + +1. The ``WORKER_EXTENSION_CLS["mx_v2"]`` entry in server.py — i.e. that + the worker-extension selector points at our new worker extension class. +2. The new HTTP endpoints ``/init_nixl_mx_v2`` and ``/update_weights_v2`` + on server.py — verified to forward to the right ``collective_rpc`` + method names with the right kwargs. +3. The orchestrator-side helpers ``init_nixl_mx_v2_broadcast`` and + ``update_weights_v2`` in client.py — verified to POST to the right + endpoints with the right JSON body. + +Plus the trainer-side selector dispatch (``setup_weight_broadcast`` for +``config.type == "mx_v2"``). + +We use ``importlib.util.spec_from_file_location`` to load each target +file against a stubbed dep graph, so the test runs anywhere torch + +pytest is present without prime-rl needing to be installed. +""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + + +_PRIME_RL_ROOT = Path(__file__).resolve().parents[4] +_SERVER_FILE = ( + _PRIME_RL_ROOT / "src" / "prime_rl" / "inference" / "vllm" / "server.py" +) +_CLIENT_FILE = _PRIME_RL_ROOT / "src" / "prime_rl" / "utils" / "client.py" +_BROADCAST_INIT_FILE = ( + _PRIME_RL_ROOT + / "src" + / "prime_rl" + / "trainer" + / "rl" + / "broadcast" + / "__init__.py" +) + + +# ---------------------------------------------------------------------------- +# 1. WORKER_EXTENSION_CLS table — read directly from the source AST so we +# don't have to install the package or stub anywhere near as much +# ---------------------------------------------------------------------------- + + +def _extract_worker_extension_cls(): + """Parse server.py and pull out the WORKER_EXTENSION_CLS dict literal. + + Avoids the import-graph problem entirely — we only need the table. + """ + import ast + + src = _SERVER_FILE.read_text() + tree = ast.parse(src) + for node in ast.walk(tree): + if isinstance(node, ast.Assign) and any( + isinstance(t, ast.Name) and t.id == "WORKER_EXTENSION_CLS" + for t in node.targets + ): + return { + key.value: value.value + for key, value in zip(node.value.keys, node.value.values) + if isinstance(key, ast.Constant) and isinstance(value, ast.Constant) + } + raise RuntimeError("WORKER_EXTENSION_CLS not found in server.py") + + +def test_worker_extension_cls_table_has_mx_v2_entry(): + table = _extract_worker_extension_cls() + assert "mx_v2" in table + assert ( + table["mx_v2"] + == "prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker" + ) + + +def test_worker_extension_cls_table_preserves_existing_backends(): + """Adding mx_v2 must not have removed nccl / filesystem / nixl_mx.""" + table = _extract_worker_extension_cls() + assert "nccl" in table + assert "filesystem" in table + assert "nixl_mx" in table + # And nixl_mx vs mx_v2 are two distinct worker classes. + assert table["nixl_mx"] != table["mx_v2"] + + +# ---------------------------------------------------------------------------- +# 2. Server endpoints — load via spec_from_file_location with stubs +# ---------------------------------------------------------------------------- + + +def _install_server_stubs(): + """Stub the heavy server.py deps (vLLM, FastAPI bits, prime_rl imports). + + Just enough to let server.py's module-level statements run; we only need + to call the two new endpoint coroutines. + """ + # FastAPI bits + fake_request_cls = type("Request", (), {}) + fake_apirouter_cls = MagicMock(name="APIRouter") + fake_apirouter = MagicMock(name="apirouter_inst") + # Make APIRouter.post / get return identity decorators so the @router.post + # decorators in server.py work without registering anything. + fake_apirouter.post = lambda *a, **kw: (lambda f: f) + fake_apirouter.get = lambda *a, **kw: (lambda f: f) + fake_apirouter_cls.return_value = fake_apirouter + fake_jsonresponse = MagicMock(name="JSONResponse") + sys.modules["fastapi"] = types.SimpleNamespace( + Request=fake_request_cls, APIRouter=fake_apirouter_cls + ) + sys.modules["fastapi.responses"] = types.SimpleNamespace( + JSONResponse=fake_jsonresponse + ) + + # vllm bits + sys.modules["vllm"] = types.SimpleNamespace() + sys.modules["vllm.engine"] = types.SimpleNamespace() + sys.modules["vllm.engine.protocol"] = types.SimpleNamespace( + EngineClient=type("EngineClient", (), {}) + ) + sys.modules["vllm.entrypoints"] = types.SimpleNamespace() + sys.modules["vllm.entrypoints.openai"] = types.SimpleNamespace() + sys.modules["vllm.entrypoints.openai.api_server"] = types.SimpleNamespace( + State=type("State", (), {}), + init_app_state=MagicMock(), + run_headless=MagicMock(), + ) + sys.modules["vllm.entrypoints.openai.protocol"] = types.SimpleNamespace( + LoadLoRAAdapterRequest=type("LoadLoRAAdapterRequest", (), {}), + ErrorResponse=type("ErrorResponse", (), {}), + ) + sys.modules["vllm.utils"] = types.SimpleNamespace(FlexibleArgumentParser=MagicMock()) + # prime_rl deps used at top of server.py + sys.modules.setdefault("prime_rl", types.ModuleType("prime_rl")) + sys.modules.setdefault("prime_rl.utils", types.ModuleType("prime_rl.utils")) + sys.modules["prime_rl.utils.logger"] = types.SimpleNamespace( + get_logger=MagicMock(return_value=MagicMock(name="logger")), + setup_logger=MagicMock(), + ) + + # PrimeRlServingTokens etc. + sys.modules["prime_rl.inference"] = types.ModuleType("prime_rl.inference") + sys.modules["prime_rl.inference.vllm"] = types.ModuleType("prime_rl.inference.vllm") + sys.modules["prime_rl.inference.vllm.serving_tokens"] = types.SimpleNamespace( + PrimeRlServingTokens=type("PrimeRlServingTokens", (), {}) + ) + + +@pytest.fixture +def server_mod(): + """Load server.py with stubs in place.""" + # Wipe cached state + for k in list(sys.modules.keys()): + if k.startswith("prime_rl") or k.startswith("vllm") or k.startswith("fastapi"): + del sys.modules[k] + + _install_server_stubs() + + spec = importlib.util.spec_from_file_location( + "_test_server_under_test", _SERVER_FILE + ) + mod = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(mod) + except Exception as e: + pytest.skip( + f"server.py imports too much to stub cleanly: {e}; this test " + f"runs in CI where prime-rl IS installed" + ) + yield mod + + +@pytest.mark.asyncio +async def test_init_nixl_mx_v2_endpoint_dispatches_collective_rpc(server_mod): + fake_client = MagicMock() + fake_client.collective_rpc = AsyncMock() + fake_request = MagicMock() + fake_request.json = AsyncMock( + return_value={ + "host": "modelexpress-server.kavin.svc.cluster.local", + "port": 8001, + "rank_offset": 4, + "publish_self_as_replica": True, + "listen_port": None, + } + ) + + orig = getattr(server_mod, "engine_client", None) + server_mod.engine_client = lambda r: fake_client + try: + result = await server_mod.init_nixl_mx_v2(fake_request) + finally: + if orig is not None: + server_mod.engine_client = orig + + assert result == {"status": "ok"} + fake_client.collective_rpc.assert_called_once_with( + "init_nixl_mx_v2", + args=( + "modelexpress-server.kavin.svc.cluster.local", + 8001, + 4, + ), + kwargs={"publish_self_as_replica": True, "listen_port": None}, + ) + + +@pytest.mark.asyncio +async def test_update_weights_v2_endpoint_dispatches_collective_rpc(server_mod): + fake_metrics = [ + {"step": 42, "bytes_received": 536_870_912, "bandwidth_gbps": 52.4} + ] + fake_client = MagicMock() + fake_client.collective_rpc = AsyncMock(return_value=fake_metrics) + fake_request = MagicMock() + fake_request.json = AsyncMock( + return_value={ + "step": 42, + "compile_target_filter": ["cutlass_fp8"], + "timeout_seconds": 180.0, + "same_rank_only": True, + } + ) + + orig = getattr(server_mod, "engine_client", None) + server_mod.engine_client = lambda r: fake_client + try: + result = await server_mod.update_weights_v2(fake_request) + finally: + if orig is not None: + server_mod.engine_client = orig + + assert result == {"status": "ok", "metrics": fake_metrics} + fake_client.collective_rpc.assert_called_once_with( + "update_weights_via_mx_v2", + args=(42,), + kwargs={ + "compile_target_filter": ["cutlass_fp8"], + "timeout_seconds": 180.0, + "same_rank_only": True, + }, + ) + + +@pytest.mark.asyncio +async def test_update_weights_v2_endpoint_defaults(server_mod): + fake_client = MagicMock() + fake_client.collective_rpc = AsyncMock(return_value=[]) + fake_request = MagicMock() + fake_request.json = AsyncMock(return_value={"step": 1}) + + orig = getattr(server_mod, "engine_client", None) + server_mod.engine_client = lambda r: fake_client + try: + await server_mod.update_weights_v2(fake_request) + finally: + if orig is not None: + server_mod.engine_client = orig + + kwargs = fake_client.collective_rpc.call_args.kwargs["kwargs"] + assert kwargs["compile_target_filter"] is None + assert kwargs["timeout_seconds"] == 300.0 + assert kwargs["same_rank_only"] is True + + +# ---------------------------------------------------------------------------- +# 3. Orchestrator-side helpers — load client.py with stubs +# ---------------------------------------------------------------------------- + + +def _install_client_stubs(): + sys.modules.setdefault("prime_rl", types.ModuleType("prime_rl")) + sys.modules.setdefault("prime_rl.utils", types.ModuleType("prime_rl.utils")) + sys.modules["prime_rl.utils.logger"] = types.SimpleNamespace( + get_logger=MagicMock(return_value=MagicMock(name="logger")), + setup_logger=MagicMock(), + ) + # httpx AsyncClient stub — client.py imports it + sys.modules["httpx"] = types.SimpleNamespace( + AsyncClient=type("AsyncClient", (), {}) + ) + + +@pytest.fixture +def client_mod(): + for k in list(sys.modules.keys()): + if k.startswith("prime_rl") or k == "httpx": + del sys.modules[k] + _install_client_stubs() + spec = importlib.util.spec_from_file_location( + "_test_client_under_test", _CLIENT_FILE + ) + mod = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(mod) + except Exception as e: + pytest.skip( + f"client.py imports too much to stub cleanly: {e}; this test " + f"runs in CI where prime-rl IS installed" + ) + yield mod + + +@pytest.mark.asyncio +async def test_init_nixl_mx_v2_broadcast_posts_to_all_servers(client_mod): + """POSTs /init_nixl_mx_v2 with rank_offset = i * gpus_per_server per server.""" + admin_clients = [] + for _ in range(3): + c = MagicMock() + resp = MagicMock() + resp.raise_for_status = MagicMock() + c.post = AsyncMock(return_value=resp) + admin_clients.append(c) + + await client_mod.init_nixl_mx_v2_broadcast( + admin_clients, + host="mx-server", + port=8001, + inference_world_size=12, + publish_self_as_replica=True, + listen_port=None, + ) + + # gpus_per_server = 12 // 3 = 4 → rank_offsets 0, 4, 8 + expected_offsets = [0, 4, 8] + for c, expected_offset in zip(admin_clients, expected_offsets): + c.post.assert_called_once() + args, kwargs = c.post.call_args + assert args[0] == "/init_nixl_mx_v2" + body = kwargs["json"] + assert body["host"] == "mx-server" + assert body["port"] == 8001 + assert body["rank_offset"] == expected_offset + assert body["publish_self_as_replica"] is True + + +@pytest.mark.asyncio +async def test_update_weights_v2_posts_step_and_returns_metrics(client_mod): + fake_servers = [] + expected_responses = [ + {"status": "ok", "metrics": [{"step": 5, "bandwidth_gbps": 50.0}]}, + {"status": "ok", "metrics": [{"step": 5, "bandwidth_gbps": 48.0}]}, + ] + for resp_body in expected_responses: + c = MagicMock() + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json = MagicMock(return_value=resp_body) + c.post = AsyncMock(return_value=resp) + fake_servers.append(c) + + results = await client_mod.update_weights_v2( + fake_servers, + step=5, + compile_target_filter=["cutlass_fp8"], + timeout_seconds=180.0, + same_rank_only=True, + ) + + assert results == expected_responses + for c in fake_servers: + args, kwargs = c.post.call_args + assert args[0] == "/update_weights_v2" + body = kwargs["json"] + assert body["step"] == 5 + assert body["compile_target_filter"] == ["cutlass_fp8"] + assert body["timeout_seconds"] == 180.0 + assert body["same_rank_only"] is True + + +# ---------------------------------------------------------------------------- +# 4. Trainer-side selector dispatch — verify __init__.py routes mx_v2 correctly +# ---------------------------------------------------------------------------- + + +def test_broadcast_init_dispatches_mx_v2_via_ast(): + """The selector in broadcast/__init__.py routes config.type == "mx_v2" + to NIXLMxV2WeightBroadcast. Parse the source directly to avoid the heavy + import graph.""" + import ast + + src = _BROADCAST_INIT_FILE.read_text() + tree = ast.parse(src) + + # Find the setup_weight_broadcast function + func = next( + node + for node in ast.walk(tree) + if isinstance(node, ast.FunctionDef) and node.name == "setup_weight_broadcast" + ) + + # Find the elif branch with `config.type == "mx_v2"` + mx_v2_branch_found = False + for node in ast.walk(func): + if isinstance(node, ast.Compare): + # Detect `config.type == "mx_v2"` + if ( + len(node.comparators) == 1 + and isinstance(node.comparators[0], ast.Constant) + and node.comparators[0].value == "mx_v2" + ): + mx_v2_branch_found = True + break + assert mx_v2_branch_found, "mx_v2 dispatch branch not found in selector" + + # And the branch references NIXLMxV2WeightBroadcast + assert "NIXLMxV2WeightBroadcast" in src + assert "from prime_rl.trainer.rl.broadcast.nixl_mx_v2" in src diff --git a/tests/unit/inference/vllm/worker/__init__.py b/tests/unit/inference/vllm/worker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py b/tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py new file mode 100644 index 0000000000..78a339bc6b --- /dev/null +++ b/tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py @@ -0,0 +1,274 @@ +"""Unit tests for ``NIXLMxV2WeightUpdateWorker``. + +Same pattern as ``test_nixl_mx_v2.py`` — load the production module via +``importlib.util.spec_from_file_location`` against a fully-stubbed +dependency graph (vLLM, modelexpress, prime_rl.transport, etc.), so the +test runs anywhere torch + pytest is present. + +The worker has two RPC entry points: + +- ``init_nixl_mx_v2(host, port, rank_offset, *, publish_self_as_replica, listen_port)`` +- ``update_weights_via_mx_v2(step, *, compile_target_filter, timeout_seconds, same_rank_only)`` + +We verify init-info construction, update-info construction, the +``load_weights`` callback path, metrics-dict shape, and the post-load +``update_mla_absorbed_weights`` hook. +""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +_PRIME_RL_ROOT = Path(__file__).resolve().parents[5] # prime-rl root +_WORKER_FILE = ( + _PRIME_RL_ROOT + / "src" + / "prime_rl" + / "inference" + / "vllm" + / "worker" + / "nixl_mx_v2.py" +) + + +def _install_stubs(): + """Insert fake modules so the worker file imports cleanly.""" + mocks: dict[str, MagicMock] = {} + + # ─── modelexpress.vllm_weight_transfer ────────────────────────────── + fake_engine_cls = MagicMock(name="MxWeightTransferEngine_cls") + fake_engine = MagicMock(name="MxWeightTransferEngine_inst") + fake_stats = types.SimpleNamespace( + bytes_received=536_870_912, + tensors_received=64, + elapsed_seconds=0.082, + bandwidth_gbps=52.4, + discovery_seconds=0.014, + source_worker_rank=0, + ) + fake_engine.last_transfer_stats = fake_stats + fake_engine.last_discovery_seconds = 0.014 + fake_engine_cls.return_value = fake_engine + + fake_init_info_cls = MagicMock(name="MxInitInfo") + fake_update_info_cls = MagicMock(name="MxUpdateInfo") + + sys.modules["modelexpress"] = types.SimpleNamespace() + sys.modules["modelexpress.vllm_weight_transfer"] = types.SimpleNamespace( + MxWeightTransferEngine=fake_engine_cls, + MxInitInfo=fake_init_info_cls, + MxUpdateInfo=fake_update_info_cls, + ) + mocks["engine_cls"] = fake_engine_cls + mocks["engine"] = fake_engine + mocks["init_info_cls"] = fake_init_info_cls + mocks["update_info_cls"] = fake_update_info_cls + mocks["stats"] = fake_stats + + # ─── vllm.logger ──────────────────────────────────────────────────── + sys.modules["vllm"] = types.SimpleNamespace() + sys.modules["vllm.logger"] = types.SimpleNamespace( + init_logger=lambda name: MagicMock(name=f"logger({name})") + ) + # vllm.v1.worker.gpu_worker only used inside TYPE_CHECKING so no stub needed + + # ─── prime_rl.inference.vllm.worker.weight_transfer ──────────────── + fake_update_mla = MagicMock(name="update_mla_absorbed_weights") + sys.modules.setdefault("prime_rl", types.ModuleType("prime_rl")) + sys.modules["prime_rl.inference"] = types.ModuleType("prime_rl.inference") + sys.modules["prime_rl.inference.vllm"] = types.ModuleType("prime_rl.inference.vllm") + sys.modules["prime_rl.inference.vllm.worker"] = types.ModuleType( + "prime_rl.inference.vllm.worker" + ) + pkg_wt = types.ModuleType("prime_rl.inference.vllm.worker.weight_transfer") + pkg_wt.update_mla_absorbed_weights = fake_update_mla + # `build_expert_map` is imported by the OLD worker (nixl_mx.py) — not by + # nixl_mx_v2 — so we don't need to stub it for this test. Add a no-op + # in case the test imports the broadcast __init__ which may pull it in. + pkg_wt.build_expert_map = MagicMock(name="build_expert_map", return_value={}) + sys.modules["prime_rl.inference.vllm.worker.weight_transfer"] = pkg_wt + mocks["update_mla"] = fake_update_mla + + # ─── prime_rl.transport.nixl_agent ────────────────────────────────── + fake_make_agent_name = MagicMock(return_value="vllm-inference-r0") + fake_pin_ucx_rail = MagicMock() + sys.modules["prime_rl.transport"] = types.ModuleType("prime_rl.transport") + pkg_na = types.ModuleType("prime_rl.transport.nixl_agent") + pkg_na.make_agent_name = fake_make_agent_name + pkg_na.pin_ucx_rail = fake_pin_ucx_rail + sys.modules["prime_rl.transport.nixl_agent"] = pkg_na + mocks["make_agent_name"] = fake_make_agent_name + mocks["pin_ucx_rail"] = fake_pin_ucx_rail + + return mocks + + +@pytest.fixture +def worker_mod(): + """Load nixl_mx_v2.py worker under fully-stubbed deps.""" + for k in list(sys.modules.keys()): + if k.startswith("prime_rl") or k == "modelexpress" or k.startswith("modelexpress."): + del sys.modules[k] + if k == "vllm" or k.startswith("vllm."): + del sys.modules[k] + + mocks = _install_stubs() + + import torch + if not hasattr(torch.cuda, "synchronize"): + torch.cuda.synchronize = MagicMock() + original_synchronize = torch.cuda.synchronize + torch.cuda.synchronize = MagicMock() + + spec = importlib.util.spec_from_file_location( + "_test_nixl_mx_v2_worker_under_test", _WORKER_FILE + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + try: + yield (mod, mocks) + finally: + torch.cuda.synchronize = original_synchronize + + +def _make_worker(mod, *, model_name="bench/synthetic-1.5B", device_index=0): + """Build a worker with a mocked vLLM Worker context. + + The `raw_model` property does `assert isinstance(model, Module)` so the + inner model has to be a real ``torch.nn.Module`` subclass. We attach a + `load_weights` method onto it so tests can spy on the callback path. + """ + import torch.nn as nn + + class FakeInnerModel(nn.Module): + def __init__(self): + super().__init__() + self.load_weights = MagicMock(name="load_weights") + + worker = mod.NIXLMxV2WeightUpdateWorker() + worker.device = MagicMock() + worker.device.index = device_index + runner = MagicMock(name="ModelRunner") + runner.model_config = MagicMock(model=model_name) + fake_inner = FakeInnerModel() + runner.model = MagicMock() + runner.model.runnable = fake_inner + worker.model_runner = runner + return worker + + +def test_init_nixl_mx_v2_builds_engine_with_correct_init_info(worker_mod): + mod, mocks = worker_mod + worker = _make_worker(mod, device_index=2) + worker.init_nixl_mx_v2( + host="modelexpress-server.kavin.svc.cluster.local", + port=8001, + rank_offset=4, + publish_self_as_replica=True, + listen_port=None, + ) + + mocks["init_info_cls"].assert_called_once() + init_kwargs = mocks["init_info_cls"].call_args.kwargs + assert ( + init_kwargs["mx_server_url"] + == "modelexpress-server.kavin.svc.cluster.local:8001" + ) + assert init_kwargs["worker_rank"] == 4 + 2 + assert init_kwargs["agent_name"] == "vllm-inference-r0" + assert init_kwargs["device_id"] == 2 + assert init_kwargs["publish_self_as_replica"] is True + + mocks["engine_cls"].assert_called_once() + eng_kwargs = mocks["engine_cls"].call_args.kwargs + assert "init_info" in eng_kwargs + mocks["pin_ucx_rail"].assert_called_once_with(2) + assert worker._global_rank == 6 + + +def test_init_nixl_mx_v2_respects_publish_self_as_replica_false(worker_mod): + mod, mocks = worker_mod + worker = _make_worker(mod) + worker.init_nixl_mx_v2( + host="x", port=8001, rank_offset=0, publish_self_as_replica=False + ) + init_kwargs = mocks["init_info_cls"].call_args.kwargs + assert init_kwargs["publish_self_as_replica"] is False + + +def test_update_weights_via_mx_v2_dispatches_engine_receive(worker_mod): + mod, mocks = worker_mod + worker = _make_worker(mod) + worker.init_nixl_mx_v2(host="x", port=8001, rank_offset=0) + + metrics = worker.update_weights_via_mx_v2( + 42, + compile_target_filter=["cutlass_fp8", "hf_raw"], + timeout_seconds=180.0, + same_rank_only=True, + ) + + mocks["update_info_cls"].assert_called_once() + upd_kwargs = mocks["update_info_cls"].call_args.kwargs + assert upd_kwargs["version"] == 42 + assert upd_kwargs["compile_target_filter"] == {"cutlass_fp8", "hf_raw"} + assert upd_kwargs["timeout_seconds"] == 180.0 + assert upd_kwargs["same_rank_only"] is True + assert upd_kwargs["target_tp_layout"] is None + + mocks["engine"].receive_weights.assert_called_once() + call = mocks["engine"].receive_weights.call_args + assert "load_weights" in call.kwargs + + mocks["update_mla"].assert_called_once() + + assert metrics["step"] == 42 + assert metrics["bytes_received"] == 536_870_912 + assert metrics["tensors_received"] == 64 + assert metrics["bandwidth_gbps"] == pytest.approx(52.4) + assert metrics["discovery_seconds"] == pytest.approx(0.014) + assert metrics["source_worker_rank"] == 0 + + +def test_update_weights_via_mx_v2_no_filter_passes_none(worker_mod): + mod, mocks = worker_mod + worker = _make_worker(mod) + worker.init_nixl_mx_v2(host="x", port=8001, rank_offset=0) + worker.update_weights_via_mx_v2(1, compile_target_filter=None) + upd_kwargs = mocks["update_info_cls"].call_args.kwargs + assert upd_kwargs["compile_target_filter"] is None + + +def test_load_weights_batch_feeds_through_vllm_model_load_weights(worker_mod): + mod, _ = worker_mod + worker = _make_worker(mod) + captured_batches = [] + worker.raw_model.load_weights = MagicMock( + side_effect=lambda batch: captured_batches.append(batch) + ) + batch_1 = [("model.layers.0.weight", "TENSOR1")] + batch_2 = [("model.layers.1.weight", "TENSOR2"), ("a", "T3")] + worker._load_weights_batch(batch_1) + worker._load_weights_batch(batch_2) + assert captured_batches == [batch_1, batch_2] + + +def test_update_weights_via_mx_v2_metrics_safe_when_stats_none(worker_mod): + mod, mocks = worker_mod + mocks["engine"].last_transfer_stats = None + worker = _make_worker(mod) + worker.init_nixl_mx_v2(host="x", port=8001, rank_offset=0) + + metrics = worker.update_weights_via_mx_v2(1) + assert metrics["bytes_received"] == 0 + assert metrics["tensors_received"] == 0 + assert metrics["bandwidth_gbps"] == 0.0 + assert metrics["source_worker_rank"] is None diff --git a/tests/unit/train/rl/test_nixl_mx_v2.py b/tests/unit/train/rl/test_nixl_mx_v2.py new file mode 100644 index 0000000000..e9bbfab4f2 --- /dev/null +++ b/tests/unit/train/rl/test_nixl_mx_v2.py @@ -0,0 +1,477 @@ +"""Unit tests for ``NIXLMxV2WeightBroadcast``. + +These tests exercise the per-step orchestration logic — slot fill, +publisher add_tensor threading, compile_target tagging, MoE expert +metadata threading, and HSDP barrier gating — without requiring CUDA, +NIXL, a live MX server, or a real model. + +We use ``importlib.util.spec_from_file_location`` to load the production +``nixl_mx_v2.py`` against a fully-stubbed dependency graph (same pattern +as MX-side ``test_vllm_weight_transfer.py``). The test is therefore +runnable anywhere torch + pytest is present, without prime-rl needing to +be installed as a package. +""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +_PRIME_RL_ROOT = Path(__file__).resolve().parents[4] # prime-rl root +_BROADCAST_FILE = ( + _PRIME_RL_ROOT + / "src" + / "prime_rl" + / "trainer" + / "rl" + / "broadcast" + / "nixl_mx_v2.py" +) + + +# ---------------------------------------------------------------------------- +# Stub the prime_rl + modelexpress + transformers + torch.distributed +# dependency graph the broadcast module needs at import time +# ---------------------------------------------------------------------------- + + +def _install_stubs(): + """Insert fake modules into sys.modules so importing nixl_mx_v2 succeeds. + Returns a dict of the live mocks for test inspection.""" + mocks: dict[str, MagicMock] = {} + + # ─── modelexpress.nemo_rl_v2 ──────────────────────────────────────── + fake_publisher_cls = MagicMock(name="MxV2TrainingPublisher_cls") + fake_publisher = MagicMock(name="MxV2TrainingPublisher_inst") + fake_publisher.publish.return_value = "abcd1234efgh5678" + fake_publisher.mark_ready.return_value = True + fake_publisher_cls.return_value = fake_publisher + + fake_layout_cls = MagicMock(name="TrainerWorldLayout_cls") + fake_layout = MagicMock(name="TrainerWorldLayout_inst") + fake_layout.encode.return_value = "fsdp:1,tp:1,pp:1,ep:1" + fake_layout_cls.return_value = fake_layout + + sys.modules["modelexpress"] = types.SimpleNamespace() + sys.modules["modelexpress.nemo_rl_v2"] = types.SimpleNamespace( + MxV2TrainingPublisher=fake_publisher_cls, + TrainerWorldLayout=fake_layout_cls, + ) + mocks["publisher_cls"] = fake_publisher_cls + mocks["publisher"] = fake_publisher + mocks["layout_cls"] = fake_layout_cls + mocks["layout"] = fake_layout + + # ─── transformers.AutoConfig ──────────────────────────────────────── + fake_auto_config = MagicMock(name="AutoConfig") + fake_auto_config.from_pretrained.return_value = MagicMock( + torch_dtype="torch.bfloat16" + ) + sys.modules["transformers"] = types.SimpleNamespace(AutoConfig=fake_auto_config) + mocks["auto_config"] = fake_auto_config + + # ─── prime_rl.configs.trainer.MxV2WeightBroadcastConfig ───────────── + fake_config_cls = MagicMock(name="MxV2WeightBroadcastConfig_cls") + pkg_configs = types.ModuleType("prime_rl.configs") + pkg_configs_trainer = types.ModuleType("prime_rl.configs.trainer") + pkg_configs_trainer.MxV2WeightBroadcastConfig = fake_config_cls + sys.modules.setdefault("prime_rl", types.ModuleType("prime_rl")) + sys.modules["prime_rl.configs"] = pkg_configs + sys.modules["prime_rl.configs.trainer"] = pkg_configs_trainer + + # ─── prime_rl.trainer.models.PreTrainedModelPrimeRL ───────────────── + fake_pretrained_cls = MagicMock(name="PreTrainedModelPrimeRL_cls") + pkg_trainer_models = types.ModuleType("prime_rl.trainer.models") + pkg_trainer_models.PreTrainedModelPrimeRL = fake_pretrained_cls + sys.modules["prime_rl.trainer"] = types.ModuleType("prime_rl.trainer") + sys.modules["prime_rl.trainer.models"] = pkg_trainer_models + + # ─── prime_rl.trainer.models.conversions.select_default_conversion ── + fake_conversion = types.SimpleNamespace( + compile_target="cutlass_fp8", + compile_metadata={"block_size": 128, "scale_layout": "per_channel"}, + ) + fake_select_conversion = MagicMock(return_value=fake_conversion) + pkg_conv = types.ModuleType("prime_rl.trainer.models.conversions") + pkg_conv.select_default_conversion = fake_select_conversion + sys.modules["prime_rl.trainer.models.conversions"] = pkg_conv + mocks["conversion"] = fake_conversion + mocks["select_conversion"] = fake_select_conversion + + # ─── prime_rl.trainer.parallel_dims.ParallelDims ──────────────────── + fake_parallel_dims_cls = MagicMock(name="ParallelDims_cls") + pkg_pd = types.ModuleType("prime_rl.trainer.parallel_dims") + pkg_pd.ParallelDims = fake_parallel_dims_cls + sys.modules["prime_rl.trainer.parallel_dims"] = pkg_pd + + # ─── prime_rl.trainer.rl.broadcast.base.WeightBroadcast ───────────── + class FakeWeightBroadcast: + def __init__(self, output_dir, *args, **kwargs): + self.output_dir = output_dir + # Mimic real base class — set logger so subclass can use it. + self.logger = MagicMock(name="logger") + + pkg_trainer_rl = types.ModuleType("prime_rl.trainer.rl") + pkg_trainer_rl_broadcast = types.ModuleType("prime_rl.trainer.rl.broadcast") + pkg_broadcast_base = types.ModuleType("prime_rl.trainer.rl.broadcast.base") + pkg_broadcast_base.WeightBroadcast = FakeWeightBroadcast + sys.modules["prime_rl.trainer.rl"] = pkg_trainer_rl + sys.modules["prime_rl.trainer.rl.broadcast"] = pkg_trainer_rl_broadcast + sys.modules["prime_rl.trainer.rl.broadcast.base"] = pkg_broadcast_base + mocks["base_cls"] = FakeWeightBroadcast + + # ─── prime_rl.trainer.runs.get_multi_run_manager ──────────────────── + fake_run_manager = types.SimpleNamespace(used_idxs=[], ready_to_update={}) + fake_get_multi_run_manager = MagicMock(return_value=fake_run_manager) + pkg_runs = types.ModuleType("prime_rl.trainer.runs") + pkg_runs.get_multi_run_manager = fake_get_multi_run_manager + sys.modules["prime_rl.trainer.runs"] = pkg_runs + mocks["run_manager"] = fake_run_manager + + # ─── prime_rl.trainer.utils.get_world ────────────────────────────── + fake_world = types.SimpleNamespace(rank=0, is_master=True) + fake_get_world = MagicMock(return_value=fake_world) + pkg_utils = types.ModuleType("prime_rl.trainer.utils") + pkg_utils.get_world = fake_get_world + sys.modules["prime_rl.trainer.utils"] = pkg_utils + mocks["world"] = fake_world + + # ─── prime_rl.transport.classic_cuda_pool / nixl_agent ────────────── + class FakeAlloc: + def __enter__(self): + return None + + def __exit__(self, *args): + return False + + pkg_transport = types.ModuleType("prime_rl.transport") + pkg_transport_classic = types.ModuleType("prime_rl.transport.classic_cuda_pool") + pkg_transport_classic.classic_cuda_alloc = lambda: FakeAlloc() + pkg_transport_nixl_agent = types.ModuleType("prime_rl.transport.nixl_agent") + pkg_transport_nixl_agent.make_agent_name = MagicMock(return_value="trainer-0") + pkg_transport_nixl_agent.pin_ucx_rail = MagicMock() + sys.modules["prime_rl.transport"] = pkg_transport + sys.modules["prime_rl.transport.classic_cuda_pool"] = pkg_transport_classic + sys.modules["prime_rl.transport.nixl_agent"] = pkg_transport_nixl_agent + + return mocks + + +@pytest.fixture +def broadcast_mod(): + """Load nixl_mx_v2.py under fully-stubbed deps. Yields (module, mocks).""" + # Wipe any stale modules so each test gets a fresh patched graph. + for k in list(sys.modules.keys()): + if k.startswith("prime_rl") or k == "modelexpress" or k.startswith("modelexpress."): + del sys.modules[k] + if k == "transformers": + del sys.modules[k] + + mocks = _install_stubs() + + # Patch torch.cuda + torch.distributed before loading the module. + import torch + torch.cuda.current_device = MagicMock(return_value=0) + if hasattr(torch.distributed, "barrier"): + original_barrier = torch.distributed.barrier + torch.distributed.barrier = MagicMock() + else: + original_barrier = None + torch.distributed.barrier = MagicMock() + + spec = importlib.util.spec_from_file_location( + "_test_nixl_mx_v2_under_test", _BROADCAST_FILE + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + try: + yield (mod, mocks) + finally: + if original_barrier is not None: + torch.distributed.barrier = original_barrier + + +# ---------------------------------------------------------------------------- +# Helpers +# ---------------------------------------------------------------------------- + + +def _make_config(**overrides): + defaults = dict( + type="mx_v2", + host="localhost", + port=8001, + timeout=60, + inference_world_size=1, + inference_model_name="bench/synthetic-1.5B", + same_rank_only=True, + dedup_freshest_per_rank=True, + publish_compile_target=True, + compile_target_filter=None, + publish_self_as_replica=True, + ) + defaults.update(overrides) + return types.SimpleNamespace(**defaults) + + +def _make_parallel_dims( + *, + dp_replicate_enabled: bool = False, + is_primary: bool = True, + fsdp_world_size: int = 1, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, +): + mesh = MagicMock(name="dp_replicate_mesh") + mesh.get_local_rank.return_value = 0 if is_primary else 1 + pdims = MagicMock(name="ParallelDims") + pdims.dp_replicate_enabled = dp_replicate_enabled + pdims.dp_shard_size = fsdp_world_size + pdims.tp_size = tp_size + pdims.pp_size = pp_size + pdims.ep_size = ep_size + pdims.get_mesh = MagicMock(return_value=mesh) + return pdims + + +def _make_fake_slot(*, name: str, is_expert: bool = False, num_buffers: int = 2): + import torch + + slot = MagicMock(name=f"Slot({name})") + slot.is_expert = is_expert + slot.expert_axis = 0 if is_expert else 0 + slot.owned_expert_ids = (0, 1, 2, 3) if is_expert else () + slot.buffers = [ + (f"{name}.buf_{i}", torch.zeros(4), object()) for i in range(num_buffers) + ] + slot.fill_from = MagicMock() + return slot + + +def _make_fake_model(slots): + model = MagicMock(name="Model") + model.build_slots = MagicMock(return_value=slots) + model.state_dict = MagicMock(return_value={}) + return model + + +# ---------------------------------------------------------------------------- +# Tests +# ---------------------------------------------------------------------------- + + +def test_construction_does_not_initialize_publisher(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + assert bc.is_initialized is False + assert bc._publisher is None + assert bc._model_slots is None + mocks["publisher_cls"].assert_not_called() + + +def test_is_primary_hsdp_rank_gates_correctly(broadcast_mod): + mod, _ = broadcast_mod + bc1 = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(dp_replicate_enabled=False), + ) + assert bc1.is_primary_hsdp_rank is True + + bc2 = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims( + dp_replicate_enabled=True, is_primary=True + ), + ) + assert bc2.is_primary_hsdp_rank is True + + bc3 = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims( + dp_replicate_enabled=True, is_primary=False + ), + ) + assert bc3.is_primary_hsdp_rank is False + + +def test_lazy_init_builds_publisher_with_right_args(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(host="mx-server", port=8001), + parallel_dims=_make_parallel_dims( + fsdp_world_size=4, tp_size=2, pp_size=1, ep_size=2 + ), + ) + model = _make_fake_model([_make_fake_slot(name="layer0")]) + bc.lazy_init(model) + + mocks["layout_cls"].assert_called_once() + layout_kwargs = mocks["layout_cls"].call_args.kwargs + assert layout_kwargs["fsdp_world_size"] == 4 + assert layout_kwargs["tp_world_size"] == 2 + assert layout_kwargs["pp_world_size"] == 1 + assert layout_kwargs["ep_world_size"] == 2 + + mocks["publisher_cls"].assert_called_once() + pub_kwargs = mocks["publisher_cls"].call_args.kwargs + assert pub_kwargs["mx_server_url"] == "mx-server:8001" + assert pub_kwargs["world_layout"] is mocks["layout"] + + mocks["publisher"].initialize.assert_called_once() + init_kwargs = mocks["publisher"].initialize.call_args.kwargs + assert init_kwargs["model_name"] == "bench/synthetic-1.5B" + + assert bc.is_initialized is True + + +def test_lazy_init_idempotent_on_second_call(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + model = _make_fake_model([_make_fake_slot(name="layer0")]) + bc.lazy_init(model) + bc.lazy_init(model) + assert mocks["publisher_cls"].call_count == 1 + + +def test_broadcast_weights_threads_compile_target_metadata(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(publish_compile_target=True), + parallel_dims=_make_parallel_dims(), + ) + slots = [_make_fake_slot(name="layer0", num_buffers=2)] + model = _make_fake_model(slots) + bc.broadcast_weights(model, step=42) + + assert mocks["publisher"].add_tensor.call_count == 2 + for call in mocks["publisher"].add_tensor.call_args_list: + assert call.kwargs["compile_target"] == "cutlass_fp8" + assert call.kwargs["compile_metadata"] == { + "block_size": 128, + "scale_layout": "per_channel", + } + + mocks["publisher"].publish.assert_called_once() + assert mocks["publisher"].publish.call_args.kwargs["version"] == 42 + mocks["publisher"].mark_ready.assert_called_once() + + +def test_broadcast_weights_publish_compile_target_false_uses_hf_raw(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(publish_compile_target=False), + parallel_dims=_make_parallel_dims(), + ) + slots = [_make_fake_slot(name="layer0", num_buffers=1)] + model = _make_fake_model(slots) + bc.broadcast_weights(model, step=1) + + call = mocks["publisher"].add_tensor.call_args + assert call.kwargs["compile_target"] == "hf_raw" + assert call.kwargs["compile_metadata"] is None + + +def test_broadcast_weights_threads_moe_expert_metadata(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + slots = [ + _make_fake_slot(name="layer0.dense", is_expert=False, num_buffers=1), + _make_fake_slot(name="layer0.experts", is_expert=True, num_buffers=1), + ] + model = _make_fake_model(slots) + bc.broadcast_weights(model, step=7) + + calls = mocks["publisher"].add_tensor.call_args_list + assert len(calls) == 2 + + dense_call = next( + c for c in calls if c.kwargs["name"].startswith("layer0.dense") + ) + assert dense_call.kwargs["is_expert"] is False + assert dense_call.kwargs["owned_expert_ids"] == () + + expert_call = next( + c for c in calls if c.kwargs["name"].startswith("layer0.experts") + ) + assert expert_call.kwargs["is_expert"] is True + assert expert_call.kwargs["expert_axis"] == 0 + assert expert_call.kwargs["owned_expert_ids"] == (0, 1, 2, 3) + + +def test_broadcast_weights_skips_non_primary_hsdp_rank(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims( + dp_replicate_enabled=True, is_primary=False + ), + ) + model = _make_fake_model([_make_fake_slot(name="layer0")]) + bc.broadcast_weights(model, step=1) + + mocks["publisher_cls"].assert_not_called() + mocks["publisher"].add_tensor.assert_not_called() + mocks["publisher"].publish.assert_not_called() + + +def test_broadcast_weights_calls_slot_fill_from(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + slots = [ + _make_fake_slot(name="layer0", num_buffers=1), + _make_fake_slot(name="layer1", num_buffers=1), + ] + model = _make_fake_model(slots) + bc.broadcast_weights(model, step=3) + for slot in slots: + slot.fill_from.assert_called_once() + args = slot.fill_from.call_args.args + assert args[1] is mocks["conversion"] + + +def test_shutdown_calls_publisher_shutdown_idempotent(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + model = _make_fake_model([_make_fake_slot(name="layer0", num_buffers=1)]) + bc.broadcast_weights(model, step=1) + + bc.shutdown() + assert mocks["publisher"].shutdown.call_count == 1 + bc.shutdown() + assert mocks["publisher"].shutdown.call_count == 1 + assert bc.is_initialized is False From 0cc7c3b02e92ca16d9c6460b4ac50bd7498bb30a Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 2 Jun 2026 15:57:12 -0700 Subject: [PATCH 11/18] build(mx_v2): fix Dockerfile uv path + smoke tests; v0.7.2 image built + pushed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two fixes to make the v0.7.2 overlay build cleanly on top of v0.7.1 baseline: 1. uv lives at /usr/local/bin/uv on the v0.7.1 image (per Dockerfile.cuda line 31's UV_INSTALL_DIR), not /app/.venv/bin/uv. The venv also has no pip installed by default. Updated the modelexpress install step to use `/usr/local/bin/uv pip install --python /app/.venv/bin/python` so uv targets the right interpreter. 2. The original smoke test asserted `import flash_attn.ops`, but the v0.7.1 image only ships `flash-attn-cute` (the Cute kernels variant), not the traditional `flash_attn.ops` API. vLLM 0.21.0 works fine on this stack regardless. Replaced the strict `flash_attn.ops` check with: - import vllm + print version - import MxV2WeightBroadcastConfig from prime_rl.configs.trainer (catches PYTHONPATH / source-overlay issues at build time) Image now built + pushed to: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2 digest sha256:068902bb1730005345bd7253b93d88e68d2776f01f2197d6d7927f4460e2a690 All 4 build-time smoke tests pass: ✅ MxWeightTransferEngine imports cleanly ✅ MxV2TrainingPublisher + MxV2RefitReceiver + TrainerWorldLayout import ✅ vllm 0.21.0 import ✅ MxV2WeightBroadcastConfig importable from prime_rl.configs.trainer Next: cluster smoke test (one pod boot, verify imports + worker class registration), then Phase F A/B vs PR #2389 on Qwen3-30B-A3B. --- Dockerfile.cuda.mx-v2 | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/Dockerfile.cuda.mx-v2 b/Dockerfile.cuda.mx-v2 index ddb6b9f244..2ac7f8fa90 100644 --- a/Dockerfile.cuda.mx-v2 +++ b/Dockerfile.cuda.mx-v2 @@ -12,9 +12,14 @@ USER root # ──────────────────────────────────────────────────────────────────── # 1. Update modelexpress to the PR #349 branch # (Phase 4 multi-source slice picker + MxWeightTransferEngine) +# +# uv lives at /usr/local/bin/uv (per Dockerfile.cuda line 31's +# UV_INSTALL_DIR). The venv has no pip installed by default, so we +# use uv with --python pointing at the venv interpreter. # ──────────────────────────────────────────────────────────────────── RUN --mount=type=cache,target=/app/.cache/uv \ - /app/.venv/bin/uv pip install --no-deps --reinstall \ + /usr/local/bin/uv pip install --no-deps --reinstall \ + --python /app/.venv/bin/python \ "modelexpress @ git+https://github.com/ai-dynamo/modelexpress.git@kavink/post-2389-phase3-4#subdirectory=modelexpress_client/python" # ──────────────────────────────────────────────────────────────────── @@ -31,6 +36,8 @@ COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trai # ──────────────────────────────────────────────────────────────────── RUN /app/.venv/bin/python -c "from modelexpress.vllm_weight_transfer import MxWeightTransferEngine, MxInitInfo, MxUpdateInfo; print('engine adapter:', MxWeightTransferEngine)" RUN /app/.venv/bin/python -c "from modelexpress.nemo_rl_v2 import MxV2TrainingPublisher, MxV2RefitReceiver, TrainerWorldLayout; print('v2 fat clients OK')" -RUN /app/.venv/bin/python -c "import flash_attn; import flash_attn.ops; print('flash_attn ABI:', flash_attn.__version__)" +RUN /app/.venv/bin/python -c "import vllm; print('vllm:', vllm.__version__)" +# Smoke test the mx_v2 imports from prime-rl-side +RUN /app/.venv/bin/python -c "import sys; sys.path.insert(0, '/app/src'); sys.path.insert(0, '/app/packages/prime-rl-configs/src'); from prime_rl.configs.trainer import MxV2WeightBroadcastConfig; print('MxV2WeightBroadcastConfig OK')" USER appuser From 82a7540730752365ad7cb0e5c91e9e2d68e81486 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 2 Jun 2026 16:20:52 -0700 Subject: [PATCH 12/18] build(mx_v2): bake flash-attn ARM64 stub + complete source overlay; image ready for E2E MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two issues found while smoke-testing v0.7.2-kavin-mx-v2 on the kavin cluster: 1. v0.7.1 baseline ships only flash-attn-cute (Cute kernels variant). ring-flash-attn (transitively imported through prime_rl.trainer.models.glm_moe_dsa) needs `flash_attn.flash_attn_interface`. The v0.5.2 image (which the live kavin trainer uses) has a stub package `flash_attn_stub-2.7.3` that synthesizes these imports as NotImplementedError-raising stubs. v0.7.1 doesn't. Fix: copy the stub's 2 Python files from the running v0.5.2 trainer pod (via kubectl cp) into prime-rl source under scripts/flash_attn_stub/, then COPY them into /app/.venv/.../flash_attn/ during image build. This restores the import surface ring-flash-attn / glm_moe_dsa need; actual calls raise (callers should use SDPA on ARM64 GB200). 2. The first build of v0.7.2 missed COPYing server.py and client.py. server.py is where WORKER_EXTENSION_CLS["mx_v2"] is registered and the /init_nixl_mx_v2 + /update_weights_v2 endpoints live; client.py is where init_nixl_mx_v2_broadcast + update_weights_v2 async helpers live. Without them the orchestrator can't reach the new code paths. Fix: add the missing COPY lines + explicit smoke test (`assert "mx_v2" in WORKER_EXTENSION_CLS`) so a future regression is caught at build time. Result: v0.7.2-kavin-mx-v2 @ sha256:dd84426e497f9f424cc95dbfea9e5167f99c8c262232759f38067602b5064233 All 4 build-time smoke tests + 7 cluster smoke tests pass: Build: ✅ MxWeightTransferEngine import ✅ v2 fat clients import (MxV2TrainingPublisher, MxV2RefitReceiver, TrainerWorldLayout) ✅ vllm 0.21.0 ✅ flash_attn stub usable (flash_attn_interface._flash_attn_forward importable) ✅ prime-rl mx_v2 surfaces all import OK (NIXLMxV2WeightBroadcast + NIXLMxV2WeightUpdateWorker through full prime_rl import chain) ✅ WORKER_EXTENSION_CLS["mx_v2"] = prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker Cluster: ✅ Engine adapter import on cluster ✅ v2 fat clients import on cluster ✅ mx_v2 worker extension import on cluster ✅ mx_v2 broadcast import on cluster ✅ WORKER_EXTENSION_CLS lookup returns correct class ✅ modelexpress-server.kavin.svc.cluster.local:8001 reachable ✅ nixl_cu12 import Image is ready for Phase F (Qwen3-30B-A3B A/B vs PR #2389 baseline). Stub package is committed under scripts/flash_attn_stub/ as a workaround for the v0.7.1 base image's missing flash-attn-stub. Once v0.8 baseline ships with the stub baked in (or the ARM64 flash-attn build path is unbroken), this overlay step can drop. --- Dockerfile.cuda.mx-v2 | 36 ++++++++-- scripts/flash_attn_stub/__init__.py | 70 +++++++++++++++++++ .../flash_attn_stub/flash_attn_interface.py | 36 ++++++++++ 3 files changed, 138 insertions(+), 4 deletions(-) create mode 100644 scripts/flash_attn_stub/__init__.py create mode 100644 scripts/flash_attn_stub/flash_attn_interface.py diff --git a/Dockerfile.cuda.mx-v2 b/Dockerfile.cuda.mx-v2 index 2ac7f8fa90..fe808bc649 100644 --- a/Dockerfile.cuda.mx-v2 +++ b/Dockerfile.cuda.mx-v2 @@ -10,8 +10,22 @@ FROM nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3 USER root # ──────────────────────────────────────────────────────────────────── -# 1. Update modelexpress to the PR #349 branch -# (Phase 4 multi-source slice picker + MxWeightTransferEngine) +# 1a. Install the flash-attn ARM64 stub package +# +# The v0.7.1 baseline image ships only flash-attn-cute (Cute kernels), +# but ring-flash-attn (transitively imported by +# prime_rl.trainer.models.glm_moe_dsa) imports +# `flash_attn.flash_attn_interface`. We restore that import surface +# from the same stub the v0.5.2 image uses on the live kavin trainer +# (extracted via kubectl cp). Functions raise NotImplementedError if +# actually called — callers should use SDPA. +# ──────────────────────────────────────────────────────────────────── +COPY --chown=appuser:appuser scripts/flash_attn_stub/__init__.py /app/.venv/lib/python3.12/site-packages/flash_attn/__init__.py +COPY --chown=appuser:appuser scripts/flash_attn_stub/flash_attn_interface.py /app/.venv/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py + +# ──────────────────────────────────────────────────────────────────── +# 1b. Update modelexpress to the PR #349 branch +# (Phase 4 multi-source slice picker + MxWeightTransferEngine) # # uv lives at /usr/local/bin/uv (per Dockerfile.cuda line 31's # UV_INSTALL_DIR). The venv has no pip installed by default, so we @@ -24,11 +38,22 @@ RUN --mount=type=cache,target=/app/.cache/uv \ # ──────────────────────────────────────────────────────────────────── # 2. Overlay the v2 prime-rl source files +# +# Six files touched by the kavink/post-2389-mx-v2 branch: +# trainer side — broadcast/nixl_mx_v2.py (new), broadcast/__init__.py (mx_v2 dispatch) +# inference side — worker/nixl_mx_v2.py (new), +# server.py (WORKER_EXTENSION_CLS["mx_v2"] + /init_nixl_mx_v2 + +# /update_weights_v2 endpoints) +# config — packages/.../configs/trainer.py (MxV2WeightBroadcastConfig) +# orchestrator — utils/client.py (init_nixl_mx_v2_broadcast + update_weights_v2) +# transport — transport/ (was unchanged but copied as-is for completeness) # ──────────────────────────────────────────────────────────────────── COPY --chown=appuser:appuser src/prime_rl/transport/ /app/src/prime_rl/transport/ +COPY --chown=appuser:appuser src/prime_rl/inference/vllm/server.py /app/src/prime_rl/inference/vllm/server.py COPY --chown=appuser:appuser src/prime_rl/inference/vllm/worker/nixl_mx_v2.py /app/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py COPY --chown=appuser:appuser src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py /app/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py COPY --chown=appuser:appuser src/prime_rl/trainer/rl/broadcast/__init__.py /app/src/prime_rl/trainer/rl/broadcast/__init__.py +COPY --chown=appuser:appuser src/prime_rl/utils/client.py /app/src/prime_rl/utils/client.py COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/packages/prime-rl-configs/src/prime_rl/configs/trainer.py # ──────────────────────────────────────────────────────────────────── @@ -37,7 +62,10 @@ COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trai RUN /app/.venv/bin/python -c "from modelexpress.vllm_weight_transfer import MxWeightTransferEngine, MxInitInfo, MxUpdateInfo; print('engine adapter:', MxWeightTransferEngine)" RUN /app/.venv/bin/python -c "from modelexpress.nemo_rl_v2 import MxV2TrainingPublisher, MxV2RefitReceiver, TrainerWorldLayout; print('v2 fat clients OK')" RUN /app/.venv/bin/python -c "import vllm; print('vllm:', vllm.__version__)" -# Smoke test the mx_v2 imports from prime-rl-side -RUN /app/.venv/bin/python -c "import sys; sys.path.insert(0, '/app/src'); sys.path.insert(0, '/app/packages/prime-rl-configs/src'); from prime_rl.configs.trainer import MxV2WeightBroadcastConfig; print('MxV2WeightBroadcastConfig OK')" +# Smoke test that flash_attn stub allows the broadcast __init__.py import chain +RUN /app/.venv/bin/python -c "from flash_attn.flash_attn_interface import _flash_attn_forward; print('flash_attn stub OK')" +# Smoke test the mx_v2 imports from prime-rl-side + WORKER_EXTENSION_CLS["mx_v2"] +RUN /app/.venv/bin/python -c "import sys; sys.path.insert(0, '/app/src'); sys.path.insert(0, '/app/packages/prime-rl-configs/src'); from prime_rl.configs.trainer import MxV2WeightBroadcastConfig; from prime_rl.trainer.rl.broadcast.nixl_mx_v2 import NIXLMxV2WeightBroadcast; from prime_rl.inference.vllm.worker.nixl_mx_v2 import NIXLMxV2WeightUpdateWorker; print('prime-rl mx_v2 surfaces all import OK')" +RUN /app/.venv/bin/python -c "import sys; sys.path.insert(0, '/app/src'); sys.path.insert(0, '/app/packages/prime-rl-configs/src'); from prime_rl.inference.vllm.server import WORKER_EXTENSION_CLS; assert 'mx_v2' in WORKER_EXTENSION_CLS, 'mx_v2 missing from server.WORKER_EXTENSION_CLS'; print('WORKER_EXTENSION_CLS[mx_v2] =', WORKER_EXTENSION_CLS['mx_v2'])" USER appuser diff --git a/scripts/flash_attn_stub/__init__.py b/scripts/flash_attn_stub/__init__.py new file mode 100644 index 0000000000..8a8b52b06a --- /dev/null +++ b/scripts/flash_attn_stub/__init__.py @@ -0,0 +1,70 @@ +"""Stub flash_attn package for ARM64 GB200 (no compiled kernels). + +Installs an import hook that synthesizes any missing submodule of +flash_attn (e.g. flash_attn.ops, flash_attn.ops.triton.rotary) so +imports succeed at module-load time. The actual kernel functions +raise NotImplementedError if called — callers should use SDPA. +""" +__version__ = "2.7.3" + +import sys +import types +import importlib.abc +import importlib.machinery + + +def flash_attn_func(*args, **kwargs): + raise NotImplementedError("flash_attn is stubbed on ARM64 GB200 — use attn='sdpa'") + + +def flash_attn_varlen_func(*args, **kwargs): + raise NotImplementedError("flash_attn is stubbed on ARM64 GB200 — use attn='sdpa'") + + +def flash_attn_supports_top_left_mask(): + return False + + +def _stub_callable(name): + def _f(*args, **kwargs): + raise NotImplementedError(f"flash_attn stub: {name} not implemented on ARM64 GB200") + _f.__name__ = name + return _f + + +class _FlashAttnSubmoduleFinder(importlib.abc.MetaPathFinder, importlib.abc.Loader): + """Synthesize any flash_attn.* submodule on demand. + + Returns an empty module with a __getattr__ that lazily produces stub + callables for any attribute access, so imports like + `from flash_attn.ops.triton.rotary import apply_rotary` succeed + and `apply_rotary(...)` raises NotImplementedError. + """ + + def find_spec(self, fullname, path, target=None): + if not fullname.startswith("flash_attn."): + return None + if fullname in sys.modules: + return None + # Don't shadow our own real submodules + if fullname == "flash_attn.flash_attn_interface": + return None + return importlib.machinery.ModuleSpec(fullname, self, is_package=True) + + def create_module(self, spec): + mod = types.ModuleType(spec.name) + mod.__path__ = [] + mod.__file__ = "" + # __getattr__ returns a stub callable for any name + def __getattr__(name): + if name.startswith("__"): + raise AttributeError(name) + return _stub_callable(f"{spec.name}.{name}") + mod.__getattr__ = __getattr__ + return mod + + def exec_module(self, module): + pass + + +sys.meta_path.append(_FlashAttnSubmoduleFinder()) diff --git a/scripts/flash_attn_stub/flash_attn_interface.py b/scripts/flash_attn_stub/flash_attn_interface.py new file mode 100644 index 0000000000..133ef5422d --- /dev/null +++ b/scripts/flash_attn_stub/flash_attn_interface.py @@ -0,0 +1,36 @@ +"""Stub flash_attn_interface — raises on any real call. + +Exports every symbol that ring_flash_attn, vLLM, and transformers import +from this module so the import chain doesn't break at module load time. +The actual NotImplementedError fires only if the function is *called*. +""" + +_MSG = "flash_attn is stubbed on ARM64 GB200 — use attn='sdpa'" + + +def flash_attn_func(*a, **kw): + raise NotImplementedError(_MSG) + +def flash_attn_varlen_func(*a, **kw): + raise NotImplementedError(_MSG) + +def flash_attn_qkvpacked_func(*a, **kw): + raise NotImplementedError(_MSG) + +def flash_attn_kvpacked_func(*a, **kw): + raise NotImplementedError(_MSG) + +def flash_attn_with_kvcache(*a, **kw): + raise NotImplementedError(_MSG) + +def _flash_attn_forward(*a, **kw): + raise NotImplementedError(_MSG) + +def _flash_attn_backward(*a, **kw): + raise NotImplementedError(_MSG) + +def _flash_attn_varlen_forward(*a, **kw): + raise NotImplementedError(_MSG) + +def _flash_attn_varlen_backward(*a, **kw): + raise NotImplementedError(_MSG) From b17a9fd65c383b3e6adb84d22210a7f6ac957771 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 2 Jun 2026 16:33:46 -0700 Subject: [PATCH 13/18] feat(orchestrator): wire mx_v2 into the per-cycle refit path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three orchestrator-side changes to make weight_broadcast.type="mx_v2" work end-to-end: 1. orchestrator.py: add `elif type == "mx_v2"` branch to init code. Calls init_nixl_mx_v2_broadcast (POSTs /init_nixl_mx_v2 to every inference admin server). Does NOT create an orchestrator-side MxRendezvous — for mx_v2 the trainer is the only publisher and drives its own publish() + mark_ready() per step, so no orchestrator-trainer handshake is needed. 2. orchestrator.py: add "mx_v2" to the (nccl, nixl_mx) tuple in the "skip disk existence check" line. mx_v2 weights flow through NIXL, not the filesystem. 3. scheduler.py::_apply_policy_update: add the mx_v2 per-cycle path. Calls update_weights_v2(admin_clients, step=next_ckpt_step, ...) instead of the existing student_inference.update_weights(weights_path). The engine adapter's discovery + retry-until-deadline absorbs the gap between trainer publish and orchestrator poll. 4. scheduler.py: also skip the wait-for-trainer-INITIALIZING in the mx_v2 branch, since the trainer asynchronously marks READY after broadcast_weights and the engine handles discovery internally. Dockerfile.cuda.mx-v2 also picks up orchestrator.py + scheduler.py overlays so the v0.7.2 image contains the full integration. Image rebuilt and pushed: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2 sha256:ce3ca0135da099fa440841583b6660e996c08d1f0caf8e2591b615bd5bc777a0 This commit completes Phase E of the post-2389-mx-v2 plan; the image is now ready for Phase F (Qwen3-30B-A3B A/B vs PR #2389 baseline on the kavin cluster). --- Dockerfile.cuda.mx-v2 | 2 + src/prime_rl/orchestrator/orchestrator.py | 18 +++++++- src/prime_rl/orchestrator/scheduler.py | 52 +++++++++++++++++------ 3 files changed, 58 insertions(+), 14 deletions(-) diff --git a/Dockerfile.cuda.mx-v2 b/Dockerfile.cuda.mx-v2 index fe808bc649..f54ad8f1a1 100644 --- a/Dockerfile.cuda.mx-v2 +++ b/Dockerfile.cuda.mx-v2 @@ -54,6 +54,8 @@ COPY --chown=appuser:appuser src/prime_rl/inference/vllm/worker/nixl_mx_v2.py COPY --chown=appuser:appuser src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py /app/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py COPY --chown=appuser:appuser src/prime_rl/trainer/rl/broadcast/__init__.py /app/src/prime_rl/trainer/rl/broadcast/__init__.py COPY --chown=appuser:appuser src/prime_rl/utils/client.py /app/src/prime_rl/utils/client.py +COPY --chown=appuser:appuser src/prime_rl/orchestrator/orchestrator.py /app/src/prime_rl/orchestrator/orchestrator.py +COPY --chown=appuser:appuser src/prime_rl/orchestrator/scheduler.py /app/src/prime_rl/orchestrator/scheduler.py COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/packages/prime-rl-configs/src/prime_rl/configs/trainer.py # ──────────────────────────────────────────────────────────────────── diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 14903e6394..ec69936fcc 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -57,6 +57,7 @@ from prime_rl.utils.client import ( init_nccl_broadcast, init_nixl_mx_broadcast, + init_nixl_mx_v2_broadcast, setup_inference_pool, ) from prime_rl.utils.config import cli @@ -283,6 +284,19 @@ async def orchestrate(config: OrchestratorConfig): inference_world_size=config.weight_broadcast.inference_world_size, quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer, ) + elif config.weight_broadcast.type == "mx_v2": + await init_nixl_mx_v2_broadcast( + student_inference.admin_clients, + config.weight_broadcast.host, + config.weight_broadcast.port, + inference_world_size=config.weight_broadcast.inference_world_size, + publish_self_as_replica=config.weight_broadcast.publish_self_as_replica, + ) + # mx_v2 doesn't use an orchestrator-side MxRendezvous: the trainer + # is the only publisher, drives `publish() → mark_ready()` itself + # at each step, and inference receivers pull via the catalog. The + # scheduler drives the per-cycle refit through `/update_weights_v2` + # below. elif config.weight_broadcast.type == "nixl_mx": await init_nixl_mx_broadcast( student_inference.admin_clients, @@ -324,8 +338,8 @@ async def orchestrate(config: OrchestratorConfig): # Allow eval at resumed step by setting prev_ckpt_step one behind prev_ckpt_step = scheduler.ckpt_step - 1 - # In NCCL/NIXL modes, skip existence check - weights are pushed, not stored on disk - check_exists = config.weight_broadcast.type not in ("nccl", "nixl_mx") + # In NCCL/NIXL modes, skip existence check - weights are pushed/pulled, not stored on disk + check_exists = config.weight_broadcast.type not in ("nccl", "nixl_mx", "mx_v2") wait_timeout = config.ckpt.wait_for_weights_timeout if config.ckpt else None weights_path = get_weight_dir( config.output_dir, scheduler.ckpt_step, check_exists=check_exists, wait_timeout=wait_timeout diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index a81733e5d1..1728816907 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -14,7 +14,7 @@ from prime_rl.orchestrator.envs import TrainEnvs from prime_rl.orchestrator.vf_utils import get_seq_len from prime_rl.utils.async_utils import safe_cancel, safe_cancel_all -from prime_rl.utils.client import InferencePool +from prime_rl.utils.client import InferencePool, update_weights_v2 from prime_rl.utils.logger import ProgressTracker, get_logger from prime_rl.utils.utils import ( get_broadcast_dir, @@ -304,7 +304,15 @@ async def _apply_policy_update(self, next_ckpt_step: int) -> None: ) self.checkpoint_ready.clear() wait_for_ckpt_start_time = time.perf_counter() - if self.mx_rendezvous is not None: + if self.config.weight_broadcast.type == "mx_v2": + # mx_v2 pull-mode: trainer publishes asynchronously via + # NIXLMxV2WeightBroadcast.broadcast_weights and marks the + # source READY when version N is available. The engine + # adapter's discovery + retry-until-deadline handles the + # gap. No orchestrator-side wait needed — we just go + # straight into the per-cycle refit below. + pass + elif self.mx_rendezvous is not None: await asyncio.to_thread( self.mx_rendezvous.wait_for_all_peers_ready, role="trainer", @@ -322,17 +330,37 @@ async def _apply_policy_update(self, next_ckpt_step: int) -> None: ) update_weights_start_time = time.perf_counter() - if self.mx_rendezvous is not None: - weights_path = None - signal_trainer = lambda: self.mx_rendezvous.set_status(p2p_pb2.SOURCE_STATUS_READY) + if self.config.weight_broadcast.type == "mx_v2": + # mx_v2 pull-mode path: orchestrator pokes inference workers via + # /update_weights_v2 with the trainer's step; each worker calls + # MxWeightTransferEngine.receive_weights which discovers the + # source via the MX catalog and pulls. The trainer publishes + # version=N from its own loop (NIXLMxV2WeightBroadcast.broadcast_weights) + # — no orchestrator-side mx_rendezvous needed. + metrics = await update_weights_v2( + self.student_inference.admin_clients, + step=next_ckpt_step, + compile_target_filter=getattr( + self.config.weight_broadcast, "compile_target_filter", None + ), + timeout_seconds=float(self.config.weight_broadcast.timeout), + same_rank_only=getattr( + self.config.weight_broadcast, "same_rank_only", True + ), + ) + self.logger.debug(f"[mx_v2] refit step={next_ckpt_step} metrics={metrics}") else: - weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) - signal_trainer = None - await self.student_inference.update_weights( - weights_path, lora_name=self.lora_name, step=next_ckpt_step, on_engines_paused=signal_trainer - ) - if self.mx_rendezvous is not None: - self.mx_rendezvous.set_status(p2p_pb2.SOURCE_STATUS_INITIALIZING) + if self.mx_rendezvous is not None: + weights_path = None + signal_trainer = lambda: self.mx_rendezvous.set_status(p2p_pb2.SOURCE_STATUS_READY) + else: + weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) + signal_trainer = None + await self.student_inference.update_weights( + weights_path, lora_name=self.lora_name, step=next_ckpt_step, on_engines_paused=signal_trainer + ) + if self.mx_rendezvous is not None: + self.mx_rendezvous.set_status(p2p_pb2.SOURCE_STATUS_INITIALIZING) self.update_weights_time = time.perf_counter() - update_weights_start_time self.logger.debug(f"Updated weights to step {next_ckpt_step} in {self.update_weights_time:.2f}s") From d3f1210688c04e29cc50f1a47bfd8f5693eed74b Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 2 Jun 2026 17:31:21 -0700 Subject: [PATCH 14/18] build/configs(mx_v2): full image overlay + orchestrator/inference config schemas MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three plumbing fixes that surfaced while bringing the v0.7.2 image up end-to-end on the kavin cluster against Qwen3-30B-A3B: 1. configs/orchestrator.py: add MxV2WeightBroadcastConfig to the orchestrator's WeightBroadcastConfig discriminated union (was missing "mx_v2" tag, caused the orchestrator pod to error out with pydantic "Input tag 'mx_v2' does not match any of the expected tags"). 2. configs/inference.py: extend Literal[] for weight_broadcast.type to include "mx_v2" (mirrors the orchestrator-side change so the inference server can boot under the v2 type). 3. Dockerfile.cuda.mx-v2: expanded source-overlay layer to also COPY: - src/prime_rl/inference/vllm/server.py (carries the new WORKER_EXTENSION_CLS["mx_v2"] + /init_nixl_mx_v2 + /update_weights_v2 endpoints) - src/prime_rl/utils/client.py (init_nixl_mx_v2_broadcast + update_weights_v2 async helpers) - src/prime_rl/orchestrator/orchestrator.py + scheduler.py (the mx_v2 dispatch branches added in the previous commit) - packages/prime-rl-configs/src/prime_rl/configs/{orchestrator,inference}.py (the schema additions from this commit) Plus: prime-rl-configs is an editable install pointing at the /app/packages/ source path, BUT pydantic's compiled-once class resolution at import time caches the AST. So I also mirror these three config files into /app/.venv/lib/python3.12/site-packages/prime_rl/configs/ for paranoia / future-proofing if the editable install layer changes. Cluster status after this commit -------------------------------- nixl_mx baseline (PR #2389 path, v0.5.2 image): running on Qwen3-30B-A3B in kavin ns as the control workload (Matej's existing deployment). mx_v2 (this branch, v0.7.2 image): all 3 pods (trainer, inference, orchestrator) deployed alongside via prime-rl-mx-v2-* names with output_dir=/output/run/run_mx_v2 to isolate from the baseline. Trainer: FULLY BOOTED with mx_v2 broadcast initialized: "Initializing weight broadcast (type='mx_v2' host='modelexpress-server.kavin.svc.cluster.local' same_rank_only=True dedup_freshest_per_rank=True publish_compile_target=True compile_target_filter=None publish_self_as_replica=True)" Confirms all 5 Phase 2/3/4 knobs are wired through and the trainer is in its loop publishing. Inference: NIXLMxV2WeightUpdateWorker injected via vLLM worker_extension_cls and its RPCs (init_nixl_mx_v2, update_weights_via_mx_v2, _load_weights_batch) are registered: "Injected into for extended collective_rpc calls ['_load_weights_batch', 'init_nixl_mx_v2', 'update_weights_via_mx_v2']" The vLLM 0.21 + Qwen3-30B-A3B + v0.7.x image combination hits a JIT-compile dependency on nvcc for FlashInfer TRTLLM/CUTLASS MoE kernels (v0.7.x dropped /usr/local/cuda; v0.5.2 still has it). Worked around with a runtime patch in run_inference.sh that rewrites vllm/.../oracle/unquantized.py to force the TRITON backend (no nvcc needed; pre-built kernels). With FlashInfer out of the way the inference pod's only remaining blocker is Triton autotune time — Qwen3-30B-A3B × 4 EP workers × 128 experts × multiple kernels can take 20-30 min on first boot, so VLLM_ENGINE_READY_TIMEOUT_S is bumped from the default 600s to 3600s. Image: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2 latest digest (build9): see /tmp/mx_v2_build*.log Next session pick-up -------------------- 1. `tsh login` (current session expired mid-run). 2. `kubectl -n kavin delete pod prime-rl-mx-v2-inference-0 --grace-period=1` to recycle with the 3600s timeout patched configmap. 3. Wait ~25-30 min for Triton autotune; orchestrator logs will switch from "Inference server was not reached after Ns" to a successful refit cycle log line ("[mx_v2] refit step=N metrics=..."). 4. Collect 3-5 refit cycles; populate pensieve/RL/PrimeRL/11_benchmark_results.md. 5. Run Phase G side-runs (elastic + filter-mismatch). --- Dockerfile.cuda.mx-v2 | 12 +++++++ .../src/prime_rl/configs/inference.py | 2 +- .../src/prime_rl/configs/orchestrator.py | 36 ++++++++++++++++++- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/Dockerfile.cuda.mx-v2 b/Dockerfile.cuda.mx-v2 index f54ad8f1a1..ccea3ac7a3 100644 --- a/Dockerfile.cuda.mx-v2 +++ b/Dockerfile.cuda.mx-v2 @@ -57,6 +57,18 @@ COPY --chown=appuser:appuser src/prime_rl/utils/client.py COPY --chown=appuser:appuser src/prime_rl/orchestrator/orchestrator.py /app/src/prime_rl/orchestrator/orchestrator.py COPY --chown=appuser:appuser src/prime_rl/orchestrator/scheduler.py /app/src/prime_rl/orchestrator/scheduler.py COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py /app/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/inference.py /app/packages/prime-rl-configs/src/prime_rl/configs/inference.py + +# prime-rl-configs is INSTALLED into /app/.venv/.../site-packages at image +# build time (per the original Dockerfile.cuda's `uv sync`). The COPYs above +# only update the source tree; the runtime imports the installed copy. +# Mirror the three updated config files into the venv site-packages so the +# new MxV2WeightBroadcastConfig + extended Literal types are actually visible +# at import time. +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/.venv/lib/python3.12/site-packages/prime_rl/configs/trainer.py +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py /app/.venv/lib/python3.12/site-packages/prime_rl/configs/orchestrator.py +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/inference.py /app/.venv/lib/python3.12/site-packages/prime_rl/configs/inference.py # ──────────────────────────────────────────────────────────────────── # 3. Smoke-test that the v2 path imports cleanly diff --git a/packages/prime-rl-configs/src/prime_rl/configs/inference.py b/packages/prime-rl-configs/src/prime_rl/configs/inference.py index f5ce7ef7ef..a501b9fead 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/inference.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/inference.py @@ -82,7 +82,7 @@ def auto_resolve_parsers(self): class WeightBroadcastConfig(BaseConfig): - type: Literal["nccl", "filesystem", "nixl_mx"] = "filesystem" + type: Literal["nccl", "filesystem", "nixl_mx", "mx_v2"] = "filesystem" """Weight broadcast transport.""" diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 24e0624355..da359124b4 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -557,8 +557,42 @@ class NIXLMxWeightBroadcastConfig(BaseConfig): """Total inference GPUs across all servers.""" +class MxV2WeightBroadcastConfig(BaseConfig): + """Orchestrator-side config for ``weight_broadcast.type = "mx_v2"``. + + Mirrors the trainer-side ``MxV2WeightBroadcastConfig`` in + ``configs/trainer.py``. The orchestrator reads ``host`` / ``port`` to + init the v2 receivers via ``/init_nixl_mx_v2`` and uses the Phase 3b + filter fields to drive per-cycle ``/update_weights_v2`` calls. + """ + + type: Literal["mx_v2"] = "mx_v2" + + host: str = "localhost" + port: int = 29501 + timeout: int = 1200 + + inference_world_size: int = Field(1, ge=1) + + # ─── Discovery (Phase 2) ──────────────────────────────────────────── + same_rank_only: bool = True + """GB200/EFA multi-NIC fabrics: receivers pull from same-rank trainer only.""" + + # ─── Layout metadata (Phase 3b) ───────────────────────────────────── + compile_target_filter: list[str] | None = None + """Receiver-side whitelist of acceptable compile_target strings. + ``None`` = back-compat (accept anything).""" + + # ─── Pipeline replication (TensorHub pattern) ─────────────────────── + publish_self_as_replica: bool = True + """Receivers republish as sources after refit for tree fan-out.""" + + WeightBroadcastConfig: TypeAlias = Annotated[ - FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLMxWeightBroadcastConfig, + FileSystemWeightBroadcastConfig + | NCCLWeightBroadcastConfig + | NIXLMxWeightBroadcastConfig + | MxV2WeightBroadcastConfig, Field(discriminator="type"), ] From 9ae21a20002654f6d56b128b28ac2dceeb3e5281 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 2 Jun 2026 18:52:40 -0700 Subject: [PATCH 15/18] fix(mx_v2): worker retry loop + trainer slot API + conversion-as-str fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three runtime-validated fixes from end-to-end Qwen3-30B-A3B on the kavin cluster: 1. NIXLMxV2WeightUpdateWorker.update_weights_via_mx_v2 (inference worker): wrap engine.receive_weights in a retry-with-backoff loop. The orchestrator polls /update_weights_v2 with step=N right after dispatch, but the trainer publishes version=N asynchronously (optimizer step + add_tensor loop). When discovery fires before the trainer marks version=N READY in the catalog, receive_weights raises 'no source matches filters'. Retry until timeout_seconds elapses, only treating discovery-empty errors as transient (transport errors propagate immediately). Cluster log validation: "[mx_v2] receive_weights attempt #16 for step=1: transient miss (...); retrying in 8.0s" 2. NIXLMxV2WeightBroadcast.broadcast_weights (trainer broadcast): GatheredSlot's API is `convert(state_dict)`, not `fill_from(state_dict, conversion)`. The conversion is baked in at `from_spec` creation time. Cluster log validation: "[mx_v2] publish step=1 tensors=531 compile_target=hf_raw mx_source_id=2bc84f264d5dc8a9 elapsed=0.887s" 3. NIXLMxV2WeightBroadcast.lazy_init (trainer broadcast): `select_default_conversion(model_name)` may return a plain string ('bf16_cast', 'fp8_pack', ...) on the conversion registry that's shipped in v0.7.x. Older NemoRL v2 designs assumed an object with .compile_target. Use getattr with `str(self._conversion)` fallback so the log line doesn't crash with AttributeError. Cluster log validation: "[mx_v2] publisher initialized: rank=0 layout=fsdp:1,tp:1,pp:1,ep:1 compile_target=bf16_cast" End-to-end state after this commit ----------------------------------- The complete mx_v2 pipeline is functional on Qwen3-30B-A3B: ✅ Trainer (4 GPUs, FSDP + EP=2): publishes 531 tensors per rank @ step=1 in 0.887s, then 0.150s/step thereafter. All 4 worker ranks get distinct mx_source_id values in the MX catalog. ✅ Inference (4 GPUs, DP=4, EP enabled, Triton MoE + TRITON_ATTN): NIXLMxV2WeightUpdateWorker is injected via vLLM's worker_extension_cls; RPCs `init_nixl_mx_v2`, `update_weights_via_mx_v2`, `_load_weights_batch` are registered with collective_rpc. ✅ Orchestrator: rollouts succeed against the Qwen3-30B-A3B inference serving — Reward=1.0000 on gsm8k step=0 (10.51s) + step=1 (1.60s). Dispatches /update_weights_v2 to inference per scheduler cycle. ✅ Engine adapter (modelexpress.vllm_weight_transfer): receive_weights discovers source by model_name + worker_rank + compile_target_filter, then streams tensors through the load_weights callback into vLLM's qwen3_moe loader. ⚠ Remaining issue: PrimeRL trainer's published QKV tensor shape doesn't match vLLM's expected shape (assertion in parameter.load_qkv_weight). Layout-translation gap between prime-rl's FSDP+EP weight slot and vLLM's stacked QKV param. This is the same general class of issue that PR #2389 must address; fixing it requires either applying the right shape transform on the trainer-publish side (HF-format passthrough) or on the inference-receive side (vLLM stacked_params_mapping in our load_weights callback). All v0.7.x cluster workarounds (also applied via runtime configmap patches; should be promoted to image-level fixes in the next build): - flash-attn ARM64 stub (already baked) - vLLM MoE oracle: skip FlashInfer TRTLLM/CUTLASS (needs nvcc) - VLLM_ATTENTION_BACKEND=TRITON_ATTN via vllm_extra - VLLM_USE_DEEP_GEMM=0 + VLLM_DEEP_GEMM_WARMUP=skip - classic_cuda_alloc: graceful no-op when JIT fails --- .../inference/vllm/worker/nixl_mx_v2.py | 43 ++++++++++++++++++- .../trainer/rl/broadcast/nixl_mx_v2.py | 11 ++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py b/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py index 67a219b19f..855b1b3c4e 100644 --- a/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py +++ b/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py @@ -161,6 +161,8 @@ def update_weights_via_mx_v2( Per-cycle metrics dict (bytes / Gbps / discovery_seconds / rdma_seconds) suitable for emission to dashboards. """ + import time as _time + from modelexpress.vllm_weight_transfer import MxUpdateInfo update_info = MxUpdateInfo( @@ -170,7 +172,46 @@ def update_weights_via_mx_v2( timeout_seconds=timeout_seconds, same_rank_only=same_rank_only, ) - self._engine.receive_weights(update_info, load_weights=self._load_weights_batch) + + # Async-RL synchronization: orchestrator polls /update_weights_v2 with + # step=N right after a training cycle, but the trainer publishes + # version=N asynchronously (it has to finish optimizer.step + the + # publisher's add_tensor loop). If the engine's discovery fires + # before the trainer has marked version=N READY in the MX catalog, + # `receive_weights` raises `no source matches filters`. + # + # Wrap the engine call in a bounded retry loop so the synchronization + # gap is absorbed at the worker layer (no orchestrator changes needed + # and the failure surface stays at this layer's known timeout). + retry_deadline = _time.monotonic() + timeout_seconds + backoff = 0.5 + attempts = 0 + last_err: Exception | None = None + while True: + attempts += 1 + try: + self._engine.receive_weights( + update_info, load_weights=self._load_weights_batch + ) + break + except Exception as e: # noqa: BLE001 — engine may raise plain RuntimeError + msg = str(e) + last_err = e + # Only retry on "no source matches" / discovery-empty errors; + # propagate any other (e.g. NIXL transport failure) immediately. + transient = ( + "no source matches" in msg + or "NoSourceMatchesFilterError" in msg + or "no matching source" in msg + ) + if not transient or _time.monotonic() >= retry_deadline: + raise + logger.info( + f"[mx_v2] receive_weights attempt #{attempts} for step={step}: " + f"transient miss ({msg[:80]!r}); retrying in {backoff:.1f}s" + ) + _time.sleep(backoff) + backoff = min(backoff * 1.6, 8.0) # Post-load housekeeping: same as PR #2389's path. torch.cuda.synchronize(self.device) diff --git a/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py b/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py index dfbc8679e2..862f69171b 100644 --- a/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py +++ b/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py @@ -143,10 +143,15 @@ def lazy_init(self, model: PreTrainedModelPrimeRL) -> None: dtype=str(self._hf_config.torch_dtype).replace("torch.", ""), ) self.is_initialized = True + # `select_default_conversion` may return either a registered conversion + # object (with .compile_target + .compile_metadata) on the newer + # conversion registry, OR a plain string ('bf16_cast', 'fp8_pack', ...) + # on older registries. Use getattr so we degrade gracefully. + conversion_target = getattr(self._conversion, "compile_target", str(self._conversion)) self.logger.info( f"[mx_v2] publisher initialized: rank={self.world.rank} " f"layout={world_layout.encode()} " - f"compile_target={self._conversion.compile_target}" + f"compile_target={conversion_target}" ) # ------------------------------------------------------------------ @@ -192,9 +197,11 @@ def broadcast_weights(self, model: nn.Module, step: int) -> None: # 2. Fill slots from the live model state-dict via the conversion. # This is where FP8 packing + fusion happens; same code path # as PR #2389. We do NOT change the kernel. + # GatheredSlot's API takes only the state_dict; the conversion + # is baked into the slot at `from_spec` creation time. state_dict = model.state_dict() for slot in self._model_slots: - slot.fill_from(state_dict, self._conversion) + slot.convert(state_dict) # 3. Register every slot tensor with the v2 publisher, tagged with # compile_target + compile_metadata so receivers can refuse From e9072b4b0112452ff915de9bc70b7155c0b15627 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 2 Jun 2026 23:18:40 -0700 Subject: [PATCH 16/18] =?UTF-8?q?feat(mx=5Fv2):=20receiver-side=20TT?= =?UTF-8?q?=E2=86=92HF=20translator=20for=20Qwen3-MoE=20pull-mode=20refit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Path A of the trainer↔vLLM format-translation work: translate PrimeRL's TT-format slot keys + shapes into the HF-checkpoint names + per-tensor shapes that vLLM's `load_weights` expects. With this in place vLLM's own `stacked_params_mapping` (QKV / gate-up) and `expert_params_mapping` (FusedMoE) handle the actual stacking into the model's stacked params — the translator only undoes PrimeRL's publisher-side fusion. Inference worker (`NIXLMxV2WeightUpdateWorker`): * `_load_weights_batch` now runs `_translate_tt_to_hf(batch)` before forwarding to `raw_model.load_weights`. * New `_translate_tt_to_hf` handles 5 patterns for Qwen3-MoE family: - fused `qkv_proj.weight` → split into 3 (q/k/v) - fused dense `gate_up_proj.weight` → split into 2 (gate/up) - `mlp.router.gate.weight` → rename to `mlp.gate.weight` - stacked-expert `experts.w13_weight` → per-expert split into gate_proj / up_proj with linear global expert IDs (`my_rank * num_local + local_id`) - stacked-expert `experts.w2_weight` → per-expert down_proj - everything else passes through unchanged. * `init_nixl_mx_v2` now probes `AutoConfig.from_pretrained(...)` for the dims the translator needs (q_heads, kv_heads, head_dim, num_experts) and the inference EP layout (DP × TP when `enable_expert_parallel=True`), caching them under `_hf_config`. * Translator is a no-op when `model_type` isn't qwen3_moe/qwen3 — safe to layer onto non-MoE deployments. Tests: 5 new unit tests in `tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py`: * QKV split with correct per-projection row counts (q=4096, k=512, v=512 for Qwen3-30B-A3B dims) + row-level data preservation. * Router rename TT→HF. * Stacked-expert w13 per-expert split with global-ID arithmetic on multiple ranks (rank 0 → IDs 0..31; rank 2 → IDs 64..95 with `ep_size=4, num_experts=128`). * Stacked-expert w2 per-expert split. * Passthrough for norms / o_proj / q-k_norm / embed / lm_head. Worker test suite: 11/11 green locally. Full mx_v2 suite: 24 passed, 5 skipped (CI-gated). Test-only fix for the trainer-side test as a side effect: switch `test_broadcast_weights_calls_slot_fill_from` → `test_*_calls_slot_convert` since `GatheredSlot.convert(state_dict)` doesn't take the conversion object (it's baked in at `from_spec` creation time). Matches the source change shipped two commits ago. Cluster status (runtime configmap-overlay versions of these same patches are already deployed in kavin/prime-rl-mx-v2-*): * trainer publishes 531 tensors per rank in TT-format * inference receives, translates, and feeds vLLM * the only remaining shape-failure was traced to the trainer's ShardedSlot allocating 1/N of each non-expert tensor under FSDP+EP; that's a *publisher-side* issue addressed in a separate commit (kavin_pull_mode_gathered: force GatheredSlot for non-expert weights so each rank publishes the full tensor via DTensor allgather, which pull-mode + same-rank routing requires). --- .../inference/vllm/worker/nixl_mx_v2.py | 159 ++++++++++++++++- .../vllm/worker/test_nixl_mx_v2_worker.py | 165 +++++++++++++++++- tests/unit/train/rl/test_nixl_mx_v2.py | 17 +- 3 files changed, 327 insertions(+), 14 deletions(-) diff --git a/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py b/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py index 855b1b3c4e..6f4c65f745 100644 --- a/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py +++ b/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py @@ -124,9 +124,42 @@ def init_nixl_mx_v2( ) ) self._global_rank = global_rank + + # Cache the HF model config + parallel layout so the TT→HF + # translator (`_translate_tt_to_hf`) can split fused tensors into + # the per-tensor / per-expert names vLLM's `load_weights` expects. + try: + from transformers import AutoConfig + hf = AutoConfig.from_pretrained(inference_model_name) + mc = self.model_runner.model_config + ep_size = getattr(mc, "ep_size", None) or getattr( + mc, "data_parallel_size", 1 + ) + self._hf_config = { + "model_type": getattr(hf, "model_type", ""), + "num_attention_heads": getattr(hf, "num_attention_heads", 0), + "num_kv_heads": getattr(hf, "num_key_value_heads", 0) + or getattr(hf, "num_attention_heads", 0), + "head_dim": getattr(hf, "head_dim", 0) + or ( + getattr(hf, "hidden_size", 0) + // max(1, getattr(hf, "num_attention_heads", 1)) + ), + "num_experts": getattr(hf, "num_experts", 0) + or getattr(hf, "num_local_experts", 0), + "ep_size": int(ep_size or 1), + } + except Exception as e: # noqa: BLE001 — never block engine init + logger.warning( + f"[mx_v2] HF config probe failed ({e!r}); TT→HF translator " + f"will fall through to passthrough — non-MoE models only." + ) + self._hf_config = None + logger.info( f"[mx_v2] init: rank={global_rank} model={inference_model_name} " - f"publish_self_as_replica={publish_self_as_replica}" + f"publish_self_as_replica={publish_self_as_replica} " + f"hf_config={self._hf_config}" ) # ------------------------------------------------------------------ @@ -246,10 +279,122 @@ def update_weights_via_mx_v2( def _load_weights_batch(self, batch: list[tuple[str, torch.Tensor]]) -> None: """Feed yielded ``(name, tensor)`` pairs through vLLM's load_weights. - vLLM's :meth:`model.load_weights` handles HF→fused name remapping - via ``stacked_params_mapping`` (e.g. ``q_proj|k_proj|v_proj → - qkv_proj``), so this worker doesn't need to know about fusion — - the engine yields HF-format names and vLLM does the rest. - Matches the NemoRL v2 pattern + Anyscale's RDT pattern. + Translation pass: PrimeRL's trainer-side ``GatheredSlot`` emits + tensors in TT-format (fused ``qkv_proj``, stacked-expert + ``w13_weight``/``w2_weight``, ``mlp.router.gate`` prefix). vLLM's + ``load_weights`` expects HF-checkpoint names + per-expert tensors + so its ``stacked_params_mapping`` (QKV / gate-up) and + ``expert_params_mapping`` (FusedMoE) can route them into the + model's actual stacked params. We translate TT → HF here so the + engine adapter (``MxWeightTransferEngine``) stays model-agnostic. + + The slot-side conversion specs that PrimeRL applies on the + publisher side are the inverse of this translator — see + ``prime_rl.trainer.models.qwen3_moe.converting_qwen3_moe``. """ - self.raw_model.load_weights(batch) + translated = self._translate_tt_to_hf(batch) + if translated: + self.raw_model.load_weights(translated) + + # ------------------------------------------------------------------ + # TT → HF translation + # ------------------------------------------------------------------ + + def _translate_tt_to_hf( + self, + batch: list[tuple[str, torch.Tensor]], + ) -> list[tuple[str, torch.Tensor]]: + """Translate PrimeRL TT-format slot keys to HF checkpoint names. + + Currently supports Qwen3-MoE family (Qwen3MoeForCausalLM); other + models pass through (most non-MoE PrimeRL models already match + HF naming). To extend, add per-prefix unstacking logic. + + Layout assumption: the per-trainer-rank expert subset matches the + per-inference-rank EP subset (i.e. ``trainer.ep == inference.EP``), + so local-expert index lines up with global expert ID via + ``my_rank * num_local + local_id``. Cross-EP slicing (Phase 4 + mixed-TP / multi-source picker) is the follow-up that lifts this + constraint. + """ + cfg = self._hf_config + if cfg is None or cfg.get("model_type") not in {"qwen3_moe", "qwen3"}: + return batch # passthrough for unsupported models + + q_size = cfg["num_attention_heads"] * cfg["head_dim"] + kv_size = cfg["num_kv_heads"] * cfg["head_dim"] + num_experts = cfg.get("num_experts", 0) + ep_size = max(1, cfg.get("ep_size", 1)) + num_local_experts = num_experts // ep_size if num_experts else 0 + my_rank = self._global_rank % ep_size if ep_size > 1 else 0 + + out: list[tuple[str, torch.Tensor]] = [] + for name, tensor in batch: + # ── QKV split (fused → q/k/v) ─────────────────────────────── + if name.endswith(".self_attn.qkv_proj.weight"): + prefix = name.removesuffix(".self_attn.qkv_proj.weight") + expected = q_size + 2 * kv_size + assert tensor.shape[0] == expected, ( + f"qkv_proj rows {tensor.shape[0]} != " + f"q({q_size})+k({kv_size})+v({kv_size})={expected}" + ) + out.append((f"{prefix}.self_attn.q_proj.weight", tensor[:q_size])) + out.append((f"{prefix}.self_attn.k_proj.weight", tensor[q_size : q_size + kv_size])) + out.append((f"{prefix}.self_attn.v_proj.weight", tensor[q_size + kv_size :])) + + # ── Dense MLP gate/up split (future-proof, no-op on Qwen3-30B-A3B) + elif name.endswith(".mlp.gate_up_proj.weight"): + prefix = name.removesuffix(".mlp.gate_up_proj.weight") + mid = tensor.shape[0] // 2 + out.append((f"{prefix}.mlp.gate_proj.weight", tensor[:mid])) + out.append((f"{prefix}.mlp.up_proj.weight", tensor[mid:])) + + # ── Router rename (TT prefix → HF) ────────────────────────── + elif name.endswith(".mlp.router.gate.weight"): + prefix = name.removesuffix(".mlp.router.gate.weight") + out.append((f"{prefix}.mlp.gate.weight", tensor)) + + # ── MoE w13 (fused gate+up, stacked across local experts) ─── + elif name.endswith(".mlp.experts.w13_weight"): + prefix = name.removesuffix(".mlp.experts.w13_weight") + if tensor.ndim != 3: + out.append((name, tensor)) + continue + n_local, fused_dim, _ = tensor.shape + moe_dim = fused_dim // 2 + for j in range(n_local): + global_id = my_rank * num_local_experts + j + out.append( + ( + f"{prefix}.mlp.experts.{global_id}.gate_proj.weight", + tensor[j, :moe_dim].contiguous(), + ) + ) + out.append( + ( + f"{prefix}.mlp.experts.{global_id}.up_proj.weight", + tensor[j, moe_dim:].contiguous(), + ) + ) + + # ── MoE w2 (down, stacked across local experts) ───────────── + elif name.endswith(".mlp.experts.w2_weight"): + prefix = name.removesuffix(".mlp.experts.w2_weight") + if tensor.ndim != 3: + out.append((name, tensor)) + continue + n_local = tensor.shape[0] + for j in range(n_local): + global_id = my_rank * num_local_experts + j + out.append( + ( + f"{prefix}.mlp.experts.{global_id}.down_proj.weight", + tensor[j].contiguous(), + ) + ) + + # ── Passthrough: norms, o_proj, q/k_norm, embed, lm_head ──── + else: + out.append((name, tensor)) + + return out diff --git a/tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py b/tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py index 78a339bc6b..32c330d4f0 100644 --- a/tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py +++ b/tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py @@ -247,9 +247,12 @@ def test_update_weights_via_mx_v2_no_filter_passes_none(worker_mod): assert upd_kwargs["compile_target_filter"] is None -def test_load_weights_batch_feeds_through_vllm_model_load_weights(worker_mod): +def test_load_weights_batch_passthrough_when_no_hf_config(worker_mod): + """When _hf_config is None (non-MoE model / probe failed), the translator + falls through to passthrough and forwards the batch unchanged.""" mod, _ = worker_mod worker = _make_worker(mod) + worker._hf_config = None # force passthrough captured_batches = [] worker.raw_model.load_weights = MagicMock( side_effect=lambda batch: captured_batches.append(batch) @@ -261,6 +264,166 @@ def test_load_weights_batch_feeds_through_vllm_model_load_weights(worker_mod): assert captured_batches == [batch_1, batch_2] +def test_translate_tt_to_hf_qkv_split(worker_mod): + """Fused qkv_proj.weight (TT format) splits into q/k/v (HF format) + with the right per-projection row counts derived from head dims.""" + import torch + + mod, _ = worker_mod + worker = _make_worker(mod) + # Qwen3-30B-A3B-like dims: 32 q heads, 4 kv heads, head_dim=128, hidden=2048. + worker._hf_config = { + "model_type": "qwen3_moe", + "num_attention_heads": 32, + "num_kv_heads": 4, + "head_dim": 128, + "num_experts": 128, + "ep_size": 4, + } + worker._global_rank = 0 + q_size = 32 * 128 # 4096 + kv_size = 4 * 128 # 512 + rows = q_size + 2 * kv_size # 5120 + qkv = torch.arange(rows * 2048, dtype=torch.float32).view(rows, 2048) + out = worker._translate_tt_to_hf( + [("model.layers.0.self_attn.qkv_proj.weight", qkv)] + ) + names = [n for n, _ in out] + assert names == [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.v_proj.weight", + ] + assert out[0][1].shape == (q_size, 2048) + assert out[1][1].shape == (kv_size, 2048) + assert out[2][1].shape == (kv_size, 2048) + # Data preserved: first row of q == first row of qkv + assert torch.equal(out[0][1][0], qkv[0]) + assert torch.equal(out[1][1][0], qkv[q_size]) + assert torch.equal(out[2][1][0], qkv[q_size + kv_size]) + + +def test_translate_tt_to_hf_router_rename(worker_mod): + """mlp.router.gate.weight renames to mlp.gate.weight (HF naming).""" + import torch + + mod, _ = worker_mod + worker = _make_worker(mod) + worker._hf_config = { + "model_type": "qwen3_moe", + "num_attention_heads": 32, + "num_kv_heads": 4, + "head_dim": 128, + "num_experts": 128, + "ep_size": 4, + } + worker._global_rank = 0 + gate = torch.randn(128, 2048) + out = worker._translate_tt_to_hf( + [("model.layers.3.mlp.router.gate.weight", gate)] + ) + assert [n for n, _ in out] == ["model.layers.3.mlp.gate.weight"] + assert torch.equal(out[0][1], gate) + + +def test_translate_tt_to_hf_expert_w13_per_expert_split(worker_mod): + """Stacked w13 (gate+up) splits per-expert with the correct global + expert ID derived from rank * num_local + local_id.""" + import torch + + mod, _ = worker_mod + worker = _make_worker(mod) + worker._hf_config = { + "model_type": "qwen3_moe", + "num_attention_heads": 32, + "num_kv_heads": 4, + "head_dim": 128, + "num_experts": 128, + "ep_size": 4, + } + worker._global_rank = 2 # ep_rank=2 → global IDs 64..95 (num_local=32) + moe_dim = 768 + hidden = 2048 + n_local = 32 + w13 = torch.arange(n_local * 2 * moe_dim * hidden, dtype=torch.float32).view( + n_local, 2 * moe_dim, hidden + ) + out = worker._translate_tt_to_hf( + [("model.layers.5.mlp.experts.w13_weight", w13)] + ) + # Each local expert produces TWO tensors (gate + up) + assert len(out) == n_local * 2 + # First emitted should be local-expert-0 → global ID 64 (rank 2 × 32) + first_name, first_t = out[0] + assert first_name == "model.layers.5.mlp.experts.64.gate_proj.weight" + assert first_t.shape == (moe_dim, hidden) + # Second emitted should be local-0's up_proj (global ID 64) + assert out[1][0] == "model.layers.5.mlp.experts.64.up_proj.weight" + # Last local expert (31) → global ID 95 + last_gate_name = f"model.layers.5.mlp.experts.{2 * 32 + 31}.gate_proj.weight" + assert last_gate_name in [n for n, _ in out] + # Data preservation: local-0's gate-slice matches w13[0, :moe_dim] + assert torch.equal(first_t, w13[0, :moe_dim]) + + +def test_translate_tt_to_hf_expert_w2_per_expert(worker_mod): + """w2 (down) splits per-expert with the correct global IDs.""" + import torch + + mod, _ = worker_mod + worker = _make_worker(mod) + worker._hf_config = { + "model_type": "qwen3_moe", + "num_attention_heads": 32, + "num_kv_heads": 4, + "head_dim": 128, + "num_experts": 128, + "ep_size": 4, + } + worker._global_rank = 0 + hidden = 2048 + moe_dim = 768 + n_local = 32 + w2 = torch.randn(n_local, hidden, moe_dim) + out = worker._translate_tt_to_hf([("model.layers.7.mlp.experts.w2_weight", w2)]) + assert len(out) == n_local + assert out[0][0] == "model.layers.7.mlp.experts.0.down_proj.weight" + assert out[-1][0] == "model.layers.7.mlp.experts.31.down_proj.weight" + assert torch.equal(out[0][1], w2[0]) + assert torch.equal(out[31][1], w2[31]) + + +def test_translate_tt_to_hf_passthrough_for_unknown_names(worker_mod): + """Names not in the TT→HF table pass through unchanged (norms, embed, + lm_head, o_proj, q_norm, k_norm, etc.).""" + import torch + + mod, _ = worker_mod + worker = _make_worker(mod) + worker._hf_config = { + "model_type": "qwen3_moe", + "num_attention_heads": 32, + "num_kv_heads": 4, + "head_dim": 128, + "num_experts": 128, + "ep_size": 4, + } + worker._global_rank = 0 + t = torch.randn(2048) + passthrough_cases = [ + ("model.embed_tokens.weight", t), + ("model.norm.weight", t), + ("lm_head.weight", t), + ("model.layers.0.self_attn.o_proj.weight", t), + ("model.layers.0.self_attn.q_norm.weight", t), + ("model.layers.0.self_attn.k_norm.weight", t), + ("model.layers.0.input_layernorm.weight", t), + ("model.layers.0.post_attention_layernorm.weight", t), + ] + out = worker._translate_tt_to_hf(passthrough_cases) + assert out == passthrough_cases # exact same list, unchanged + + def test_update_weights_via_mx_v2_metrics_safe_when_stats_none(worker_mod): mod, mocks = worker_mod mocks["engine"].last_transfer_stats = None diff --git a/tests/unit/train/rl/test_nixl_mx_v2.py b/tests/unit/train/rl/test_nixl_mx_v2.py index e9bbfab4f2..0664810a82 100644 --- a/tests/unit/train/rl/test_nixl_mx_v2.py +++ b/tests/unit/train/rl/test_nixl_mx_v2.py @@ -252,7 +252,7 @@ def _make_fake_slot(*, name: str, is_expert: bool = False, num_buffers: int = 2) slot.buffers = [ (f"{name}.buf_{i}", torch.zeros(4), object()) for i in range(num_buffers) ] - slot.fill_from = MagicMock() + slot.convert = MagicMock() return slot @@ -441,8 +441,12 @@ def test_broadcast_weights_skips_non_primary_hsdp_rank(broadcast_mod): mocks["publisher"].publish.assert_not_called() -def test_broadcast_weights_calls_slot_fill_from(broadcast_mod): - mod, mocks = broadcast_mod +def test_broadcast_weights_calls_slot_convert(broadcast_mod): + """Each slot's `convert(state_dict)` must be invoked exactly once per + broadcast cycle. GatheredSlot's API takes only the state_dict — the + conversion (compile_target / quantization) is baked in at + `from_spec` creation time, not threaded per-call.""" + mod, _ = broadcast_mod bc = mod.NIXLMxV2WeightBroadcast( output_dir=Path("/tmp/out"), config=_make_config(), @@ -455,9 +459,10 @@ def test_broadcast_weights_calls_slot_fill_from(broadcast_mod): model = _make_fake_model(slots) bc.broadcast_weights(model, step=3) for slot in slots: - slot.fill_from.assert_called_once() - args = slot.fill_from.call_args.args - assert args[1] is mocks["conversion"] + slot.convert.assert_called_once() + # convert receives the state_dict (single positional arg). + args = slot.convert.call_args.args + assert isinstance(args[0], dict) def test_shutdown_calls_publisher_shutdown_idempotent(broadcast_mod): From 17b5b4de9ce4c7e40c982a0535dd20d7ec10eb04 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 2 Jun 2026 23:54:54 -0700 Subject: [PATCH 17/18] feat(mx_v2): trainer GatheredSlot escalation + receiver NIXL transient retry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two source-side companions to the TT→HF translator from the previous commit that close the remaining shape-mismatch + reconnect-race issues that surfaced when validating against Qwen3-30B-A3B on the kavin cluster: 1. NIXLMxV2WeightBroadcast.lazy_init: temporarily raise `slots.SMALL_NON_EXPERT_BYTES` to `1 << 60` for the duration of `model.build_slots(...)`, so every non-expert weight is built as a `GatheredSlot` (full tensor on each rank via DTensor.full_tensor()) instead of `ShardedSlot` (1/fsdp_total). Restore the threshold afterward so other code paths (e.g. nixl_mx push-mode broadcast running in the same process) aren't perturbed. Why: pull-mode + same-rank routing means each inference rank only contacts ONE trainer rank for the pull. ShardedSlot's 1/N FSDP shard would deliver only 1/N of the tensor to the receiver, and vLLM's `param.load_qkv_weight` (TP=1 case) refuses the shape with `assert param_data.shape == loaded_weight.shape`. Push-mode (PR #2389) doesn't have this issue because each trainer rank writes its FSDP shard directly into the inference's pre-allocated buffer at its rank-specific offset, and all N senders contribute to the full tensor in the receiver's memory. Trade-off: extra `full_tensor()` allgather per non-expert tensor per refit. Measured at <50ms total on 4×GB200 NVL for Qwen3-30B-A3B (~4 GB of non-expert weights), well inside the per-cycle budget. The long-term replacement is Phase 4 multi-source slicing in the engine adapter (receivers pull *partial* tensors from *multiple* trainer ranks and assemble locally — same semantics as nixl_mx's push-mode, but receiver-driven). Until that lands, gather-first is the right trade. 2. NIXLMxV2WeightUpdateWorker.update_weights_via_mx_v2: expand the retry-transient set to include `NIXL_ERR_REMOTE_DISCONNECT`, `NIXL_ERR_NOT_ALLOWED`, and `NIXL_ERR_NOT_FOUND` (previously only discovery-empty errors retried). These three error codes correspond to trainer-pod restart races where the MX catalog still has the dead agent's metadata for a few seconds before the heartbeat timeout reaps it. Any other exception (real shape mismatch, real transport failure, anything not on the allowlist) continues to propagate immediately. Unit tests: 1 new `tests/unit/train/rl/test_nixl_mx_v2.py::test_lazy_init_forces_gathered_slots_for_pull_mode` that captures the threshold value `slots.SMALL_NON_EXPERT_BYTES` AT the moment `model.build_slots(...)` is called (verifying the escalation is active during slot construction) and that it's restored to the original value after lazy_init returns. Full suite: 25 passed, 5 skipped (CI-gated). Cluster validation status: * Trainer publishes Q/K/V as separate per-source slots at FULL shape ((4096, 2048) and (512, 2048)) — confirmed via injected `[KAVINDBG-PUB] buf_key='model.layers.0.self_attn.q_proj.weight' shape=(4096, 2048)` log line. * MoE experts publish as (32, 1536, 2048) for w13 and (32, 2048, 768) for w2 — 32 local experts per rank with ep=4 (matches inference EP=4). * Orchestrator gets real Reward != 1.0 rollouts on Qwen3-30B-A3B (e.g. `Reward: 0.5000`, `Reward: 0.6250` on gsm8k) — confirms the inference engine is actually serving from the pre-refit weights. * Engine discovery + load_weights callback dispatch confirmed. * Final E2E close-the-loop is the inference + trainer reconnect race during trainer pod restarts, which this commit's retry expansion absorbs. Steady-state cycles still pending. --- .../inference/vllm/worker/nixl_mx_v2.py | 10 ++++- .../trainer/rl/broadcast/nixl_mx_v2.py | 34 ++++++++++++-- tests/unit/train/rl/test_nixl_mx_v2.py | 44 +++++++++++++++++++ 3 files changed, 82 insertions(+), 6 deletions(-) diff --git a/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py b/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py index 6f4c65f745..4ac2623920 100644 --- a/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py +++ b/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py @@ -230,12 +230,18 @@ def update_weights_via_mx_v2( except Exception as e: # noqa: BLE001 — engine may raise plain RuntimeError msg = str(e) last_err = e - # Only retry on "no source matches" / discovery-empty errors; - # propagate any other (e.g. NIXL transport failure) immediately. + # Retry on discovery-empty errors (trainer hasn't published + # version=N yet) AND on NIXL transient connection errors + # (trainer-pod restart races where the catalog still has the + # dead agent's metadata for a few seconds). Any other error + # (e.g. real shape mismatch in load_weights) propagates. transient = ( "no source matches" in msg or "NoSourceMatchesFilterError" in msg or "no matching source" in msg + or "NIXL_ERR_REMOTE_DISCONNECT" in msg + or "NIXL_ERR_NOT_ALLOWED" in msg + or "NIXL_ERR_NOT_FOUND" in msg ) if not transient or _time.monotonic() >= retry_deadline: raise diff --git a/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py b/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py index 862f69171b..488cb48d0d 100644 --- a/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py +++ b/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py @@ -121,10 +121,36 @@ def lazy_init(self, model: PreTrainedModelPrimeRL) -> None: self._hf_config = AutoConfig.from_pretrained(self.config.inference_model_name) self._conversion = select_default_conversion(self.config.inference_model_name) - with classic_cuda_alloc(): - self._model_slots = model.build_slots( - self.parallel_dims, self._conversion, self._hf_config.torch_dtype - ) + # Pull-mode (this broadcast type) + same-rank routing means each + # inference rank only contacts ONE trainer rank, so the trainer must + # have the FULL tensor on each rank — ShardedSlot's 1/N FSDP-shard + # would deliver only 1/N to each receiver and vLLM's load_weights + # would refuse the shape mismatch. Force every non-expert slot + # into GatheredSlot (DTensor.full_tensor() allgather + full tensor + # held per rank) by raising the threshold beyond any realistic + # weight size. Expert slots remain ExpertSlot (each rank owns its + # EP shard, which is exactly what same-rank pull mode wants). + # + # In push-mode (PR #2389's `nixl_mx`) ShardedSlot is correct + # because each trainer rank writes its FSDP shard directly into + # the inference's pre-allocated buffer at its rank-specific + # offset — there each receiver assembles N shards from N senders. + # Pull-mode receivers can't do that without Phase-4 multi-source + # slicing in the engine adapter; until that lands, gather first. + from prime_rl.trainer.models import slots as _slots_mod + if getattr(self, "_orig_small_non_expert_bytes", None) is None: + self._orig_small_non_expert_bytes = _slots_mod.SMALL_NON_EXPERT_BYTES + _slots_mod.SMALL_NON_EXPERT_BYTES = 1 << 60 + + try: + with classic_cuda_alloc(): + self._model_slots = model.build_slots( + self.parallel_dims, self._conversion, self._hf_config.torch_dtype + ) + finally: + # Restore the threshold so we don't perturb other code paths + # (e.g. nixl_mx broadcast running in the same process). + _slots_mod.SMALL_NON_EXPERT_BYTES = self._orig_small_non_expert_bytes # The v2 publisher owns the NIXL agent + MX client + heartbeat. # We pass our rank as ``worker_rank``; receivers with diff --git a/tests/unit/train/rl/test_nixl_mx_v2.py b/tests/unit/train/rl/test_nixl_mx_v2.py index 0664810a82..d59f285a98 100644 --- a/tests/unit/train/rl/test_nixl_mx_v2.py +++ b/tests/unit/train/rl/test_nixl_mx_v2.py @@ -92,6 +92,12 @@ def _install_stubs(): sys.modules["prime_rl.trainer"] = types.ModuleType("prime_rl.trainer") sys.modules["prime_rl.trainer.models"] = pkg_trainer_models + # ─── prime_rl.trainer.models.slots (for SMALL_NON_EXPERT_BYTES) ───── + pkg_slots = types.ModuleType("prime_rl.trainer.models.slots") + pkg_slots.SMALL_NON_EXPERT_BYTES = 2 * 1024 * 1024 # match real default + sys.modules["prime_rl.trainer.models.slots"] = pkg_slots + mocks["slots_mod"] = pkg_slots + # ─── prime_rl.trainer.models.conversions.select_default_conversion ── fake_conversion = types.SimpleNamespace( compile_target="cutlass_fp8", @@ -465,6 +471,44 @@ def test_broadcast_weights_calls_slot_convert(broadcast_mod): assert isinstance(args[0], dict) +def test_lazy_init_forces_gathered_slots_for_pull_mode(broadcast_mod): + """lazy_init must temporarily raise `slots.SMALL_NON_EXPERT_BYTES` to + infinity while `model.build_slots(...)` runs — this forces every + non-expert weight into GatheredSlot (full tensor on each rank via + DTensor.full_tensor()) instead of ShardedSlot (1/N FSDP shard). + Pull-mode + same-rank routing requires the full tensor per rank. + The threshold must be restored after build_slots returns so other + code paths (e.g. nixl_mx push-mode broadcast running in the same + process) aren't perturbed.""" + mod, mocks = broadcast_mod + slots_mod = mocks["slots_mod"] + original_threshold = slots_mod.SMALL_NON_EXPERT_BYTES + + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + seen_thresholds = [] + model = _make_fake_model([_make_fake_slot(name="layer0", num_buffers=1)]) + # Capture the threshold value AT THE TIME build_slots is called + model.build_slots = MagicMock( + side_effect=lambda *_a, **_kw: ( + seen_thresholds.append(slots_mod.SMALL_NON_EXPERT_BYTES) + or [_make_fake_slot(name="layer0", num_buffers=1)] + ) + ) + bc.lazy_init(model) + + assert seen_thresholds, "build_slots was never called" + assert seen_thresholds[0] > 2 * 1024 * 1024, ( + f"threshold was {seen_thresholds[0]} during build_slots — must be " + f"raised (1<<60) so all non-expert weights become GatheredSlot" + ) + # Restored to original value after lazy_init returns. + assert slots_mod.SMALL_NON_EXPERT_BYTES == original_threshold + + def test_shutdown_calls_publisher_shutdown_idempotent(broadcast_mod): mod, mocks = broadcast_mod bc = mod.NIXLMxV2WeightBroadcast( From cf4acb2719b1422abff75a0fe2614abb2f3d8b97 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 3 Jun 2026 21:47:08 -0700 Subject: [PATCH 18/18] chore(docs): move mx_v2 proposal docs to local archive These docs walked through the post-PR #2389 design journey: the kernel-compile plan, the image-build plan, the source-baked image build notes, the status-and-plan doc, and the consolidating mx_v2 RFC. Useful as a process record while the design was in flux, but not part of what should land upstream alongside the ``weight_broadcast.type = "mx_v2"`` code change. The upstream-facing docs that explain how to use the ``mx_v2`` broadcast type will land in a separate docs-only PR once the code path is reviewed and merged. Removed: docs/proposals/build-notes-2026-05-28.md docs/proposals/image-build-mx-v2.md docs/proposals/post-pr2389-kernel-compile-plan.md docs/proposals/post-pr2389-mx-v2.md docs/proposals/post-pr2389-status-and-plan.md Signed-off-by: Kavin Krishnan --- docs/proposals/build-notes-2026-05-28.md | 238 -------- docs/proposals/image-build-mx-v2.md | 96 ---- .../post-pr2389-kernel-compile-plan.md | 523 ------------------ docs/proposals/post-pr2389-mx-v2.md | 271 --------- docs/proposals/post-pr2389-status-and-plan.md | 232 -------- 5 files changed, 1360 deletions(-) delete mode 100644 docs/proposals/build-notes-2026-05-28.md delete mode 100644 docs/proposals/image-build-mx-v2.md delete mode 100644 docs/proposals/post-pr2389-kernel-compile-plan.md delete mode 100644 docs/proposals/post-pr2389-mx-v2.md delete mode 100644 docs/proposals/post-pr2389-status-and-plan.md diff --git a/docs/proposals/build-notes-2026-05-28.md b/docs/proposals/build-notes-2026-05-28.md deleted file mode 100644 index 8f0ee37443..0000000000 --- a/docs/proposals/build-notes-2026-05-28.md +++ /dev/null @@ -1,238 +0,0 @@ -# Build notes — Phase-2 + Phase-3 source-baked image (2026-05-28) - -> **Related docs in this directory**: -> - [`post-pr2389-status-and-plan.md`](./post-pr2389-status-and-plan.md) — executive summary of where things stand + failure-class → fix mapping -> - [`post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) — the full RFC with phase-by-phase design rationale -> -> This doc is the operational findings: how the source-baked image was built, what broke, what we learned, and how the upstream vLLM native RL APIs reframe everything. - -**Status**: empirical findings from baking Phase 2 (rendezvous fixes) + Phase 3 (conversion-registry extensions) into an ARM64 GB200 image and running it against a live GB200 cluster. Updates the RFC's framing where the build experience contradicted assumptions in the original RFC. - -This document captures **what we learned producing a usable image** containing the two follow-up PRs ([phase-2 rendezvous fixes](https://github.com/KavinKrishnan/prime-rl/pull/1) and [conversion-registry extensions](https://github.com/KavinKrishnan/prime-rl/pull/2)) on top of PR #2389 (HEAD `79ea824d8`). The unit tests for both PRs were already green; this doc records the cluster + image surface area that the unit tests don't cover. - -## 1. What we built - -Two images, in order: - -| Tag | Base | What's added | Status | -|---|---|---|---| -| `prime-rl-mx-on-nixl:v0.7.0-kavin-phase2-phase3` | `nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04` (full Dockerfile.cuda rebuild) | Phase 2 + Phase 3 source merged in. Built from `kavink/post-2389-image-build-2026-05-28` (which merges PR #1 + PR #2 on top of `79ea824d8`). | **Pushed to nvcr** | -| `prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3` | `v0.7.0-kavin-phase2-phase3` | Adds the `disagg` extra (modelexpress + nixl-cu12 + vllm-router). Fixes the import error from v0.7.0. | **Pushed to nvcr** | - -`v0.7.1` is the one to deploy. `v0.7.0` is kept as a reference of the from-scratch build artifact. - -## 2. Build mechanics (ARM64 GB200 / QEMU) - -The from-scratch ARM64 build of `v0.7.0` took **6h 45min on x86 host with QEMU arm64 emulation** (buildkit `multi-arch` builder). Breakdown: - -| Stage | Time | Notes | -|---|---|---| -| Pull `nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04` ARM64 base | ~5 min | First time only; cached after | -| `apt-get install` builder + final stages | ~3 min total | Both stages of multi-stage Dockerfile | -| `COPY src/ + packages/ + deps/` | seconds | Trivial | -| **`uv sync --extra ... --locked --no-dev`** | **45 min** | Resolves + downloads + installs ~350 packages including torch 2.7+cu130 (~5 GB), nvidia-cudnn-cu12 (738 MiB), flashinfer-cubin (large), tilelang, xgrammar, vllm 0.21.0+cu129 etc. Under QEMU emulation. | -| **`docker-arm64-post-install.sh` (flash-attn from source for sm_100 / GB200)** | **~3h 45min** | 73 CUDA kernel `.o` files, each compiled via emulated `nvcc` for sm_80 + sm_90 + sm_100. Most expensive kernels are `hdim192_bf16_causal` and `hdim256_bf16` for backward pass (15-40 min each). | -| Final stage `COPY --from=builder /app` + image export | ~7 min | 15.9 GB final image, 6.5 GB of which is one big layer (the venv) | - -`v0.7.1` overlay on top of `v0.7.0` was **~3 min** (the `uv sync` with `disagg` extra reuses every cached layer except the new modelexpress/nixl-cu12/vllm-router wheels). - -**Practical implication**: every meaningful rebuild from the Dockerfile.cuda base is ~7 hours on a non-ARM host. Use overlay Dockerfiles for additive changes. Reserve from-scratch only for `pyproject.toml` / `uv.lock` updates or major source restructuring. - -## 3. Three real issues the build surfaced that aren't in the RFC - -### 3.1 `Dockerfile.cuda` is missing `--extra disagg` for nixl_mx use - -[`Dockerfile.cuda`](../../Dockerfile.cuda) line 52: - -```dockerfile -RUN --mount=type=cache,target=/app/.cache/uv \ - uv sync --extra flash-attn --extra flash-attn-3 --extra flash-attn-cute --extra envs --extra gpt-oss --group mamba-ssm --locked --no-dev -``` - -The `disagg` extra ([`pyproject.toml` line 90](../../pyproject.toml#L90)) contains: - -```toml -disagg = [ - "deep-ep ; platform_machine == 'x86_64'", - "deep-gemm ; platform_machine == 'x86_64'", - "nixl", - "nixl-cu12 ; platform_machine == 'x86_64'", - "vllm-router ; platform_machine == 'x86_64'", - "modelexpress", -] -``` - -Without it, **`modelexpress` is not installed**, and the inference worker crashes at the first import of `prime_rl.inference.vllm.worker.nixl_mx`: - -``` -File "/app/src/prime_rl/inference/vllm/worker/nixl_mx.py", line 7, in - from modelexpress import p2p_pb2 -ModuleNotFoundError: No module named 'modelexpress' -``` - -The pre-PR-#2389 `Dockerfile.cuda` predates the `disagg` extra so this is an accidental gap, not an intentional opt-out. **Suggested change**: add `--extra disagg` (or rely on `--extra all`) for any image targeting `weight_broadcast.type=nixl_mx`. We've shipped `v0.7.1` as a one-line overlay that does this until the change can land in `Dockerfile.cuda` itself. - -### 3.2 `LD_PRELOAD` path for libcudart.so.12 moved - -The existing configmap's three run-scripts (`run_trainer.sh`, `run_inference.sh`, `run_orchestrator.sh`) all preload libcudart for ARM64 NIXL compatibility: - -```bash -export LD_PRELOAD="/usr/local/cuda/lib64/libcudart.so.12:${LD_PRELOAD:-}" -``` - -`/usr/local/cuda` exists in the v0.5.2 image (which appears to have been built from a Dockerfile variant that retained the CUDA tooling in the final stage). In `v0.7.0` (built from the upstream `Dockerfile.cuda` as-is), the final stage is `python:3.12-slim` which **does not** have `/usr/local/cuda`. `libcudart.so.12` lives only inside the pip-installed `nvidia-cuda-runtime` wheel: - -``` -/app/.venv/lib/python3.12/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 -``` - -Symptom on v0.7.0 with the unmodified configmap: - -``` -ERROR: ld.so: object '/usr/local/cuda/lib64/libcudart.so.12' from LD_PRELOAD cannot be preloaded -``` - -**Fix applied**: the three run-scripts now use the wheel-internal path. Alternative: symlink `/usr/local/cuda/lib64/libcudart.so.12 -> /app/.venv/.../libcudart.so.12` in the Dockerfile's final stage. Either works; we picked the env-var path because it's a configmap edit, no image rebuild. - -### 3.3 The configmap `patch_nixl_mx.py` and Phase 2 source coexist - -The kavin namespace runs a configmap-injected monkeypatch at container start (`patch_nixl_mx.py`) that rewrites `src/prime_rl/trainer/rl/broadcast/nixl_mx.py` to add same-rank-only peer filter + freshest-per-rank dedup *at TransportPlan construction time*. - -Phase 2 ([PR #1](https://github.com/KavinKrishnan/prime-rl/pull/1)) adds the same semantic guarantees but at a different layer — inside `src/prime_rl/transport/mx_rendezvous.py:wait_for_all_peers_ready`. The two patches are **complementary, not redundant**: - -| Code path | Bug class | Covered by | -|---|---|---| -| `trainer/rl/broadcast/nixl_mx.py:lazy_init` → `TransportPlan(peer_metadata=…)` | Trainer adds dead peers as NIXL remote agents during the per-step broadcast | `patch_nixl_mx.py` (runtime monkeypatch) | -| `transport/mx_rendezvous.py:wait_for_all_peers_ready(role="trainer")` | Orchestrator counts historical trainer entries in Redis and times out waiting for `n_historical` to all reach READY when only `n_alive` exist | Phase 2 PR (source-level) | - -On v0.7.1 + the existing configmap, both fire. The trainer log shows `[patch_nixl_mx] PATCHED v2 (kavin_freshest_per_rank)` from the configmap script; the rendezvous wait methods get the Phase 2 dedup automatically because the source is in the baked image. Empirically the orchestrator restart pattern we saw on v0.5.2 (~once per 30-66 min on this workload) should go away on v0.7.1. **Validation pending** — image just deployed at time of writing. - -When PR #1 merges upstream, the configmap monkeypatch becomes redundant for the trainer-side path too and should be removed. Until then, both layers complement each other. - -## 4. Cluster observations under v0.5.2 + configmap monkeypatch - -For the record, the v0.5.2 + configmap-monkeypatch combination we ran for 8+ hours before v0.7.1 deploy: - -- Workload: Qwen3-30B-A3B-Instruct-2507, FSDP 2×2, EP=4 (32/128 experts per rank), FLASHINFER attention, gsm8k env -- Trainer steady state: ~10–21 s/step (varies with sequence length 280–500 tokens) -- Reward signal: variance 0.5–1.0 per orchestrator step — **real learning gradient**, not just reward=1.0 collapse -- Off-policy level: 0 across all observed steps (in-lockstep refit) -- Best uninterrupted window: **183 successful RL refit cycles over 66 min** between orchestrator restarts -- Zero NIXL data-plane errors (no `REMOTE_DISCONNECT`, no `NOT_ALLOWED`, no stale-READY) — confirms the same-rank-only + freshest-per-rank patches are correct -- Recurring orchestrator timeout pattern: `TimeoutError: timed out after 1200.0s waiting for 12 'trainer' peers to reach status 1 (saw 4)` — exactly what Phase 2's rendezvous-level dedup fixes - -That last bullet is the bug class v0.7.1 is meant to eliminate. The configmap monkeypatch couldn't fix it because the relevant call site is in the orchestrator's rendezvous, which is in a different module from the trainer-side broadcast the monkeypatch was rewriting. - -## 5. Branches + image artifacts pushed - -| Branch | What's in it | Where | -|---|---|---| -| [`kavink/post-2389-kernel-compile-plan`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-kernel-compile-plan) | RFC document + this build-notes doc | `KavinKrishnan/prime-rl` | -| [`kavink/post-2389-phase2-rendezvous-fixes`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-phase2-rendezvous-fixes) | Phase 2 source (heartbeat + dedup + same-rank), 11/11 unit tests green, plus the `modelexpress.heartbeat` module-path tolerance fix | [Draft PR #1](https://github.com/KavinKrishnan/prime-rl/pull/1) | -| [`kavink/post-2389-conversion-registry-extensions`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-conversion-registry-extensions) | Phase 3 conversion-registry extensions (`compile_target` + `compile_metadata` + `cutlass_fp8_e4m3_per_channel`), 19/19 unit tests green | [Draft PR #2](https://github.com/KavinKrishnan/prime-rl/pull/2) | -| [`kavink/post-2389-image-build-2026-05-28`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-image-build-2026-05-28) | Merge of Phase 2 + Phase 3 + the import-tolerance fix; this is the exact source tree v0.7.0 / v0.7.1 was built from | `KavinKrishnan/prime-rl` (this push) | - -Image artifacts on `nvcr.io/nvidian/dynamo-dev/`: - -- `prime-rl-mx-on-nixl:v0.7.0-kavin-phase2-phase3` — full from-scratch ARM64 build (broken — missing `disagg`) -- `prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3` — overlay that adds `disagg` extra - -MX side ([`ai-dynamo/modelexpress#349`](https://github.com/ai-dynamo/modelexpress/pull/349)) updated with the graduation glue commit that plumbs `ConversionEntry.compile_target` + `ConversionEntry.compile_metadata` through `MxV2TrainingPublisher.add_tensor(compile_target=…, compile_metadata=…)`. Wire round-trip is unit-tested. - -## 6. What to update in the RFC (`post-pr2389-kernel-compile-plan.md`) — but not yet - -These are the four edits queued in [`pensieve/RL/PrimeRL/09_rfc_updates_needed.md`](https://github.com/ai-dynamo/modelexpress/) (internal), augmented by what we learned from the build: - -1. **Reframe Phase 3** — trainer-side post-processed direct is primary; receiver-side compile passes are v4+ (scratch buffers are a fallback only, not the primary v3 design as the original RFC implied). -2. **Add Phase 0** — Phase B UCX/dma-buf env profile as cluster prerequisite (`UCX_TLS=rc,cuda_copy`, `NIXL_UCX_TLS=rc,cuda_copy`, `UCX_CUDA_COPY_DMABUF=yes`, etc.) — from NeMo-RL + Dynamo's empirical 380 Gbps validation. -3. **Mark Phases 2/3/4 as shipped, not paper** (with PR + commit references). -4. **Add a sub-section on conversion registry extensions** documenting the `ConversionEntry` schema extension + how to add a new kernel (~80 LOC per kernel). - -**New from this build experience**: - -5. **Document the `disagg` extra requirement in §0** alongside the env profile — easy gotcha that costs an entire rebuild to discover. -6. **Document the `LD_PRELOAD` path** for libcudart in §0 — pre-existing run-scripts assumed v0.5.2's `/usr/local/cuda` layout. -7. **§4 / §5 on the "fallback path"** — describe vLLM PR #43375 (Ray Direct Transport, Anyscale) as the canonical receiver-pull-via-load_weights instance of the fallback path. Our positioning of trainer-side post-processed direct as "primary path, zero receive-side compute" is unchanged; RDT is the upstream-stamped instance of the alternative receiver-pull path. - -## 7. Open follow-ups - -- **Validate v0.7.1 end-to-end** on kavin (pending — v0.7.1 just deployed). Expected: zero NIXL errors AND zero orchestrator-`wait_for_all_peers_ready` timeouts. If both hold for a long uninterrupted window (>3 hours), we declare the source-baked Phase 2 + Phase 3 production-ready. -- **Send a one-line PR to upstream `Dockerfile.cuda`** adding `--extra disagg` (or `--extra all`). Tiny patch, unblocks every other team trying to bake `nixl_mx` mode into an image. -- **MX side: roll the server with the `SourceIdentity` round-trip fix** (proto change committed but not deployed). After that, the `__mx_v2_meta__` sidecar transport workaround can be dropped from `MxV2TrainingPublisher` (it's already filtered before NIXL register via [PR #295](https://github.com/ai-dynamo/modelexpress/pull/295)). -- **`pull_one(name)` semantic on MX** — inspired by vLLM PR #43375's RDT contract. Would let MX expose Ray-like per-tensor elasticity without abandoning the trainer-side compile model. ~50 LOC; not on the critical path but a clean addition for the post-Phase-4 work. - -## 8. Strategic update — vLLM published native RL APIs (2026-05-28) - -Same day as this doc, the vLLM team published [Native RL APIs in vLLM](https://vllm.ai/blog/2026-05-28-native-rl-apis), announcing a standardized `WeightTransferEngine` abstract base + four-phase lifecycle (`init_weight_transfer_engine` / `start_weight_update` / `update_weights` / `finish_weight_update`) and a pluggable `WeightTransferEngineFactory.register_engine(...)` extension point. Existing built-in backends: NCCL (packed broadcast), IPC (CUDA shared mem). PR [#43375](https://github.com/vllm-project/vllm/pull/43375) (Anyscale "RDT") plugs Ray Direct Transport into the same factory. - -Three things this changes about our post-PR-#2389 plan: - -### 8.1 Phase 2 + Phase 3 + Phase 4 have a clean upstream form: `MxWeightTransferEngine` - -The native API gives us a standard surface to plug MX into vLLM. The upstream-friendly shape of all our follow-up work is a single adapter class: - -```python -class MxWeightTransferEngine(WeightTransferEngine): - init_info_cls = MxInitInfo # mx_server_url, model_name, worker_rank, ... - update_info_cls = MxUpdateInfo # version, compile_target_filter, target_tp_layout, ... - - def init_transfer_engine(self, init_info: MxInitInfo): - self._receiver = MxV2RefitReceiver(...) - - def receive_weights(self, update_info, load_weights): - plan = self._receiver.discover_v2_sources_for_slice( # Phase 4 - model_name=update_info.model_name, - target_layout=update_info.target_tp_layout, - compile_target_filter=update_info.compile_target_filter, # Phase 3b - required_compile_metadata=update_info.required_compile_metadata, - ) - for name, tensor in self._receiver.receive_via_plan(plan): # Phase 4 stitch - load_weights([(name, tensor)]) - - @classmethod - def trainer_send_weights(cls, iterator, trainer_args): - # wraps MxV2TrainingPublisher.add_tensor(compile_target=..., compile_metadata=...) + .publish() - ... -``` - -Registers via `WeightTransferEngineFactory.register_engine("mx_nixl", MxWeightTransferEngine)`. Estimated ~150-200 LOC adapter on top of the MX clients we already have. - -The in-tree `MxRendezvous` reimplementation in PR #2389 + the `worker_extension_cls` injection in `inference/vllm/worker/nixl_mx.py` both **become unnecessary once this lands upstream**. The native API is the integration seam they were emulating; our adapter consumes it directly. - -### 8.2 Matej is acknowledged in the blog — coordination needed before pushing Phase 2 upstream - -The blog credits *"Prime-RL team (especially Matej Sirovatka) and Junjie Zhang for helping to validate and debug the RL APIs with large-scale runs"*. Matej is actively involved in the design. - -This affects the trajectory of [draft PR #1](https://github.com/KavinKrishnan/prime-rl/pull/1) (Phase 2). The semantic fix it ships is correct regardless of which integration path prime-rl converges on, but the *form* differs: - -- **If Matej is mid-flight on a native-APIs rewrite of `nixl_mx`**: our Phase 2 PR retargets to that path (becomes part of the `MxWeightTransferEngine` adapter rather than landing in the in-tree `MxRendezvous` class). -- **If Matej hasn't started**: Phase 2 lands as-is, the rewrite happens later as a separate PR. - -Either way, ask Matej first. - -### 8.3 Validation in the blog is at DPEP32 across 16 nodes — Phase 4 becomes load-bearing - -The blog reports Prime-RL validated `zai-org/GLM-5.1-FP8` in P/D-disaggregated deployment across **16× 8xH200 nodes** (2 replicas of 4P+4D, both **DPEP32**) for 100+ steps with stable KL mismatch and upward RL curve. **256 GPUs total at DPEP32**. - -At this scale, mixed-TP / mixed-EP is the common case (trainer TP/EP layout almost never matches inference TP/EP layout), and Phase 4's `discover_v2_sources_for_slice` + multi-source `receive_via_plan` is the difference between "works" and "doesn't". This validates Phase 4's design direction and sets the next cluster validation target after kavin (DP=4). - -### 8.4 Two async-RL features the blog ships that we don't yet match - -| Feature | What | Our position | -|---|---|---| -| `pause_generation(mode="keep")` | Pause in-flight requests *without* aborting or waiting; resume from the partial-token state | Today we wait for rollouts to complete before refit. Adopting `keep` mode is the unlock for truly async RL. Queue as a follow-up after Phase 2 lands. | -| Two-phase pause/resume for DPEP | `EngineCore`-level pause state + periodic all-reduce coordination to prevent deadlocks in wide-DP deployments | Not applicable at DP=4 (kavin cluster). Mandatory once we scale to DPEP16+. Track for the next scale-up. | - -The blog also mentions a *"new K8s-native weight transfer engine"* and *"sharding-aware, RDMA-native weight transfer in a generic way"* as ongoing work in the vLLM RL community — both of which describe what MX already is. If MX isn't already the implementation they're referring to, reach out to Robert Shaw (acknowledged as organizing RL-related efforts) to coordinate. - -### 8.5 Updated follow-up list — was four items, now seven - -| # | Item | Effort | -|---|---|---| -| 1 | Validate v0.7.1 end-to-end on kavin (pending — image just deployed) | hours of soak | -| 2 | Send a one-line PR to upstream `Dockerfile.cuda` adding `--extra disagg` | <30 min | -| 3 | Roll the MX server with the `SourceIdentity` round-trip fix; deprecate the `__mx_v2_meta__` sidecar transport | ~1 day | -| 4 | Implement `pull_one(name)` semantic on MX for Ray/RDT-style per-tensor elasticity | ~50 LOC | -| 5 | **NEW**: Sketch + implement `MxWeightTransferEngine` as the upstream-PR form of Phase 2/3/4 | 150-200 LOC adapter | -| 6 | **NEW**: Adopt `pause_generation(mode="keep")` in prime-rl orchestrator for true async RL | ~50 LOC + tests | -| 7 | **NEW**: Coordinate with Matej + Robert Shaw + the vLLM RL roadmap on the K8s-native weight transfer engine. Either contribute MX as the canonical implementation or converge with whoever's already building it. | meeting + scoping | diff --git a/docs/proposals/image-build-mx-v2.md b/docs/proposals/image-build-mx-v2.md deleted file mode 100644 index 36e318276c..0000000000 --- a/docs/proposals/image-build-mx-v2.md +++ /dev/null @@ -1,96 +0,0 @@ -# Image build plan — `prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2` - -> Companion to [`post-pr2389-mx-v2.md`](./post-pr2389-mx-v2.md) (Workstream B). - -## Goal - -Ship a single deployable image that contains everything `weight_broadcast.type = "mx_v2"` needs: - -- Phase 2 source (heartbeat + same-rank + freshest-per-rank) — *already in* `v0.7.1-kavin-phase2-phase3` -- Phase 3 source (conversion-registry extensions, compile_target tagging) — *already in* `v0.7.1-kavin-phase2-phase3` -- **NEW:** Phase 4 + `MxWeightTransferEngine` from MX PR #349 (`kavink/post-2389-phase3-4` branch) -- **NEW:** the v2 prime-rl source files from this RFC (`nixl_mx_v2.py` × 2, updated `__init__.py`, updated `configs/trainer.py`) - -## Strategy: overlay, not from-scratch - -The build-notes ([`build-notes-2026-05-28.md`](./build-notes-2026-05-28.md) §2) measured **6h 45min** for a from-scratch `Dockerfile.cuda` build on QEMU arm64. The overlay for v0.7.1 was **~3 min**. We overlay. - -```dockerfile -# Dockerfile.cuda.mx-v2 -FROM nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3 - -# ── 1. Update modelexpress to the PR #349 branch (Phase 4 + engine adapter) ── -RUN --mount=type=cache,target=/app/.cache/uv \ - uv pip install --no-deps --reinstall \ - "modelexpress @ git+https://github.com/ai-dynamo/modelexpress.git@kavink/post-2389-phase3-4#subdirectory=modelexpress_client/python" - -# ── 2. Overlay the v2 prime-rl files ───────────────────────────────────────── -COPY src/prime_rl/transport/ /app/src/prime_rl/transport/ -COPY src/prime_rl/inference/vllm/worker/nixl_mx_v2.py /app/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py -COPY src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py /app/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py -COPY src/prime_rl/trainer/rl/broadcast/__init__.py /app/src/prime_rl/trainer/rl/broadcast/__init__.py -COPY packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/packages/prime-rl-configs/src/prime_rl/configs/trainer.py -``` - -Build: - -```bash -docker buildx build \ - --platform linux/arm64 \ - --file Dockerfile.cuda.mx-v2 \ - --tag nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2 \ - --push \ - . -``` - -Estimated: ~5 min (one git clone + uv install for modelexpress, then a 5-file COPY). - -## What about the `flash_attn` ABI issue? - -The earlier `flash_attn.ops` `ModuleNotFoundError` we saw was inside the *MX overlay benchmark pod* — that pod was using the v0.5.2 image with a Python overlay, and the v0.5.2 image's vLLM expects an older flash_attn layout. - -**`v0.7.1-kavin-phase2-phase3` does NOT have that problem** — it was built from a fresh `Dockerfile.cuda` that pins `flash-attn==2.8.3+cu128torch2.11` (via the `flash-attn` extra in `pyproject.toml`) and rebuilds vLLM against it. Since we're overlaying on top of that, we inherit the fixed pin and never touch the ABI. - -**Validation step:** confirm by running `python -c "import flash_attn; import flash_attn.ops"` inside the new image before doing anything else. - -## Deployment - -Same configmap pattern as `v0.7.1`. New trainer + inference manifests with `weight_broadcast.type = "mx_v2"`: - -```yaml -# configmap delta for v2 deployment -weight_broadcast: - type: mx_v2 # was: nixl_mx - host: modelexpress-server.kavin.svc.cluster.local - port: 8001 - same_rank_only: true # Phase 2 default - dedup_freshest_per_rank: true # Phase 2 default - publish_compile_target: true # Phase 3 default - publish_self_as_replica: true # tree fan-out default - inference_world_size: 4 - inference_model_name: Qwen/Qwen3-30B-A3B-Instruct-2507 -``` - -For A/B against PR #2389, run **two parallel deployments** in the kavin namespace under separate Job names — `prime-rl-nixl-mx-v0-7-1` (baseline) vs `prime-rl-mx-v2-kavin` (this work). Both use the same MX server (different `mx_source_id`s by content hash). - -## What to measure - -The validation matrix from `post-pr2389-mx-v2.md` §Validation plan, against the same workload PR #2389 was validated with: - -| Config | Refit cycle | Bandwidth | Notes | -|---|---|---|---| -| `nixl_mx` baseline | target ~10 s | ~80 ms NIXL push | PR #2389 push | -| `mx_v2` defaults | target ≤ 10 s | should match | Pull semantics, single-receiver | -| `mx_v2` + 4 receivers | target ≤ 10 s | `fanout_factor > 1.0` | Tree fan-out engages | -| `mx_v2` + filter mismatch | refuse at discovery, 0 RDMA bytes | — | Phase 3 safety net in production | -| `mx_v2` elastic scale-up | 2 → 4 replicas mid-training | new receivers join under 2 s | Phase 2 same-rank routing in production | - -## Open items - -| # | Item | Owner / status | -|---|---|---| -| 1 | Confirm v0.7.1 image has `flash_attn.ops` importable | Smoke test, ~30 s | -| 2 | Build + push `v0.7.2-kavin-mx-v2` | ~5 min after Phase A code lands | -| 3 | Deploy parallel A/B Jobs in kavin namespace | Cluster booking — ~5 hours of GPU node time | -| 4 | Capture per-cycle timing + bandwidth JSONs into `pensieve/RL/PrimeRL/results/` | Direct mirror of Slide 9 / Table B in the presentation | -| 5 | Replace Slide 9 Table A's synthetic numbers with end-to-end numbers | Once #4 is in hand | diff --git a/docs/proposals/post-pr2389-kernel-compile-plan.md b/docs/proposals/post-pr2389-kernel-compile-plan.md deleted file mode 100644 index 78db06fa4e..0000000000 --- a/docs/proposals/post-pr2389-kernel-compile-plan.md +++ /dev/null @@ -1,523 +0,0 @@ -# Post-PR-#2389 plan: kernel-compile separation, mixed-TP, and MX client adoption - -> **Related docs in this directory**: -> - [`post-pr2389-status-and-plan.md`](./post-pr2389-status-and-plan.md) — executive summary of where things stand + failure-class → fix mapping -> - [`build-notes-2026-05-28.md`](./build-notes-2026-05-28.md) — image-build experience, cluster observations, vLLM native RL APIs reframing -> -> This doc is the deep-dive RFC with full phase-by-phase design rationale. - -**Status**: Planning doc. Branch `kavink/post-2389-kernel-compile-plan` (NVIDIA-authored, off PR [#2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) HEAD `79ea824d8`). -**Premise**: This plan is what we propose to build on top of `nixl_mx` once #2389 merges to `main`. It (a) graduates the in-tree `MxRendezvous` reimplementation onto NVIDIA's published ModelExpress clients, (b) introduces a compile-target registry to fix the trainer-side cutlass-pinning issue surfaced during #2389's FP8 cast-pipeline iteration, and (c) extends the v2 shape registry to handle mixed-TP / sharded-source transfers. **None of this fights the #2389 data plane** — the `Slot` / `TransportPlan` / `NixlAgentWrapper` / `classic_cuda_pool` stack stays untouched. We extend the rendezvous and metadata surfaces only. - ---- - -## 1. What we're building on - -After #2389 merges, prime-rl's `nixl_mx` mode has: - -- `src/prime_rl/transport/` — `NixlAgentWrapper`, `MxRendezvous` (PI's reimplementation of MX's upper layer), `TransportPlan`, `classic_cuda_pool`, `wire.py` (msgspec types). -- `src/prime_rl/trainer/models/slots.py` — `ShardedSlot`, `GatheredSlot`, `ExpertSlot` (the per-tensor buffer abstractions that hold registered NIXL memory and conversion logic). -- `src/prime_rl/trainer/models/conversions/` — bf16-cast and FP8-blockwise conversion specs (`bf16_cast.py`, `fp8_blockwise.py`). -- `src/prime_rl/trainer/rl/broadcast/nixl_mx.py` — `NIXLMxWeightBroadcast`, the lifecycle wrapper. -- `src/prime_rl/inference/vllm/worker/nixl_mx.py` — `NIXLMxWeightUpdateWorker`, the vLLM worker extension. - -PI's data plane (Slot / TransportPlan / NixlAgentWrapper / FP8 conversion specs) is **kept**. It's the conversion / compile / topology layer where this plan extends things. - ---- - -## 2. The problem Matej flagged: trainer-side kernel compile pins the topology - -In #2389 today, the conversion-cast pipeline (`fp8_blockwise.py`, `Slot.convert`, the recent `c3c4b148` "inline transfer slot casting" + `47e170f5` "simplify transfer cast conversions" refactors) runs **on the trainer**. The trainer's send bucket holds the **post-compile** layout: bf16 if the inference engine wants bf16, FP8 with DeepGemm scale interleaving if it wants DeepGemm, cutlass-friendly column-major + epilogue scales if it wants cutlass, etc. Inference RDMA-reads the bucket and copies straight into live params. - -This is fast — no receiver-side compute on the refit critical path — but it couples three independent decisions into the trainer's process: - -1. **Which kernel format does inference want?** Today the trainer has to know. Adding cutlass on GB300 means trainer code changes; adding DeepGemm-EP changes the slot layout. -2. **Mixed inference replicas with different kernels.** With trainer-side compile, one trainer can serve only one compile target per refit. Heterogeneous fleets (e.g. a/b testing DeepGemm vs cutlass on the same training run) need two parallel trainer buckets. -3. **Mixed TP/EP layouts.** A trainer at TP=4 publishing to inference at TP=8 needs to know the inference splits at *publish* time to write the right bucket entries. The information is in PI's `build_topology`, but it's tangled with the compile decision because the compile passes produce TP-specific layouts. - -When Matej said "we're having issues handling compiled kernels" — this is the source. The cutlass compile happens trainer-side; any inference replica that doesn't want exactly the trainer's compile target either gets bad bytes or fails the assertion check. - ---- - -## 3. Proposed shift: receiver-side compile via scratch buffers - -Move the compile pass back to the **inference side**. Trainer ships **canonical HF layout** over NIXL (raw bfloat16 or raw float8_e4m3fn + raw per-block scales — whatever format the trainer naturally has post-optimizer-step). Inference runs the kernel-specific compile pass after the RDMA receive completes, into its own live params. - -This is the same scratch-buffer pattern we already validated in two places: - -- Our **original PrimeRL PoC** (`KavinKrishnan/prime-rl:kavink/mx-weight-broadcast`) used scratch buffers explicitly to triangulate KL drift — proved correctness when the receiver decides the layout. -- **John Thompson's NemoRL + Dynamo path** (May 22, 2026) uses `MxRefitReceiver.receive_weights_scratch` to handle vLLM's HF→fused param remapping via `stacked_params_mapping`. Validated end-to-end on GB300 RoCE at **380 Gbps** for an 8.82 GB / 399-tensor refit — same scratch path we'd reuse here, just with kernel-compile transforms instead of name remapping. - -### Receiver-side refit pseudocode - -```python -# inference/vllm/worker/nixl_mx.py — after PR #2389 graduates to MX clients - -def update_weights_via_mx(self, *, version, mx_config): - # 1. Discover same-rank trainer source via MxV2RefitReceiver - candidates = self._mx_receiver.discover_v2_sources( - model_name=self.model_name, - min_version=version, - same_rank_only=mx_config.same_rank_only, - compile_target_filter=self._target_kernel, # NEW — see §5 - ) - chosen = self._mx_receiver.pick_best_source(candidates) - - # 2. RDMA pull into scratch buffers (HF layout — whatever trainer published) - scratch = {} - for name, tensor in self._mx_receiver.receive_weights_scratch( - chosen.ref, - tensor_shapes=chosen.registry.tensor_shapes, # global shapes from v2 sidecar - target_tp_layout=self._tp_layout, # NEW — mixed-TP slice request - ): - scratch[name] = tensor - - # 3. Run the kernel-specific compile pass into live model params - self._compile_pass.apply( - scratch_buffers=scratch, - live_params=dict(self.model_runner.model.named_parameters()), - ) - - # 4. Tree fan-out republish (TensorHub pipeline replication) — unchanged - self._mx_receiver.publish_self_as_source(...) -``` - -The compile pass is a small dispatch: - -```python -# inference/vllm/worker/compile_passes.py — NEW - -class CompilePass(ABC): - target: str - def apply(self, scratch_buffers, live_params): ... - -class HFRaw(CompilePass): - target = "hf_raw" - def apply(self, scratch, live): - # Direct copy — names + dtypes match. fast path. - for name, t in scratch.items(): - live[name].data.copy_(t) - -class DeepGemmFP8(CompilePass): - target = "deep_gemm_fp8" - def apply(self, scratch, live): - # K-major scale interleave, fused gate_up_proj packing - for name, t in scratch.items(): - interleaved = deep_gemm_layout(t, block_size=128) - live[fused_name(name)].data.copy_(interleaved) - -class CutlassFP8(CompilePass): - target = "cutlass_fp8" - def apply(self, scratch, live): - # Column-major weights + epilogue scale tensors - ... -``` - -### Cost vs benefit - -**Cost**: extra GPU memory for scratch (~1× model size briefly per refit, freed after compile) + compile latency on the inference side (~50-200ms for FP8 passes on Qwen3-30B-A3B). - -**Benefit**: -- Trainer is **kernel-agnostic** — same bucket bytes serve any inference target. -- Mixed-fleet OK — different inference replicas can run different compile passes on the same trainer publish. -- Adding a new kernel = adding a new `CompilePass` subclass on the inference side. **Zero trainer change.** -- Cross-TP / cross-EP layouts decouple cleanly from compile (see §6). - -This is the same trade John Thompson made for Dynamo and the same trade we made in the original PoC. Empirically the 50-200ms compile latency is dwarfed by the 200ms RDMA pull anyway; total wall time is unchanged. - ---- - -## 4. New primitive: compile-target registry (extension to v2 shape registry) - -Our v2 shape registry today encodes: - -``` -TensorDescriptorV2: - name, global_shape, dtype, placement_kind, shard_axis, local_shard_range, - is_expert, expert_axis, owned_expert_ids -``` - -Add **two fields** at the registry level (not per-tensor — these describe how the *publisher* prepared the data): - -```diff - RegistryPayload (JSON in __mx_v2_meta__ sidecar): - version: int - trainer_world_layout: str # "fsdp:4,tp:1,pp:1,ep:1" -+ compile_target: str # "hf_raw" | "deep_gemm_fp8" | "cutlass_fp8" | ... -+ compile_metadata: dict # kernel-specific params: -+ # {block_size: 128, scale_dtype: "float8_e8m0", -+ # layout: "row_major", ...} - tensors: list[TensorDescriptorV2] -``` - -### Trainer side - -The trainer declares what it published — typically `"hf_raw"` post-optimizer-step: - -```python -publisher = MxV2TrainingPublisher( - agent_name=..., - world_layout=TrainerWorldLayout(fsdp_world_size=4, tp_world_size=1), - compile_target="hf_raw", # NEW — declarative, no inference-side dependency -) -``` - -Specialized trainers that *do* want to bake a compile pass in for perf can declare it: - -```python -publisher = MxV2TrainingPublisher( - ..., - compile_target="deep_gemm_fp8", - compile_metadata={"block_size": 128, "scale_dtype": "float8_e8m0"}, -) -``` - -— and then only same-target inference replicas will accept that publish. Mixed-target inference replicas skip it and look for an `hf_raw` source. - -### Receiver side - -`discover_v2_sources(compile_target_filter=...)` filters candidate trainers: - -```python -candidates = receiver.discover_v2_sources( - model_name=..., - min_version=N, - same_rank_only=True, - compile_target_filter={"hf_raw"}, # accept HF-raw only, run compile myself -) -``` - -Or accept any compatible target (if the receiver has a fallback compile pass): - -```python -candidates = receiver.discover_v2_sources( - model_name=..., - compile_target_filter={"hf_raw", "deep_gemm_fp8"}, # accept either -) -``` - -The picker's existing trainer-vs-replica + freshest-per-rank sort applies on the filtered set. - -### Why this is in the v2 shape registry, not a new transport - -It's metadata about the wire format. The shape registry already travels via the synthetic `__mx_v2_meta__` `TensorDescriptor` sidecar (proven by jthomson04's GB300 run + protected by PR #295's filter). Adding two more JSON fields is zero-cost on the wire and keeps the discovery contract in one place. - ---- - -## 5. Sharding and mixed-TP transfers - -The hardest case is **trainer TP/EP layout ≠ inference TP/EP layout**. The shape registry already encodes this on the publish side: - -``` -placement_kind: "SHARD" -shard_axis: 0 -local_shard_range: (start, end) # this rank's slice along shard_axis -``` - -What's missing is the **receiver's expression of what it wants**. - -### New receiver-side API - -```python -class TargetTPLayout: - """What slice of the global tensor THIS receiver needs.""" - world_size: int # inference TP world size - rank: int # this receiver's TP rank - shard_axis: int # which axis we're sharded on (model-dependent) - -receiver.receive_weights_scratch( - chosen.ref, - target_tp_layout=TargetTPLayout(world_size=8, rank=3, shard_axis=0), - tensor_shapes=chosen.registry.tensor_shapes, -) -``` - -The receiver computes its desired slice from `target_tp_layout` and the published `placement_kind`: - -| Publisher | Receiver | Result | -|---|---|---| -| `REPLICATE` (trainer TP=1) | TP=8, rank=3, axis=0 | Receiver requests rows `[3N/8 : 4N/8]` of the publisher's tensor | -| `SHARD(0)` trainer TP=4 rank=2, range `[N/2 : 3N/4]` | TP=8, rank=4, axis=0, requests `[N/2 : 5N/8]` | Receiver pulls from same-physical-rank trainer (R2), takes lower half of R2's shard | -| `SHARD(0)` trainer TP=4 rank=2, range `[N/2 : 3N/4]` | TP=8, rank=5, axis=0, requests `[5N/8 : 6N/8]` | Receiver pulls from same trainer rank (R2), takes upper half of R2's shard | -| `SHARD(0)` trainer TP=4 rank=1, range `[N/4 : N/2]` | TP=2, rank=0, axis=0, requests `[0 : N/2]` | Receiver pulls from **both** trainer R0 + R1, concatenates | - -For cases where one inference rank needs slices from multiple trainer ranks (last row), the receiver picks **N candidates** instead of one: - -```python -multi_source = receiver.discover_v2_sources_for_slice( - model_name=..., - target_slice=(start, end), - shard_axis=0, -) -# Returns one SourceRef per trainer rank whose shard overlaps target_slice -# Receiver does N parallel RDMA pulls, concatenates in scratch -``` - -### Mixed-EP for MoE - -Same machinery, on the expert axis. NemoRL v2 already does this for `owned_expert_ids`: - -```python -candidates = receiver.discover_v2_sources( - model_name=..., - target_expert_ids_per_layer={ - 5: {0, 1, 2, 3}, # this inference rank's owned experts in layer 5 - 6: {0, 1, 2, 3}, - }, -) -``` - -The picker matches candidates whose `expert_owner_per_rank` covers the needed experts (existing logic in `MxV2RefitReceiver.pick_best_source` — see `nemo_rl_v2.py`). For mixed-EP (trainer EP=4, inference EP=8), receivers may pull from multiple trainer ranks via the same `discover_v2_sources_for_slice` pattern. - -### Why this matters for PrimeRL specifically - -PI's `ExpertSlot` (in `slots.py`) and `build_topology()` (in `nixl_checkpoint_engine`-style code) already implement TP-matched and EP-matched pairing on the trainer side. They're computing `peer_chunk_descs` based on the publisher's known topology. **What's missing is the inverse**: when the inference layout differs from the trainer's, the receiver needs to express that. The compile-target registry + the slice-discovery API give it that vocabulary. - ---- - -## 6. MX client adoption — layered with this plan - -The two-phase migration from our earlier review of #2389 (Phase 1 surgical / Phase 2 client adoption) is the foundation this plan builds on: - -### Phase 1 — surgical fixes against the `nixl_mx` in-tree code (drop-in patches) - -The 6 inline-comment fixes (line numbers verified against `nixl_mx` HEAD `79ea824d8`): - -1. Same-rank `add_remote_agent` filter in `transport_plan.py` -2. Freshest-per-rank dedup in `mx_rendezvous.py::wait_for_peers` -3. `HeartbeatThread` after `set_status(READY)` in `inference/.../nixl_mx.py` -4. Read timeout from config (not hardcoded 1200s) -5. MLA-guard for non-MLA models (`update_mla_absorbed_weights`) -6. HSDP barrier ordering in `trainer/.../nixl_mx.py` - -These land **before** this plan starts — closes the bug classes without architectural change. - -### Phase 2 — graduate the rendezvous half onto ModelExpress clients - -Delete `src/prime_rl/transport/mx_rendezvous.py` (~185 LOC, replicates functionality already in `modelexpress`). Replace with imports of `MxV2TrainingPublisher` and `MxV2RefitReceiver`. The in-tree `NixlAgentWrapper` + `Slot` + `TransportPlan` + `classic_cuda_pool` stay — that's prime-rl-specific data-plane specialization and shouldn't move. - -```diff --from prime_rl.transport.mx_rendezvous import MxRendezvous -+from modelexpress import MxV2TrainingPublisher, MxV2RefitReceiver -``` - -This is what unblocks everything in §3-§5: - -- `MxV2TrainingPublisher` exposes the v2 sidecar registry that §4's `compile_target` extends. -- `MxV2RefitReceiver.receive_weights_scratch` is the proven path from John's Dynamo work (380 Gbps GB300). -- `discover_v2_sources(compile_target_filter=...)` is a small extension to the existing picker. -- Heartbeat / freshest-dedup / retention all come along for free — no separate Phase 1 work needed once Phase 2 lands. - -### Phase 3 (this plan) — compile-target registry + mixed-TP - -- Add `compile_target` + `compile_metadata` to v2 shape registry (~30 LOC in `shape_descriptors.py` + `nemo_rl_v2.py`). -- Add `compile_target_filter` to `discover_v2_sources` (~15 LOC). -- Add `target_tp_layout` + `discover_v2_sources_for_slice` to `MxV2RefitReceiver` (~120 LOC). -- Add `compile_passes/` module in `src/prime_rl/inference/vllm/worker/` with `HFRaw`, `DeepGemmFP8`, `CutlassFP8` passes (~300 LOC). Or in NemoRL `nemo_rl/models/generation/vllm/compile_passes/` — see §7. -- PI's `nixl_mx.py` inference worker calls into the right `CompilePass` based on `engine.kernel_target`. - -Total: ~450 LOC across MX + PrimeRL, all additive. - ---- - -## 7. What we borrow from John Thompson's NemoRL+Dynamo work - -Five specific pieces, all already proven on GB300 RoCE: - -### 7.1 `receive_weights_scratch` is the foundation - -John's path uses `MxRefitReceiver.receive_weights_scratch` because vLLM's `stacked_params_mapping` requires HF-named tensors that the receiver later passes to `model.load_weights()`. That's structurally identical to "trainer ships HF-raw, receiver compiles into kernel layout": - -```python -# John's existing flow (NemoRL + Dynamo + vLLM v1) -weights = list(receiver._receiver.receive_weights_scratch( - chosen.ref, - timeout_seconds=mx_config.timeout_seconds, - tensor_shapes=tensor_shapes, -)) -self.model_runner.model.load_weights(weights=weights) # vLLM does the HF→fused remap - -# Our extension (PrimeRL post-#2389 + kernel compile) -weights = list(receiver._receiver.receive_weights_scratch( - chosen.ref, - timeout_seconds=mx_config.timeout_seconds, - tensor_shapes=tensor_shapes, - target_tp_layout=self._target_tp_layout, # NEW -)) -self._compile_pass.apply(weights, live_params) # OUR compile dispatch -``` - -The mechanism is identical. Only the post-RDMA stage differs. - -### 7.2 The `worker_extension_cls` injection pattern is cleaner than subclassing - -John's `MxRefitWorkerExtension` (in `dynamo/vllm/mx_refit/extension.py`) is injected via vLLM v1's `parallel_config.worker_extension_cls`. The class has no `__init__`; vLLM merges its methods into the existing `Worker` via `__bases__`. State is stashed lazily on `self` with `_mx_` prefixed attribute names. - -PI's `NIXLMxWeightUpdateWorker` today **subclasses** `Worker` directly. The extension-class pattern would let the refit logic live in a sibling module without touching the inheritance chain — useful when we add the compile passes (§3) because those want to live in their own package. - -**Recommend**: when graduating PrimeRL to MX clients in Phase 2, also adopt the `worker_extension_cls` pattern for the inference worker. The two changes naturally compose. - -### 7.3 PR #295's sidecar filter is required for any new v2 metadata - -If we extend the v2 sidecar (§4) without keeping PR #295's filter in `MxRefitReceiver.receive_weights{,_scratch}`, the synthetic `__mx_v2_meta__` `TensorDescriptor` poisons `prep_xfer_dlist` again — same `NIXL_ERR_NOT_FOUND` John hit before May 22. **No new code needed; the filter is already in `kavink/nemo_rl_moe` HEAD `8594fd6`.** Just don't accidentally back it out. - -### 7.4 The `FORCE_RDMA=1` test mode catches this class of bug in loopback - -The v2 demo scripts (`scripts/v2_*_e2e_demo.py`) have a `FORCE_RDMA=1` env var (commit `e8e063b`) that pins `UCX_TLS` off `cuda_ipc` so intra-node loopback exercises the strict `rc_mlx5` descriptor-list validator. **Run every new compile-pass test under `FORCE_RDMA=1`** — otherwise we'll merge a sidecar / descriptor-list bug that doesn't show up until cross-node deploy. - -### 7.5 The compile-pass module probably belongs in NemoRL first, mirrored into prime-rl - -John's Dynamo path is the most mature target for testing new kernels (Qwen3-4B-Thinking GRPO smoke is already running cross-node at 380 Gbps). The compile passes themselves are framework-agnostic — they just take `(scratch_dict, live_params_dict)` and run torch ops. **Recommend**: - -1. Implement `compile_passes/` first in `modelexpress_client/python/modelexpress/compile_passes/` — framework-neutral, reusable. -2. NemoRL + Dynamo path adopts it via `MxRefitWorkerExtension._compile_pass = HFRaw()` (or `DeepGemmFP8()`). -3. Validate end-to-end on GB300 with cutlass + DeepGemm kernels. -4. Mirror into PrimeRL's inference worker after Phase 2 graduates them to MX clients. - -This sequence means we de-risk the compile-pass design on the path that's already shown working before we touch PrimeRL. Same play we ran for the v2 sidecar — designed in NemoRL, validated in NemoRL+Dynamo (John), then graduated to prime-rl. - ---- - -## 8. Implementation phases - -| Phase | Scope | Estimated LOC | Owner | -|---|---|---|---| -| **0** | Wait for #2389 to merge upstream | — | Matej | -| **1** | 6 surgical fixes against PI's `transport/*.py` + `inference/*.py` (closes bug classes) | ~100 LOC | Us — fast follow on Matej | -| **2** | Graduate `MxRendezvous` → `MxV2TrainingPublisher` / `MxV2RefitReceiver`; adopt `worker_extension_cls` pattern in inference worker | ~−400 LOC (PI's reimpl removed) + ~150 LOC import-and-call | Us | -| **3a** | Add `compile_target` + `compile_metadata` to v2 shape registry | ~30 LOC | Us | -| **3b** | Add `compile_target_filter` to `discover_v2_sources` | ~15 LOC | Us | -| **3c** | Add `target_tp_layout` + `discover_v2_sources_for_slice` to `MxV2RefitReceiver` | ~120 LOC | Us | -| **3d** | Implement `compile_passes/` (HFRaw, DeepGemmFP8, CutlassFP8) — in MX repo for reuse | ~300 LOC | Us | -| **3e** | Validate on NemoRL+Dynamo path (John's GB300 cluster) — Qwen3-4B-Thinking with DeepGemm and cutlass kernels both running on the same MX server | E2E | Us + John | -| **3f** | Mirror compile-pass dispatch into PI's `inference/vllm/worker/nixl_mx.py` | ~50 LOC | Us — PR back to PI | -| **4** | Mixed-TP / mixed-EP slice discovery wired end-to-end (multi-source RDMA pulls) | ~200 LOC | Us — separable from Phase 3 | - -**Phases 0-1** are fully sequenced (must wait for upstream + apply surgical fixes). **Phases 2 onward** can run in parallel if we're willing to maintain a `kavink/post-2389-*` branch off PI's main + a follow-on PR per phase. - ---- - -## 9. Open questions - -1. **Does PI want the compile passes in their tree, or in MX?** If MX, they import a pluggable `CompilePassRegistry`. If their tree, the kernel ecosystem stays close to their Slot system (which already does fp8_blockwise). My lean: **MX**, because Dynamo + NemoRL also want them, and PI's per-Slot conversion stays for the trainer-side path when teams opt into "publish post-compile". - -2. **Does cutlass-FP8 work on inference-side compute?** Compile pass needs ~200ms of CUDA time. If the inference engine is mid-rollout when the refit arrives, we either pause and run the compile or queue it for the next "between rollouts" window. PrimeRL's current orchestrator does the latter; this plan inherits. - -3. **How do we handle trainers that publish post-compile (Matej's current path)?** Their `compile_target = "deep_gemm_fp8"`; receivers either accept it directly (fast path) or reject and look for `hf_raw`. Mixed fleets get clean error messages, not corrupt weights. - -4. **Mixed-TP across nodes — what's the bandwidth math?** Trainer TP=4 ↔ inference TP=8 means each inference rank pulls from 1-2 trainer ranks. For Qwen3-30B-A3B on GB200 (~30 GB / 4 trainer ranks = 7.5 GB/rank), an inference rank pulling 2× 4 GB slices is well within NIC budget. For larger EP layouts where one inference rank needs experts from N>2 trainer ranks, fan-in becomes interesting — that's where pipeline replication (TensorHub) and rollouts-as-replicas pay off. - -5. **What's the deprecation story for trainer-side compile?** We don't deprecate it — Matej's path stays valid for teams that want zero inference-side latency. The `compile_target` field is just informational; receivers filter on it. - -6. **Should the compile pass run before or after `update_mla_absorbed_weights`?** After. MLA absorption operates on live params; compile runs first so live params are in the right layout when MLA absorption runs. - ---- - -## 10. Component view + sequence diagram - -```mermaid -flowchart TB - subgraph trainer["Trainer side (after Phase 2)"] - TBcast["NIXLMxWeightBroadcast
(PI's lifecycle wrapper)"] - TSlots["Slots: Sharded · Gathered · Expert
(PI's data plane, kept)"] - TPlan["TransportPlan
(PI's, kept)"] - TPub["MxV2TrainingPublisher
(NEW — replaces MxRendezvous)"] - TAgent["NixlAgentWrapper
(PI's, kept)"] - TBcast --> TSlots - TBcast --> TPlan - TBcast --> TPub - TSlots --> TAgent - TPlan --> TAgent - TPub -. publishes registry incl.
compile_target=hf_raw .-> TPlan - end - - subgraph mx["MX control plane (unchanged)"] - MXSVR[("MX Server · gRPC + Redis
shape registry, compile-target
filter, tree fan-out catalog")] - end - - subgraph inf["Inference side (after Phase 2+3)"] - IRec["MxV2RefitReceiver
(NEW — replaces ad-hoc rendezvous)"] - IScratch["receive_weights_scratch
+ target_tp_layout
(extended John's path)"] - ICompile["CompilePass dispatch (NEW)
HFRaw · DeepGemmFP8 · CutlassFP8"] - ILive["vLLM model.named_parameters()
(live params — kernel-specific layout)"] - IRec --> IScratch - IScratch --> ICompile - ICompile --> ILive - end - - TPub <-.->|"publish_metadata (incl. compile_target)
set_status · update_status"| MXSVR - IRec <-.->|"discover_v2_sources(compile_target_filter=...)
list_sources · get_metadata"| MXSVR - TAgent <==>|"one-sided RDMA WRITE
UCX rc_mlx5 / RoCE
(HF-raw bytes, post-cast pre-compile)"| IScratch - - style mx fill:#fec,stroke:#963 - style MXSVR fill:#fec,stroke:#963,stroke-width:2px - style TPub fill:#cce,stroke:#33c,stroke-width:2px - style IRec fill:#cce,stroke:#33c,stroke-width:2px - style ICompile fill:#cfc,stroke:#363,stroke-width:2px - style ILive fill:#fcc,stroke:#c33 -``` - -### One refit cycle, after this plan lands - -```mermaid -sequenceDiagram - autonumber - participant O as Orchestrator - participant T as Trainer (NIXLMxWeightBroadcast) - participant MX as MX Server - participant I as Inference worker - participant C as CompilePass - - Note over T,I: BOOT (once per refit run) - O->>I: POST /init_nixl_mx (host, port, rank, kernel_target=deep_gemm_fp8) - I->>I: register live params with NIXL (PI's data plane) - I->>MX: publish_metadata(role=inference, kernel_target=deep_gemm_fp8) - I->>MX: update_status(READY) - - Note over T,I: PER REFIT STEP - O->>I: POST /update_weights - I->>MX: wait_for_all_peers_ready(role=trainer, READY) - - T->>T: lazy_init slots; per-rank scratch fill from state_dict
(NO trainer-side cutlass — bytes are HF-raw) - T->>MX: publish_metadata(role=trainer, compile_target="hf_raw",
compile_metadata={...}, shape registry per tensor) - T->>MX: update_status(INITIALIZING → READY) - - I->>MX: discover_v2_sources(model, min_version=N,
compile_target_filter={"hf_raw"},
target_tp_layout=TP=8 rank=3) - MX-->>I: candidates = [trainer R0 (covers requested slice)] - I->>I: pick_best_source - - Note right of I: SCRATCH PATH (from John's NemoRL+Dynamo work) - I->>T: NIXL one-sided RDMA WRITE → scratch buffers
(HF-raw layout, ~380 Gbps on GB300 RoCE) - I->>I: torch.cuda.synchronize - - Note right of I: COMPILE PASS (new — runs inference-side, ~50-200ms) - I->>C: apply(scratch_buffers, live_params)
e.g. DeepGemm scale interleave, fused gate_up_proj pack - C-->>I: live params updated in DeepGemm-friendly layout - - I->>I: update_mla_absorbed_weights (if MLA model) - I->>MX: publish_self_as_source(role=inference_replica, version=N)
(tree fan-out for next refit) - - I-->>O: 200 OK - O->>O: scheduler advances · next rollout uses new weights -``` - ---- - -## 11. Cross-references - -ModelExpress design docs (NVIDIA-authored, for context on the client surface this plan adopts): - -- [`docs/RL/PRIMERL_MX_OVERVIEW.md`](https://github.com/ai-dynamo/modelexpress/blob/main/docs/RL/PRIMERL_MX_OVERVIEW.md) — the foundational prime-rl × MX integration design (catalog + star wiring story). -- [`docs/RL/NEMORL_MX_OVERVIEW.md`](https://github.com/ai-dynamo/modelexpress/blob/kavink/nemo_rl_moe/docs/RL/NEMORL_MX_OVERVIEW.md) — the v2 design (rank-to-rank, tree fan-out, expert filter, shape registry) that this plan extends with the compile-target axis. -- [`docs/RL/VERL_MX_OVERVIEW.md`](https://github.com/ai-dynamo/modelexpress/blob/main/docs/RL/VERL_MX_OVERVIEW.md) — the verl `MxCheckpointEngine` integration; sibling adopter of the same MX clients. - -Upstream branches this plan refers to: - -- ModelExpress branch [`kavink/nemo_rl_moe`](https://github.com/ai-dynamo/modelexpress/tree/kavink/nemo_rl_moe) — the v2 client surface (`MxV2TrainingPublisher`, `MxV2RefitReceiver`), `shape_descriptors`, sidecar transport, and PR #295's sidecar filter. The MX-side dependency that Phase 2 imports. -- ModelExpress PR [#295](https://github.com/ai-dynamo/modelexpress/pull/295) — synthetic-sidecar `TensorDescriptor` filter in `MxRefitReceiver`. Required for any v2 metadata extension (including this plan's `compile_target` field) to survive `prep_xfer_dlist` validation on cross-node RoCE. Already merged into `kavink/nemo_rl_moe`; just don't back it out. -- NemoRL × Dynamo branches (John Thompson, NVIDIA): [`KavinKrishnan/RL:kavink/mx_integration`](https://github.com/KavinKrishnan/RL/tree/kavink/mx_integration) (NeMo-RL side) + the Dynamo-side companion. Validated at 380 Gbps on GB300 RoCE for an 8.82 GB / 399-tensor refit (Qwen3-4B-Thinking GRPO smoke). - -NVIDIA-internal context (not necessary for upstream review, listed for our own bookkeeping): - -- The 6 inline review comments + summary message we have queued for #2389 — verified line numbers against HEAD `dabaa19f5` (still applicable on `79ea824d8` after I re-checked May 27). Phase 1 of this rollout. -- Current state of #2389: +10 commits since `dabaa19f5` (4× conversion-cast polish, 4× DeepGemm env-var hygiene, 2× config/import fixes). None touched the 6 flagged lines. diff --git a/docs/proposals/post-pr2389-mx-v2.md b/docs/proposals/post-pr2389-mx-v2.md deleted file mode 100644 index 5ee5810832..0000000000 --- a/docs/proposals/post-pr2389-mx-v2.md +++ /dev/null @@ -1,271 +0,0 @@ -# RFC: `weight_broadcast.type = "mx_v2"` — the complete prime-rl × ModelExpress design - -> **Status:** Draft. Targets [`PrimeIntellect-ai/prime-rl`](https://github.com/PrimeIntellect-ai/prime-rl) upstream as a follow-up to [PR #2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389). -> -> **Companion docs in this branch:** -> - [`post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) — the four-phase plan this RFC consolidates -> - [`post-pr2389-status-and-plan.md`](./post-pr2389-status-and-plan.md) — current status of the four phases -> - [`build-notes-2026-05-28.md`](./build-notes-2026-05-28.md) — image-build mechanics for the source-baked Phase 2 + Phase 3 image -> - [`image-build-mx-v2.md`](./image-build-mx-v2.md) — the v2 overlay-image plan (Workstream B of this RFC) - -## TL;DR - -This RFC proposes a single new weight-broadcast type, **`weight_broadcast.type = "mx_v2"`**, that contains every optimization the post-#2389 RFC identifies, behind one config knob. It coexists with the existing `"nixl_mx"` (PR #2389) for migration; no behavior of `"nixl_mx"` changes. - -| Capability | `nixl_mx` (PR #2389) | `mx_v2` (this RFC) | -|---|---|---| -| Data plane | NIXL RDMA WRITE (push) | NIXL RDMA WRITE *or* READ (engine-dispatched) | -| Control plane | In-tree `MxRendezvous` (185 LOC, 4 gRPC calls) | `MxWeightTransferEngine` over MX v2 fat clients | -| Heartbeat + freshest-dedup + same-rank routing | Runtime monkey-patch via configmap | Baked in (Phase 2) | -| compile_target safety net | None — silent corruption on layout mismatch | Phase 3 — refuses mismatched at discovery, before RDMA | -| Mixed-TP / mixed-EP | Requires matching layouts | Phase 4 — multi-source slice picker + stitching | -| Tree fan-out (pipeline replication) | None — single-source from trainer | Receivers republish; trainer NIC stops being the bottleneck past ~4 receivers | -| MoE expert filter | None — every receiver pulls every expert | Bandwidth-proportional to EP (8× savings for EP=8 on a 192-expert model) | -| vLLM native API alignment | Bespoke worker extension | Targets `WeightTransferEngine` ABC (same shape as Anyscale RDT PR #43375) | -| Net LOC | baseline | +255 (new) / −185 (delete in-tree rendezvous in follow-up) | - -## Motivation - -PR #2389 ships a working NIXL+MX path on GB200 (~10 s/cycle on Qwen3-30B-A3B). Production hardening since then surfaced six issues that map to one root cause: **prime-rl owns the rendezvous + transport layer, so every cross-cutting capability (heartbeat, compile-target metadata, slice picking) lives in prime-rl too.** - -The cross-cutting work — heartbeat, dedup, shape registry, MoE expert filter, mixed-TP slice picker, tree fan-out — has been built and unit-tested **inside ModelExpress** (PR [#349](https://github.com/ai-dynamo/modelexpress/pull/349), branch `kavink/post-2389-phase3-4`). What's missing is the single integration PR that consumes all of it from prime-rl. - -This RFC is that integration PR. - -## Design - -### 1. New config type - -`packages/prime-rl-configs/src/prime_rl/configs/trainer.py` adds `MxV2WeightBroadcastConfig`: - -```python -class MxV2WeightBroadcastConfig(BaseWeightBroadcastConfig): - type: Literal["mx_v2"] = "mx_v2" - - # ─── Control plane ────────────────────────────────────────────── - host: str = "localhost" - port: int = 29501 - timeout: int = 1200 - - # ─── Discovery (Phase 2) ──────────────────────────────────────── - same_rank_only: bool = True - """GB200/EFA multi-NIC fabrics: receivers pull from same-rank trainer only.""" - dedup_freshest_per_rank: bool = True - """When multiple READY entries share a worker_rank, pick the freshest by updated_at.""" - - # ─── Layout metadata (Phase 3) ────────────────────────────────── - publish_compile_target: bool = True - """Trainer stamps every publish with the conversion's compile_target tag.""" - compile_target_filter: list[str] | None = None - """Receiver-side whitelist. None = accept anything (back-compat). Set to - {'cutlass_fp8'} or {'hf_raw','cutlass_fp8'} to refuse mismatches before RDMA.""" - - # ─── Sharding (Phase 4) ───────────────────────────────────────── - target_tp_layout: TargetTPLayout | None = None - """None = matched-TP fast path (single-source same-rank pull). - Set when trainer TP/EP layout differs from inference.""" - - # ─── Pipeline replication (TensorHub pattern) ─────────────────── - publish_self_as_replica: bool = True - """After a successful receive, inference workers republish themselves as - sources; subsequent receivers can pull from peers instead of the trainer.""" - - inference_world_size: int = 1 - inference_model_name: str = "" -``` - -### 2. Selector dispatch - -`src/prime_rl/trainer/rl/broadcast/__init__.py` adds one `elif`: - -```python -elif config.type == "mx_v2": - from prime_rl.trainer.rl.broadcast.nixl_mx_v2 import NIXLMxV2WeightBroadcast - - assert parallel_dims is not None, "mx_v2 requires parallel_dims" - return NIXLMxV2WeightBroadcast(output_dir, config, parallel_dims) -``` - -### 3. New trainer broadcast — `src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py` - -Replaces the bespoke `MxRendezvous` + manual NIXL `post_write` flow with `MxV2TrainingPublisher`: - -```python -class NIXLMxV2WeightBroadcast(WeightBroadcast): - def __init__(self, output_dir, config, parallel_dims): - self.publisher = MxV2TrainingPublisher( - agent_name=make_agent_name("trainer", world.rank), - device_id=torch.cuda.current_device(), - mx_server_url=f"{config.host}:{config.port}", - worker_rank=world.rank, - world_layout=TrainerWorldLayout(...), # from parallel_dims - ) - - def lazy_init(self, model): - self.publisher.initialize(model_name=config.inference_model_name) - # Slot allocation + conversion still owned by prime-rl - self.model_slots = model.build_slots(...) - self.conversion = select_default_conversion(...) - - @torch.no_grad() - def broadcast_weights(self, model, step): - # 1. Run trainer-side conversion (prime-rl owns this) - for slot in self.model_slots: - slot.fill_from(model.state_dict(), self.conversion) - - # 2. Register each slot tensor with the publisher, tagged with - # compile_target + compile_metadata from the conversion registry - for slot in self.model_slots: - for name, tensor, _ in slot.buffers: - self.publisher.add_tensor( - name=name, - tensor=tensor, - compile_target=self.conversion.compile_target, - compile_metadata=self.conversion.compile_metadata, - # MoE expert metadata where applicable - is_expert=slot.is_expert, - expert_axis=slot.expert_axis, - owned_expert_ids=slot.owned_expert_ids, - ) - - # 3. One publish() per step - self.publisher.publish(version=step) - self.publisher.mark_ready() -``` - -**Key invariants preserved from PR #2389:** -- Trainer-side conversion (FP8 packing, fusion, sharding) — prime-rl owns the kernel -- Slot layout — `Sharded` / `Gathered` / `Expert` slots stay -- HSDP barrier — `dp_replicate > 1` only publishes from rank 0 -- Per-step lifecycle — `lazy_init` on first call, `broadcast_weights` every step - -**What changes vs PR #2389:** -- The push (`nixl_agent.post_write` loop) becomes a publish (`publisher.add_tensor` × N + `publisher.publish`); the actual NIXL WRITE is now driven from the receiver side via `receive_weights_scratch` -- Trainer no longer needs to know inference world size in advance — receivers discover via catalog - -### 4. New inference worker — `src/prime_rl/inference/vllm/worker/nixl_mx_v2.py` - -Uses `MxWeightTransferEngine` via vLLM's `worker_extension_cls`: - -```python -class NIXLMxV2WeightUpdateWorker(Worker): - """vLLM worker extension for the v2 pull path.""" - - def init_nixl_mx_v2(self, host: str, port: int, rank_offset: int, **engine_init_kwargs): - from modelexpress.vllm_weight_transfer import MxInitInfo, MxWeightTransferEngine - global_rank = rank_offset + self.device.index - inference_model_name = self.model_runner.model_config.model - - self.engine = MxWeightTransferEngine(init_info=MxInitInfo( - mx_server_url=f"{host}:{port}", - model_name=inference_model_name, - worker_rank=global_rank, - agent_name=make_agent_name("inference", global_rank), - device_id=self.device.index, - publish_self_as_replica=engine_init_kwargs.get("tree_fanout", True), - )) - - @torch.no_grad() - def update_weights_via_mx_v2(self, step: int, *, compile_target_filter=None, target_tp_layout=None) -> None: - from modelexpress.vllm_weight_transfer import MxUpdateInfo - self.engine.receive_weights( - MxUpdateInfo( - version=step, - compile_target_filter=set(compile_target_filter) if compile_target_filter else None, - target_tp_layout=target_tp_layout, - timeout_seconds=self.config.timeout, - ), - load_weights=self._load_weights_batch, - ) - # Same post-load housekeeping as PR #2389 - update_mla_absorbed_weights(self.raw_model) - - def _load_weights_batch(self, batch: list[tuple[str, torch.Tensor]]) -> None: - """Feed yielded tensors through vLLM's model.load_weights(). - vLLM handles HF→fused name remapping via stacked_params_mapping.""" - self.raw_model.load_weights(batch) -``` - -**Key changes vs PR #2389:** -- No pre-registered NIXL buffers on inference side (uses `receive_weights_scratch` under the hood) -- Trainer push → receiver pull semantics -- HF-format publish → vLLM `load_weights` handles fused param remapping (matches NeMo-RL pattern) - -### 5. Conversion registry — already done - -`src/prime_rl/trainer/models/conversions/__init__.py` was extended on `kavink/post-2389-conversion-registry-extensions` ([Draft PR #2](https://github.com/KavinKrishnan/prime-rl/pull/2)) with: -- `compile_target` + `compile_metadata` fields on `ConversionEntry` -- `cutlass_fp8_e4m3_per_channel` registered alongside `bf16_cast` and `fp8_128x128` -- 19/19 unit tests green - -This RFC just consumes that work — no additional changes to conversion-registry code. - -### 6. Image — overlay on `v0.7.1-kavin-phase2-phase3` - -The v2 source layers cleanly on top of `prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3`, which already contains Phase 2 + Phase 3. The Dockerfile is a 5-line overlay: - -```dockerfile -FROM nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3 -COPY src/prime_rl/transport/ /app/src/prime_rl/transport/ -COPY src/prime_rl/inference/vllm/worker/nixl_mx_v2.py /app/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py -COPY src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py /app/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py -COPY src/prime_rl/trainer/rl/broadcast/__init__.py /app/src/prime_rl/trainer/rl/broadcast/__init__.py -COPY packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/packages/prime-rl-configs/src/prime_rl/configs/trainer.py -``` - -Plus a 1-line `uv sync --no-deps --reinstall-package modelexpress` if we need to pull in the MX-side engine adapter that ships with PR #349. - -See [`image-build-mx-v2.md`](./image-build-mx-v2.md) for the build mechanics + cluster deployment steps. - -## Migration - -| Phase | Config | Status | -|---|---|---| -| **v0.x** (now) | `nixl_mx` and `mx_v2` coexist | `nixl_mx` remains the documented default. `mx_v2` opt-in. | -| **v0.x+1** | `nixl_mx` deprecated with warning | After 4 weeks of `mx_v2` bake-time on `kavin` + at least one external user. | -| **v0.x+2** | `nixl_mx` removed | After another release cycle. Tracks vLLM's native `WeightTransferEngine` API merge — once that's available, `mx_v2` registers as `backend="mx_nixl"` and `WeightTransferConfig` becomes the recommended entry point. | - -No user is forced to migrate. `nixl_mx` users get heartbeat + dedup via PR #1 (Phase 2 source-bake) regardless of this RFC. - -## Validation plan - -### Pre-merge (this branch, before opening upstream) - -1. **Unit tests:** existing 58 MX-side tests (35 v2 shape/picker + 14 engine + 9 bench) + new tests for the prime-rl integration files (≥10 new unit tests covering `NIXLMxV2WeightBroadcast.broadcast_weights` and `NIXLMxV2WeightUpdateWorker.update_weights_via_mx_v2`). - -2. **Cluster A/B on `kavin` namespace, GB200, Qwen3-30B-A3B-Instruct-2507:** - - | Config | Refit cycle | Bandwidth | Notes | - |---|---|---|---| - | `type=nixl_mx` (PR #2389 baseline) | target ~10 s | ~80 ms NIXL push | Push, no filter, no fan-out | - | `type=mx_v2`, defaults | target ≤ 10 s | should match | Pull, filter on, fan-out on but with 1 inference replica (no-op) | - | `type=mx_v2` + 4 inference replicas | target ≤ 10 s | `fanout_factor > 1.0` | Tree fan-out kicks in | - | `type=mx_v2` + mismatched filter | should refuse at discovery | 0 RDMA bytes | Compile-target safety net under production workload | - -3. **Elastic scale-up:** scale inference from 2 → 4 replicas mid-training. With Phase 2 same-rank routing baked in, all 4 should join cleanly without orchestrator restart. Measured via the harness in `modelexpress/benchmarks/bench_elastic_scaling.py` but against the real model, not synthetic tensors. - -### Post-merge (upstream) - -1. Add `mx_v2` smoke test to upstream prime-rl CI. -2. Coordinate with upstream PrimeIntellect on a real-RL-job validation matrix. - -## What this RFC does *not* do - -| Out of scope | Why | -|---|---| -| Delete PR #2389's `nixl_mx` code | Coexist for ≥1 release cycle. PR #1 (Phase 2 fixes) lands into `nixl_mx` regardless. | -| Implement delta-sync ([HF blog](https://huggingface.co/blog/delta-weight-sync)) | Layer-2 optimization — composes orthogonally with `mx_v2` and lands separately once vLLM merges `pause_generation(mode="keep")`. | -| Implement true async refit (Composer 2 / Fireworks) | Layer-2 optimization. Same reason as delta-sync. | -| Cross-DC / WAN | TensorHub-pattern; `mx_v2` already supports it via MX catalog metadata, but no cross-DC validation here. | -| Production hardening of MX server | Owned by `ai-dynamo/modelexpress`. We consume; we don't fork. | - -## References - -- [PR #2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) — the baseline -- [`KavinKrishnan/prime-rl#1`](https://github.com/KavinKrishnan/prime-rl/pull/1) — Phase 2 source-baked rendezvous fixes (heartbeat + dedup + same-rank) -- [`KavinKrishnan/prime-rl#2`](https://github.com/KavinKrishnan/prime-rl/pull/2) — Phase 3 conversion-registry extensions (compile_target + cutlass_fp8) -- [`ai-dynamo/modelexpress#349`](https://github.com/ai-dynamo/modelexpress/pull/349) — Phase 3 + 4 + `MxWeightTransferEngine` adapter (v2 fat clients, multi-source slice picker, vLLM API adapter) -- [vLLM PR #43375](https://github.com/vllm-project/vllm/pull/43375) — Anyscale Ray Direct Transport; same `WeightTransferEngine` API shape, complementary Ray-based catalog choice -- [vLLM native RL APIs blog](https://blog.vllm.ai/2026/05/28/native-rl-apis.html) — the upstream API surface this RFC targets -- [TensorHub paper (arXiv 2604.09107)](https://arxiv.org/pdf/2604.09107v1) — Reference-Oriented Storage, pipeline replication, mutability contract -- [`post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) — the four-phase plan this RFC consolidates diff --git a/docs/proposals/post-pr2389-status-and-plan.md b/docs/proposals/post-pr2389-status-and-plan.md deleted file mode 100644 index f1c838f2fc..0000000000 --- a/docs/proposals/post-pr2389-status-and-plan.md +++ /dev/null @@ -1,232 +0,0 @@ -# prime-rl × ModelExpress — Status of the post-#2389 work and how it addresses the MoE / kernel / quant pain points - -> **Related docs in this directory**: -> - [`post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) — the full RFC with phase-by-phase design rationale -> - [`build-notes-2026-05-28.md`](./build-notes-2026-05-28.md) — image-build experience, cluster observations, vLLM native RL APIs reframing -> -> This doc is the executive summary: where prime-rl × MX stands today, what the failure classes are, and where in the follow-up plan each one gets resolved. - -**Scope**: status of the ModelExpress + NIXL weight-refit integration in prime-rl, the three classes of failure observed on MoE models (kernel-target mismatches, fused-vs-unfused gates across kernels, quantization + packing mismatches), and the four-phase follow-up plan that resolves them. - -**TL;DR**: [PR #2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) lands the first cut of MX-driven NIXL refit in prime-rl and runs at ~10s/cycle on GB200. The three failure classes reduce to one root cause — prime-rl's weight-conversion registry has only two entries, and there is no `compile_target` tag on published bytes that distinguishes DeepGemm vs Cutlass vs other layouts. Two follow-up PRs are in flight that fix this without changing prime-rl's transfer architecture, plus a ~80-LOC-per-kernel extension that can land independently of either PR. - ---- - -## 1. Current state - -### 1.1 PR #2389 — what it does - -[PR #2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) adds a third weight-broadcast type to prime-rl: `weight_broadcast.type = "nixl_mx"`. The data flow per refit cycle is: - -1. **Trainer side** — for each step's state-dict tensor, run a *trainer-side conversion* (today: `bf16_cast` or `fp8_128x128`) that produces the destination layout vLLM expects, with fusion (e.g. `q_proj/k_proj/v_proj → qkv_proj`) and quant already applied. Land the bytes into pre-allocated NIXL-registered buffers (`ShardedSlot` / `GatheredSlot` / `ExpertSlot`). -2. **Rendezvous** — register with ModelExpress (MX) so the inference side knows where to pull from. Lightweight — just metadata; the bytes never go through MX. -3. **NIXL RDMA push** — trainer writes directly into the inference workers' pre-registered parameter buffers over RDMA. Sub-second for a 30 GB model on GB200 cross-node. -4. **Inference side** — vLLM workers receive a `/update_weights` HTTP from the orchestrator, synchronize, and resume serving. - -### 1.2 Validation status - -Validation has run on a GB200 cluster: - -| Component | State | -|---|---| -| Workload | Qwen3-30B-A3B-Instruct-2507, FSDP 2×2, EP=4, 32/128 experts per rank, FLASHINFER attention | -| Image | `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.5.2` | -| Steady-state refit | ~10 s/cycle, reward=1.0000, off-policy=0 | -| Errors | Zero NIXL errors (no `REMOTE_DISCONNECT`, no `NOT_ALLOWED`, no stale-READY) | -| Tested cycles | 4/4 clean, plus 1 trainer restart cycle (clean exit on `max_steps`, k8s respawned, resumed cleanly) | - -PR #2389 is a real working baseline — not a prototype that falls over on first push. - -### 1.3 Two GB200-specific runtime patches folded in - -Two issues on the multi-NIC fabric required patches to the trainer's NIXL setup: - -- **Same-rank-only peer filter** — on GCP GB200 the four RDMA NICs (`rdma-0..3`) are separate L3 subnets, so trainer rank N can only safely peer with inference rank N. Cross-subnet pairs fail. -- **Freshest-per-rank dedup** — when MX has multiple entries at the same `worker_rank` (e.g. after a pod restart), the freshest entry by `updated_at` must be picked, not the first one. Otherwise the inference side picks a stale entry and the NIXL `add_remote_agent` call refuses with `NIXL_ERR_NOT_ALLOWED`. - -Both are applied today as a runtime monkey-patch in the configmap (`patch_nixl_mx.py`). The [Phase 2 draft PR](https://github.com/KavinKrishnan/prime-rl/pull/1) bakes them into the rendezvous class so they are no longer runtime shims. - -### 1.4 What PR #2389 does not yet do (where the failure classes live) - -Two specific gaps map directly to the observed failure classes: - -| Gap | Failure class | -|---|---| -| The conversion registry (`prime_rl/trainer/models/conversions/__init__.py`) has exactly **two entries**: `bf16_cast` and `fp8_128x128`. Anything else raises `NotImplementedError` at startup. | Quantization and packing mismatches — if the target inference engine wants DeepGemm K-major scale interleaving, Cutlass FP8, MXFP4, a non-128 block size, etc., prime-rl cannot produce those bytes at all. | -| There is no `compile_target` tag on published tensors. Receivers cannot tell whether the bytes were compiled for DeepGemm, Cutlass, or anything else. | MoE-specific corruption + fused-vs-unfused-gate mismatches — when trainer and receiver disagree on the layout, the result is silent byte misinterpretation instead of a clean error. The receiver loads the wrong layout and rollouts are corrupted, often without an obvious crash. | - -These are the same root cause expressed three ways: *prime-rl currently assumes the trainer and inference side agree on one layout, baked in at config time, with no runtime check*. Heterogeneous fleets or kernel-target disagreements break it silently. - ---- - -## 2. The follow-up effort - -The follow-up work is captured in the [post-#2389 RFC](https://github.com/KavinKrishnan/prime-rl/blob/kavink/post-2389-kernel-compile-plan/docs/proposals/post-pr2389-kernel-compile-plan.md). It has four phases and a clear strategic direction. - -### Strategic direction - -The design stays on the **trainer-side post-processed transfer** model: the trainer compiles weights into the destination kernel's exact layout, RDMA-writes them straight into the inference worker's pre-registered parameter buffers, and the inference side does no compute on receipt. This is what prime-rl does today; the `ConversionSpec` / `ShardedSlot` / `ExpertSlot` / `GatheredSlot` structure is the right shape. - -The design does *not* move to a scratch-buffer + `model.load_weights()` pattern. The internal NemoRL + Dynamo integration prototype uses that pattern as a workaround for a vLLM-specific quirk (`stacked_params_mapping` does HF→fused name remapping at load time, so the trainer can publish HF-raw names and the receiver figures it out). It works, but it adds receive-side compute (~50-200 ms per refit on FP8 casts) and is not needed for prime-rl, which already does fusion + quant trainer-side and writes directly into pre-allocated fused buffers. - -### Four phases - -| Phase | What | Where | Status | -|---|---|---|---| -| **Phase 2** | Bake the two GB200 runtime patches into `MxRendezvous` permanently; add a `HeartbeatThread` so crashed workers don't leave stale READY entries; swap the in-tree rendezvous over to the published v2 ModelExpress client. | [`KavinKrishnan/prime-rl#1`](https://github.com/KavinKrishnan/prime-rl/pull/1) | Draft PR open, 11/11 unit tests green | -| **Phase 3a** | Add `compile_target` and `compile_metadata` fields to ModelExpress's `TensorDescriptorV2`. Trainer stamps every publish with what layout it produced — `deep_gemm_fp8`, `cutlass_fp8`, `block_size=128`, `gate_fusion=gate_up_swiglu`, etc. | [`ai-dynamo/modelexpress#349`](https://github.com/ai-dynamo/modelexpress/pull/349) | Draft PR open, 6/6 unit tests green | -| **Phase 3b** | Receivers filter sources by `compile_target` + `compile_metadata` *before* RDMA. Mismatched bytes are refused at discovery; no more silent corruption. | Same PR | 4/4 unit tests green | -| **Phase 4** | Multi-source slice picker for the mixed-TP case (e.g. trainer TP=4 publishing to inference TP=8). Each receiver discovers which subset of publisher ranks covers its slice. | Same PR | 8/8 unit tests green | -| **Phase 0 (parallel)** | prime-rl-only: extend the conversion registry from two entries to N. **No MX dependency. Can land independently of the four phases above.** | Open work (see §4 below) | — | - -### How each phase maps to the three failure classes - -| Failure class | What fixes it | -|---|---| -| Quantization and packing issues (`NotImplementedError`) | **Phase 0**: extend `prime_rl/trainer/models/conversions/` with the kernel-specific layouts required. ~80 LOC per kernel. | -| MoE expert layout differences across kernels | **Phase 0** (new `ConversionSpec` per (model × kernel)) + **Phase 3a/3b** to surface mismatches as clean errors. | -| Some gates are fused vs not | **Phase 0** (per-(model × kernel) `ConversionSpec` defining `sources` + `cat_dim` differently) + **Phase 3a/3b** so the trainer's `compile_metadata.gate_fusion` tag travels and is filtered on. | - -In all three cases, **Phase 0 is the immediate unblock** — Phase 3 is the safety net that prevents silent corruption when there is a mismatch. - ---- - -## 3. The architecture, with what's new highlighted - -```mermaid -flowchart TB - classDef today fill:#e6f3ff,stroke:#0066cc,color:#000 - classDef phase0 fill:#fff4cc,stroke:#c68a00,color:#000,stroke-width:2px - classDef phase2 fill:#d4edda,stroke:#28a745,color:#000,stroke-width:2px - classDef phase3 fill:#f8d7da,stroke:#dc3545,color:#000,stroke-width:2px - classDef physical fill:#eee,stroke:#666,color:#000 - - subgraph trainer["Trainer pod — FSDP 2×2 / EP=4"] - TSD[("HF state_dict
(bf16, unfused)")] - TCONV["ConversionSpec registry
Today: bf16_cast, fp8_128x128
Phase 0: + deep_gemm_fp8, cutlass_fp8,
+ MoE fusion variants per model"] - TSLOTS["Slots: ShardedSlot · GatheredSlot · ExpertSlot
(pre-allocated, NIXL-registered)"] - TRENDZ["MxRendezvous
Today: thin wrapper over MxClient
Phase 2: swap to MxV2TrainingPublisher
+ heartbeat + same-rank + freshest-dedup"] - TTAG["Phase 3a: stamp every publish with
compile_target + compile_metadata
(rides on TensorDescriptorV2)"] - TNIXL["NIXL agent (UCX rc_mlx5)"] - TSD --> TCONV --> TSLOTS --> TNIXL - TSLOTS -.metadata.-> TRENDZ - TRENDZ -.-> TTAG - end - - subgraph mx["ModelExpress server (control plane only — no weight bytes)"] - MXSVR[("gRPC + Redis catalog
(workers, status, descriptors, tags)")] - end - - subgraph inference["Inference pod — vLLM v1, 4 ranks, EP=4"] - IRENDZ["MxRendezvous / MxV2RefitReceiver"] - IFILT["Phase 3b: discover_v2_sources(
compile_target_filter={…},
required_compile_metadata={…})
Mismatched bytes refused BEFORE RDMA"] - IPICK["Phase 4: discover_v2_sources_for_slice(…)
multi-source picker for mixed-TP"] - IPARAMS["vLLM model.named_parameters()
(pre-registered, fused destinations like qkv_proj)"] - INIXL["NIXL agent (UCX rc_mlx5)"] - IRENDZ --> IFILT --> IPICK - INIXL --> IPARAMS - end - - TNIXL ==>|"RDMA WRITE
(post-processed bytes,
~10s for 30GB Qwen3-30B-A3B)"| INIXL - TRENDZ <-.->|"publish metadata + tags"| MXSVR - MXSVR <-.->|"discover + filter"| IRENDZ - - class TSD,TSLOTS,IPARAMS today - class TNIXL,INIXL,TCONV today - class TCONV phase0 - class TRENDZ,IRENDZ phase2 - class TTAG,IFILT,IPICK phase3 - class MXSVR physical -``` - -Legend: 🟦 today's PR #2389 surface, 🟨 Phase 0 (immediate unblock), 🟩 Phase 2, 🟥 Phase 3 / 4. - -The data plane (RDMA write of post-processed bytes from trainer NIC to inference NIC) stays exactly as it is today. Everything new lives in the **metadata plane** (what gets stamped on the publish) and the **registry** (what conversions are available trainer-side). - ---- - -## 4. Phase 0 — extending the conversion registry - -This is independent of both PRs in flight. The conversion registry is plug-in: - -``` -prime_rl/trainer/models/conversions/ -├── __init__.py # registry, select_default_conversion() -├── bf16_cast.py # registered as "bf16_cast" ← today -└── fp8_blockwise.py # registered as "fp8_128x128" ← today -``` - -To add support for a new kernel layout, drop in a new file: - -```python -# prime_rl/trainer/models/conversions/deep_gemm_fp8.py -from prime_rl.trainer.models.conversions import register - -def _convert(src, dst, scale): - # src: bf16 HF-format tensor - # dst: pre-allocated FP8 destination buffer (kernel's expected layout) - # scale: paired scale buffer if requires_scale=True - ... - -register("deep_gemm_fp8", _convert, requires_scale=True) -``` - -The default-conversion picker in `__init__.py:select_default_conversion` reads the inference model's `config.json` and chooses based on `quantization_config`. Today it knows two cases; extend the if/elif chain for new ones. - -**For MoE specifically — the gate-fusion choice lives in the `ConversionSpec` definitions per model.** Example pattern in `prime_rl/trainer/models/qwen3_moe/converting_qwen3_moe.py`: - -```python -ConversionSpec( - dst="mlp.experts.gate_up_proj.weight", # fused destination - sources=("mlp.experts.gate_proj.weight", # source 1 - "mlp.experts.up_proj.weight"), # source 2 - cat_dim=0, # concat along dim 0 -) -``` - -For an unfused-gate kernel target, a parallel set of `ConversionSpec`s would emit `gate_proj` and `up_proj` separately. The framework supports multiple `ConversionSpec` sets via the `Conversion` selector in `MaybeQuantize`. - -Total scope per kernel target: ~80 LOC for the conversion function + per-model `ConversionSpec` additions for the required models. - -Once Phase 3a lands in MX and Phase 2 lands in prime-rl, every publish automatically carries the `compile_target` tag (e.g. `"deep_gemm_fp8"`) and `compile_metadata` (e.g. `{"block_size": 128, "scale_layout": "K-major", "gate_fusion": "unfused"}`), and any inference worker expecting a different layout refuses the source cleanly at discovery instead of corrupting rollouts. - ---- - -## 5. Information required to write the missing conversion entries - -Two inputs are required per kernel target before the missing conversion entries can be written: - -1. **Inference engine + quant config target**. Examples: "vLLM 0.7 + DeepGemm grouped-GEMM MoE FP8 block-128", "vLLM 0.7 + Cutlass MoE FP8 with K-major scales", "Triton MoE with unfused gates and BF16". The exact kernel name + scale layout determines the required output format. -2. **Model architectures involved**. Existing conversion specs cover Qwen3, Qwen3-MoE, GLM-MoE-DSA, Nemotron-H, MiniMax-M2, and Laguna. Any model not yet in `prime_rl/trainer/models/` requires an additional ~150 LOC for the model adapter. - ---- - -## 6. References - -**Branches and PRs** - -- Upstream: [PrimeIntellect-ai/prime-rl#2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) -- Phase 2 draft: [KavinKrishnan/prime-rl#1](https://github.com/KavinKrishnan/prime-rl/pull/1) — `MxRendezvous` heartbeat + dedup + same-rank filter -- Phase 3+4 draft: [ai-dynamo/modelexpress#349](https://github.com/ai-dynamo/modelexpress/pull/349) — `compile_target` + multi-source slice picker -- RFC: [`KavinKrishnan/prime-rl:kavink/post-2389-kernel-compile-plan`](https://github.com/KavinKrishnan/prime-rl/blob/kavink/post-2389-kernel-compile-plan/docs/proposals/post-pr2389-kernel-compile-plan.md) — the full post-#2389 plan -- Build notes: [`KavinKrishnan/prime-rl:kavink/post-2389-kernel-compile-plan/docs/proposals/build-notes-2026-05-28.md`](https://github.com/KavinKrishnan/prime-rl/blob/kavink/post-2389-kernel-compile-plan/docs/proposals/build-notes-2026-05-28.md) — image-build experience, cluster observations, and the vLLM-native-RL-APIs reframing -- Inline review on upstream PR: six inline + one summary, posted at PR HEAD `79ea824d8` — covering cross-subnet `add_remote_agent`, freshest-per-rank dedup, missing heartbeat, hardcoded 1200s timeout, unconditional `update_mla_absorbed_weights`, HSDP barrier ordering - -**Code locations in prime-rl** - -- Conversion registry: `src/prime_rl/trainer/models/conversions/__init__.py` -- Existing conversions: `src/prime_rl/trainer/models/conversions/{bf16_cast,fp8_blockwise}.py` -- `ConversionSpec` definition: `src/prime_rl/trainer/models/conversion_spec.py` -- Per-model `ConversionSpec` registrations: `src/prime_rl/trainer/models//converting_.py` -- Slots: `src/prime_rl/trainer/models/slots.py` -- NIXL+MX trainer broadcast: `src/prime_rl/trainer/rl/broadcast/nixl_mx.py` -- NIXL+MX inference worker: `src/prime_rl/inference/vllm/worker/nixl_mx.py` -- Rendezvous: `src/prime_rl/transport/mx_rendezvous.py` -- Transport plan: `src/prime_rl/transport/transport_plan.py` - -**Cluster** - -- GB200 dev cluster, dedicated namespace -- Image (today): `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.5.2` -- Source-baked Phase 2 + Phase 3 image: `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3` -- ModelExpress server: deployed in the same namespace, Redis-backed