diff --git a/docs/proposals/build-notes-2026-05-28.md b/docs/proposals/build-notes-2026-05-28.md new file mode 100644 index 0000000000..8f0ee37443 --- /dev/null +++ b/docs/proposals/build-notes-2026-05-28.md @@ -0,0 +1,238 @@ +# Build notes — Phase-2 + Phase-3 source-baked image (2026-05-28) + +> **Related docs in this directory**: +> - [`post-pr2389-status-and-plan.md`](./post-pr2389-status-and-plan.md) — executive summary of where things stand + failure-class → fix mapping +> - [`post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) — the full RFC with phase-by-phase design rationale +> +> This doc is the operational findings: how the source-baked image was built, what broke, what we learned, and how the upstream vLLM native RL APIs reframe everything. + +**Status**: empirical findings from baking Phase 2 (rendezvous fixes) + Phase 3 (conversion-registry extensions) into an ARM64 GB200 image and running it against a live GB200 cluster. Updates the RFC's framing where the build experience contradicted assumptions in the original RFC. + +This document captures **what we learned producing a usable image** containing the two follow-up PRs ([phase-2 rendezvous fixes](https://github.com/KavinKrishnan/prime-rl/pull/1) and [conversion-registry extensions](https://github.com/KavinKrishnan/prime-rl/pull/2)) on top of PR #2389 (HEAD `79ea824d8`). The unit tests for both PRs were already green; this doc records the cluster + image surface area that the unit tests don't cover. + +## 1. What we built + +Two images, in order: + +| Tag | Base | What's added | Status | +|---|---|---|---| +| `prime-rl-mx-on-nixl:v0.7.0-kavin-phase2-phase3` | `nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04` (full Dockerfile.cuda rebuild) | Phase 2 + Phase 3 source merged in. Built from `kavink/post-2389-image-build-2026-05-28` (which merges PR #1 + PR #2 on top of `79ea824d8`). | **Pushed to nvcr** | +| `prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3` | `v0.7.0-kavin-phase2-phase3` | Adds the `disagg` extra (modelexpress + nixl-cu12 + vllm-router). Fixes the import error from v0.7.0. | **Pushed to nvcr** | + +`v0.7.1` is the one to deploy. `v0.7.0` is kept as a reference of the from-scratch build artifact. + +## 2. Build mechanics (ARM64 GB200 / QEMU) + +The from-scratch ARM64 build of `v0.7.0` took **6h 45min on x86 host with QEMU arm64 emulation** (buildkit `multi-arch` builder). Breakdown: + +| Stage | Time | Notes | +|---|---|---| +| Pull `nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04` ARM64 base | ~5 min | First time only; cached after | +| `apt-get install` builder + final stages | ~3 min total | Both stages of multi-stage Dockerfile | +| `COPY src/ + packages/ + deps/` | seconds | Trivial | +| **`uv sync --extra ... --locked --no-dev`** | **45 min** | Resolves + downloads + installs ~350 packages including torch 2.7+cu130 (~5 GB), nvidia-cudnn-cu12 (738 MiB), flashinfer-cubin (large), tilelang, xgrammar, vllm 0.21.0+cu129 etc. Under QEMU emulation. | +| **`docker-arm64-post-install.sh` (flash-attn from source for sm_100 / GB200)** | **~3h 45min** | 73 CUDA kernel `.o` files, each compiled via emulated `nvcc` for sm_80 + sm_90 + sm_100. Most expensive kernels are `hdim192_bf16_causal` and `hdim256_bf16` for backward pass (15-40 min each). | +| Final stage `COPY --from=builder /app` + image export | ~7 min | 15.9 GB final image, 6.5 GB of which is one big layer (the venv) | + +`v0.7.1` overlay on top of `v0.7.0` was **~3 min** (the `uv sync` with `disagg` extra reuses every cached layer except the new modelexpress/nixl-cu12/vllm-router wheels). + +**Practical implication**: every meaningful rebuild from the Dockerfile.cuda base is ~7 hours on a non-ARM host. Use overlay Dockerfiles for additive changes. Reserve from-scratch only for `pyproject.toml` / `uv.lock` updates or major source restructuring. + +## 3. Three real issues the build surfaced that aren't in the RFC + +### 3.1 `Dockerfile.cuda` is missing `--extra disagg` for nixl_mx use + +[`Dockerfile.cuda`](../../Dockerfile.cuda) line 52: + +```dockerfile +RUN --mount=type=cache,target=/app/.cache/uv \ + uv sync --extra flash-attn --extra flash-attn-3 --extra flash-attn-cute --extra envs --extra gpt-oss --group mamba-ssm --locked --no-dev +``` + +The `disagg` extra ([`pyproject.toml` line 90](../../pyproject.toml#L90)) contains: + +```toml +disagg = [ + "deep-ep ; platform_machine == 'x86_64'", + "deep-gemm ; platform_machine == 'x86_64'", + "nixl", + "nixl-cu12 ; platform_machine == 'x86_64'", + "vllm-router ; platform_machine == 'x86_64'", + "modelexpress", +] +``` + +Without it, **`modelexpress` is not installed**, and the inference worker crashes at the first import of `prime_rl.inference.vllm.worker.nixl_mx`: + +``` +File "/app/src/prime_rl/inference/vllm/worker/nixl_mx.py", line 7, in + from modelexpress import p2p_pb2 +ModuleNotFoundError: No module named 'modelexpress' +``` + +The pre-PR-#2389 `Dockerfile.cuda` predates the `disagg` extra so this is an accidental gap, not an intentional opt-out. **Suggested change**: add `--extra disagg` (or rely on `--extra all`) for any image targeting `weight_broadcast.type=nixl_mx`. We've shipped `v0.7.1` as a one-line overlay that does this until the change can land in `Dockerfile.cuda` itself. + +### 3.2 `LD_PRELOAD` path for libcudart.so.12 moved + +The existing configmap's three run-scripts (`run_trainer.sh`, `run_inference.sh`, `run_orchestrator.sh`) all preload libcudart for ARM64 NIXL compatibility: + +```bash +export LD_PRELOAD="/usr/local/cuda/lib64/libcudart.so.12:${LD_PRELOAD:-}" +``` + +`/usr/local/cuda` exists in the v0.5.2 image (which appears to have been built from a Dockerfile variant that retained the CUDA tooling in the final stage). In `v0.7.0` (built from the upstream `Dockerfile.cuda` as-is), the final stage is `python:3.12-slim` which **does not** have `/usr/local/cuda`. `libcudart.so.12` lives only inside the pip-installed `nvidia-cuda-runtime` wheel: + +``` +/app/.venv/lib/python3.12/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 +``` + +Symptom on v0.7.0 with the unmodified configmap: + +``` +ERROR: ld.so: object '/usr/local/cuda/lib64/libcudart.so.12' from LD_PRELOAD cannot be preloaded +``` + +**Fix applied**: the three run-scripts now use the wheel-internal path. Alternative: symlink `/usr/local/cuda/lib64/libcudart.so.12 -> /app/.venv/.../libcudart.so.12` in the Dockerfile's final stage. Either works; we picked the env-var path because it's a configmap edit, no image rebuild. + +### 3.3 The configmap `patch_nixl_mx.py` and Phase 2 source coexist + +The kavin namespace runs a configmap-injected monkeypatch at container start (`patch_nixl_mx.py`) that rewrites `src/prime_rl/trainer/rl/broadcast/nixl_mx.py` to add same-rank-only peer filter + freshest-per-rank dedup *at TransportPlan construction time*. + +Phase 2 ([PR #1](https://github.com/KavinKrishnan/prime-rl/pull/1)) adds the same semantic guarantees but at a different layer — inside `src/prime_rl/transport/mx_rendezvous.py:wait_for_all_peers_ready`. The two patches are **complementary, not redundant**: + +| Code path | Bug class | Covered by | +|---|---|---| +| `trainer/rl/broadcast/nixl_mx.py:lazy_init` → `TransportPlan(peer_metadata=…)` | Trainer adds dead peers as NIXL remote agents during the per-step broadcast | `patch_nixl_mx.py` (runtime monkeypatch) | +| `transport/mx_rendezvous.py:wait_for_all_peers_ready(role="trainer")` | Orchestrator counts historical trainer entries in Redis and times out waiting for `n_historical` to all reach READY when only `n_alive` exist | Phase 2 PR (source-level) | + +On v0.7.1 + the existing configmap, both fire. The trainer log shows `[patch_nixl_mx] PATCHED v2 (kavin_freshest_per_rank)` from the configmap script; the rendezvous wait methods get the Phase 2 dedup automatically because the source is in the baked image. Empirically the orchestrator restart pattern we saw on v0.5.2 (~once per 30-66 min on this workload) should go away on v0.7.1. **Validation pending** — image just deployed at time of writing. + +When PR #1 merges upstream, the configmap monkeypatch becomes redundant for the trainer-side path too and should be removed. Until then, both layers complement each other. + +## 4. Cluster observations under v0.5.2 + configmap monkeypatch + +For the record, the v0.5.2 + configmap-monkeypatch combination we ran for 8+ hours before v0.7.1 deploy: + +- Workload: Qwen3-30B-A3B-Instruct-2507, FSDP 2×2, EP=4 (32/128 experts per rank), FLASHINFER attention, gsm8k env +- Trainer steady state: ~10–21 s/step (varies with sequence length 280–500 tokens) +- Reward signal: variance 0.5–1.0 per orchestrator step — **real learning gradient**, not just reward=1.0 collapse +- Off-policy level: 0 across all observed steps (in-lockstep refit) +- Best uninterrupted window: **183 successful RL refit cycles over 66 min** between orchestrator restarts +- Zero NIXL data-plane errors (no `REMOTE_DISCONNECT`, no `NOT_ALLOWED`, no stale-READY) — confirms the same-rank-only + freshest-per-rank patches are correct +- Recurring orchestrator timeout pattern: `TimeoutError: timed out after 1200.0s waiting for 12 'trainer' peers to reach status 1 (saw 4)` — exactly what Phase 2's rendezvous-level dedup fixes + +That last bullet is the bug class v0.7.1 is meant to eliminate. The configmap monkeypatch couldn't fix it because the relevant call site is in the orchestrator's rendezvous, which is in a different module from the trainer-side broadcast the monkeypatch was rewriting. + +## 5. Branches + image artifacts pushed + +| Branch | What's in it | Where | +|---|---|---| +| [`kavink/post-2389-kernel-compile-plan`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-kernel-compile-plan) | RFC document + this build-notes doc | `KavinKrishnan/prime-rl` | +| [`kavink/post-2389-phase2-rendezvous-fixes`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-phase2-rendezvous-fixes) | Phase 2 source (heartbeat + dedup + same-rank), 11/11 unit tests green, plus the `modelexpress.heartbeat` module-path tolerance fix | [Draft PR #1](https://github.com/KavinKrishnan/prime-rl/pull/1) | +| [`kavink/post-2389-conversion-registry-extensions`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-conversion-registry-extensions) | Phase 3 conversion-registry extensions (`compile_target` + `compile_metadata` + `cutlass_fp8_e4m3_per_channel`), 19/19 unit tests green | [Draft PR #2](https://github.com/KavinKrishnan/prime-rl/pull/2) | +| [`kavink/post-2389-image-build-2026-05-28`](https://github.com/KavinKrishnan/prime-rl/tree/kavink/post-2389-image-build-2026-05-28) | Merge of Phase 2 + Phase 3 + the import-tolerance fix; this is the exact source tree v0.7.0 / v0.7.1 was built from | `KavinKrishnan/prime-rl` (this push) | + +Image artifacts on `nvcr.io/nvidian/dynamo-dev/`: + +- `prime-rl-mx-on-nixl:v0.7.0-kavin-phase2-phase3` — full from-scratch ARM64 build (broken — missing `disagg`) +- `prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3` — overlay that adds `disagg` extra + +MX side ([`ai-dynamo/modelexpress#349`](https://github.com/ai-dynamo/modelexpress/pull/349)) updated with the graduation glue commit that plumbs `ConversionEntry.compile_target` + `ConversionEntry.compile_metadata` through `MxV2TrainingPublisher.add_tensor(compile_target=…, compile_metadata=…)`. Wire round-trip is unit-tested. + +## 6. What to update in the RFC (`post-pr2389-kernel-compile-plan.md`) — but not yet + +These are the four edits queued in [`pensieve/RL/PrimeRL/09_rfc_updates_needed.md`](https://github.com/ai-dynamo/modelexpress/) (internal), augmented by what we learned from the build: + +1. **Reframe Phase 3** — trainer-side post-processed direct is primary; receiver-side compile passes are v4+ (scratch buffers are a fallback only, not the primary v3 design as the original RFC implied). +2. **Add Phase 0** — Phase B UCX/dma-buf env profile as cluster prerequisite (`UCX_TLS=rc,cuda_copy`, `NIXL_UCX_TLS=rc,cuda_copy`, `UCX_CUDA_COPY_DMABUF=yes`, etc.) — from NeMo-RL + Dynamo's empirical 380 Gbps validation. +3. **Mark Phases 2/3/4 as shipped, not paper** (with PR + commit references). +4. **Add a sub-section on conversion registry extensions** documenting the `ConversionEntry` schema extension + how to add a new kernel (~80 LOC per kernel). + +**New from this build experience**: + +5. **Document the `disagg` extra requirement in §0** alongside the env profile — easy gotcha that costs an entire rebuild to discover. +6. **Document the `LD_PRELOAD` path** for libcudart in §0 — pre-existing run-scripts assumed v0.5.2's `/usr/local/cuda` layout. +7. **§4 / §5 on the "fallback path"** — describe vLLM PR #43375 (Ray Direct Transport, Anyscale) as the canonical receiver-pull-via-load_weights instance of the fallback path. Our positioning of trainer-side post-processed direct as "primary path, zero receive-side compute" is unchanged; RDT is the upstream-stamped instance of the alternative receiver-pull path. + +## 7. Open follow-ups + +- **Validate v0.7.1 end-to-end** on kavin (pending — v0.7.1 just deployed). Expected: zero NIXL errors AND zero orchestrator-`wait_for_all_peers_ready` timeouts. If both hold for a long uninterrupted window (>3 hours), we declare the source-baked Phase 2 + Phase 3 production-ready. +- **Send a one-line PR to upstream `Dockerfile.cuda`** adding `--extra disagg` (or `--extra all`). Tiny patch, unblocks every other team trying to bake `nixl_mx` mode into an image. +- **MX side: roll the server with the `SourceIdentity` round-trip fix** (proto change committed but not deployed). After that, the `__mx_v2_meta__` sidecar transport workaround can be dropped from `MxV2TrainingPublisher` (it's already filtered before NIXL register via [PR #295](https://github.com/ai-dynamo/modelexpress/pull/295)). +- **`pull_one(name)` semantic on MX** — inspired by vLLM PR #43375's RDT contract. Would let MX expose Ray-like per-tensor elasticity without abandoning the trainer-side compile model. ~50 LOC; not on the critical path but a clean addition for the post-Phase-4 work. + +## 8. Strategic update — vLLM published native RL APIs (2026-05-28) + +Same day as this doc, the vLLM team published [Native RL APIs in vLLM](https://vllm.ai/blog/2026-05-28-native-rl-apis), announcing a standardized `WeightTransferEngine` abstract base + four-phase lifecycle (`init_weight_transfer_engine` / `start_weight_update` / `update_weights` / `finish_weight_update`) and a pluggable `WeightTransferEngineFactory.register_engine(...)` extension point. Existing built-in backends: NCCL (packed broadcast), IPC (CUDA shared mem). PR [#43375](https://github.com/vllm-project/vllm/pull/43375) (Anyscale "RDT") plugs Ray Direct Transport into the same factory. + +Three things this changes about our post-PR-#2389 plan: + +### 8.1 Phase 2 + Phase 3 + Phase 4 have a clean upstream form: `MxWeightTransferEngine` + +The native API gives us a standard surface to plug MX into vLLM. The upstream-friendly shape of all our follow-up work is a single adapter class: + +```python +class MxWeightTransferEngine(WeightTransferEngine): + init_info_cls = MxInitInfo # mx_server_url, model_name, worker_rank, ... + update_info_cls = MxUpdateInfo # version, compile_target_filter, target_tp_layout, ... + + def init_transfer_engine(self, init_info: MxInitInfo): + self._receiver = MxV2RefitReceiver(...) + + def receive_weights(self, update_info, load_weights): + plan = self._receiver.discover_v2_sources_for_slice( # Phase 4 + model_name=update_info.model_name, + target_layout=update_info.target_tp_layout, + compile_target_filter=update_info.compile_target_filter, # Phase 3b + required_compile_metadata=update_info.required_compile_metadata, + ) + for name, tensor in self._receiver.receive_via_plan(plan): # Phase 4 stitch + load_weights([(name, tensor)]) + + @classmethod + def trainer_send_weights(cls, iterator, trainer_args): + # wraps MxV2TrainingPublisher.add_tensor(compile_target=..., compile_metadata=...) + .publish() + ... +``` + +Registers via `WeightTransferEngineFactory.register_engine("mx_nixl", MxWeightTransferEngine)`. Estimated ~150-200 LOC adapter on top of the MX clients we already have. + +The in-tree `MxRendezvous` reimplementation in PR #2389 + the `worker_extension_cls` injection in `inference/vllm/worker/nixl_mx.py` both **become unnecessary once this lands upstream**. The native API is the integration seam they were emulating; our adapter consumes it directly. + +### 8.2 Matej is acknowledged in the blog — coordination needed before pushing Phase 2 upstream + +The blog credits *"Prime-RL team (especially Matej Sirovatka) and Junjie Zhang for helping to validate and debug the RL APIs with large-scale runs"*. Matej is actively involved in the design. + +This affects the trajectory of [draft PR #1](https://github.com/KavinKrishnan/prime-rl/pull/1) (Phase 2). The semantic fix it ships is correct regardless of which integration path prime-rl converges on, but the *form* differs: + +- **If Matej is mid-flight on a native-APIs rewrite of `nixl_mx`**: our Phase 2 PR retargets to that path (becomes part of the `MxWeightTransferEngine` adapter rather than landing in the in-tree `MxRendezvous` class). +- **If Matej hasn't started**: Phase 2 lands as-is, the rewrite happens later as a separate PR. + +Either way, ask Matej first. + +### 8.3 Validation in the blog is at DPEP32 across 16 nodes — Phase 4 becomes load-bearing + +The blog reports Prime-RL validated `zai-org/GLM-5.1-FP8` in P/D-disaggregated deployment across **16× 8xH200 nodes** (2 replicas of 4P+4D, both **DPEP32**) for 100+ steps with stable KL mismatch and upward RL curve. **256 GPUs total at DPEP32**. + +At this scale, mixed-TP / mixed-EP is the common case (trainer TP/EP layout almost never matches inference TP/EP layout), and Phase 4's `discover_v2_sources_for_slice` + multi-source `receive_via_plan` is the difference between "works" and "doesn't". This validates Phase 4's design direction and sets the next cluster validation target after kavin (DP=4). + +### 8.4 Two async-RL features the blog ships that we don't yet match + +| Feature | What | Our position | +|---|---|---| +| `pause_generation(mode="keep")` | Pause in-flight requests *without* aborting or waiting; resume from the partial-token state | Today we wait for rollouts to complete before refit. Adopting `keep` mode is the unlock for truly async RL. Queue as a follow-up after Phase 2 lands. | +| Two-phase pause/resume for DPEP | `EngineCore`-level pause state + periodic all-reduce coordination to prevent deadlocks in wide-DP deployments | Not applicable at DP=4 (kavin cluster). Mandatory once we scale to DPEP16+. Track for the next scale-up. | + +The blog also mentions a *"new K8s-native weight transfer engine"* and *"sharding-aware, RDMA-native weight transfer in a generic way"* as ongoing work in the vLLM RL community — both of which describe what MX already is. If MX isn't already the implementation they're referring to, reach out to Robert Shaw (acknowledged as organizing RL-related efforts) to coordinate. + +### 8.5 Updated follow-up list — was four items, now seven + +| # | Item | Effort | +|---|---|---| +| 1 | Validate v0.7.1 end-to-end on kavin (pending — image just deployed) | hours of soak | +| 2 | Send a one-line PR to upstream `Dockerfile.cuda` adding `--extra disagg` | <30 min | +| 3 | Roll the MX server with the `SourceIdentity` round-trip fix; deprecate the `__mx_v2_meta__` sidecar transport | ~1 day | +| 4 | Implement `pull_one(name)` semantic on MX for Ray/RDT-style per-tensor elasticity | ~50 LOC | +| 5 | **NEW**: Sketch + implement `MxWeightTransferEngine` as the upstream-PR form of Phase 2/3/4 | 150-200 LOC adapter | +| 6 | **NEW**: Adopt `pause_generation(mode="keep")` in prime-rl orchestrator for true async RL | ~50 LOC + tests | +| 7 | **NEW**: Coordinate with Matej + Robert Shaw + the vLLM RL roadmap on the K8s-native weight transfer engine. Either contribute MX as the canonical implementation or converge with whoever's already building it. | meeting + scoping | diff --git a/docs/proposals/post-pr2389-kernel-compile-plan.md b/docs/proposals/post-pr2389-kernel-compile-plan.md new file mode 100644 index 0000000000..78db06fa4e --- /dev/null +++ b/docs/proposals/post-pr2389-kernel-compile-plan.md @@ -0,0 +1,523 @@ +# Post-PR-#2389 plan: kernel-compile separation, mixed-TP, and MX client adoption + +> **Related docs in this directory**: +> - [`post-pr2389-status-and-plan.md`](./post-pr2389-status-and-plan.md) — executive summary of where things stand + failure-class → fix mapping +> - [`build-notes-2026-05-28.md`](./build-notes-2026-05-28.md) — image-build experience, cluster observations, vLLM native RL APIs reframing +> +> This doc is the deep-dive RFC with full phase-by-phase design rationale. + +**Status**: Planning doc. Branch `kavink/post-2389-kernel-compile-plan` (NVIDIA-authored, off PR [#2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) HEAD `79ea824d8`). +**Premise**: This plan is what we propose to build on top of `nixl_mx` once #2389 merges to `main`. It (a) graduates the in-tree `MxRendezvous` reimplementation onto NVIDIA's published ModelExpress clients, (b) introduces a compile-target registry to fix the trainer-side cutlass-pinning issue surfaced during #2389's FP8 cast-pipeline iteration, and (c) extends the v2 shape registry to handle mixed-TP / sharded-source transfers. **None of this fights the #2389 data plane** — the `Slot` / `TransportPlan` / `NixlAgentWrapper` / `classic_cuda_pool` stack stays untouched. We extend the rendezvous and metadata surfaces only. + +--- + +## 1. What we're building on + +After #2389 merges, prime-rl's `nixl_mx` mode has: + +- `src/prime_rl/transport/` — `NixlAgentWrapper`, `MxRendezvous` (PI's reimplementation of MX's upper layer), `TransportPlan`, `classic_cuda_pool`, `wire.py` (msgspec types). +- `src/prime_rl/trainer/models/slots.py` — `ShardedSlot`, `GatheredSlot`, `ExpertSlot` (the per-tensor buffer abstractions that hold registered NIXL memory and conversion logic). +- `src/prime_rl/trainer/models/conversions/` — bf16-cast and FP8-blockwise conversion specs (`bf16_cast.py`, `fp8_blockwise.py`). +- `src/prime_rl/trainer/rl/broadcast/nixl_mx.py` — `NIXLMxWeightBroadcast`, the lifecycle wrapper. +- `src/prime_rl/inference/vllm/worker/nixl_mx.py` — `NIXLMxWeightUpdateWorker`, the vLLM worker extension. + +PI's data plane (Slot / TransportPlan / NixlAgentWrapper / FP8 conversion specs) is **kept**. It's the conversion / compile / topology layer where this plan extends things. + +--- + +## 2. The problem Matej flagged: trainer-side kernel compile pins the topology + +In #2389 today, the conversion-cast pipeline (`fp8_blockwise.py`, `Slot.convert`, the recent `c3c4b148` "inline transfer slot casting" + `47e170f5` "simplify transfer cast conversions" refactors) runs **on the trainer**. The trainer's send bucket holds the **post-compile** layout: bf16 if the inference engine wants bf16, FP8 with DeepGemm scale interleaving if it wants DeepGemm, cutlass-friendly column-major + epilogue scales if it wants cutlass, etc. Inference RDMA-reads the bucket and copies straight into live params. + +This is fast — no receiver-side compute on the refit critical path — but it couples three independent decisions into the trainer's process: + +1. **Which kernel format does inference want?** Today the trainer has to know. Adding cutlass on GB300 means trainer code changes; adding DeepGemm-EP changes the slot layout. +2. **Mixed inference replicas with different kernels.** With trainer-side compile, one trainer can serve only one compile target per refit. Heterogeneous fleets (e.g. a/b testing DeepGemm vs cutlass on the same training run) need two parallel trainer buckets. +3. **Mixed TP/EP layouts.** A trainer at TP=4 publishing to inference at TP=8 needs to know the inference splits at *publish* time to write the right bucket entries. The information is in PI's `build_topology`, but it's tangled with the compile decision because the compile passes produce TP-specific layouts. + +When Matej said "we're having issues handling compiled kernels" — this is the source. The cutlass compile happens trainer-side; any inference replica that doesn't want exactly the trainer's compile target either gets bad bytes or fails the assertion check. + +--- + +## 3. Proposed shift: receiver-side compile via scratch buffers + +Move the compile pass back to the **inference side**. Trainer ships **canonical HF layout** over NIXL (raw bfloat16 or raw float8_e4m3fn + raw per-block scales — whatever format the trainer naturally has post-optimizer-step). Inference runs the kernel-specific compile pass after the RDMA receive completes, into its own live params. + +This is the same scratch-buffer pattern we already validated in two places: + +- Our **original PrimeRL PoC** (`KavinKrishnan/prime-rl:kavink/mx-weight-broadcast`) used scratch buffers explicitly to triangulate KL drift — proved correctness when the receiver decides the layout. +- **John Thompson's NemoRL + Dynamo path** (May 22, 2026) uses `MxRefitReceiver.receive_weights_scratch` to handle vLLM's HF→fused param remapping via `stacked_params_mapping`. Validated end-to-end on GB300 RoCE at **380 Gbps** for an 8.82 GB / 399-tensor refit — same scratch path we'd reuse here, just with kernel-compile transforms instead of name remapping. + +### Receiver-side refit pseudocode + +```python +# inference/vllm/worker/nixl_mx.py — after PR #2389 graduates to MX clients + +def update_weights_via_mx(self, *, version, mx_config): + # 1. Discover same-rank trainer source via MxV2RefitReceiver + candidates = self._mx_receiver.discover_v2_sources( + model_name=self.model_name, + min_version=version, + same_rank_only=mx_config.same_rank_only, + compile_target_filter=self._target_kernel, # NEW — see §5 + ) + chosen = self._mx_receiver.pick_best_source(candidates) + + # 2. RDMA pull into scratch buffers (HF layout — whatever trainer published) + scratch = {} + for name, tensor in self._mx_receiver.receive_weights_scratch( + chosen.ref, + tensor_shapes=chosen.registry.tensor_shapes, # global shapes from v2 sidecar + target_tp_layout=self._tp_layout, # NEW — mixed-TP slice request + ): + scratch[name] = tensor + + # 3. Run the kernel-specific compile pass into live model params + self._compile_pass.apply( + scratch_buffers=scratch, + live_params=dict(self.model_runner.model.named_parameters()), + ) + + # 4. Tree fan-out republish (TensorHub pipeline replication) — unchanged + self._mx_receiver.publish_self_as_source(...) +``` + +The compile pass is a small dispatch: + +```python +# inference/vllm/worker/compile_passes.py — NEW + +class CompilePass(ABC): + target: str + def apply(self, scratch_buffers, live_params): ... + +class HFRaw(CompilePass): + target = "hf_raw" + def apply(self, scratch, live): + # Direct copy — names + dtypes match. fast path. + for name, t in scratch.items(): + live[name].data.copy_(t) + +class DeepGemmFP8(CompilePass): + target = "deep_gemm_fp8" + def apply(self, scratch, live): + # K-major scale interleave, fused gate_up_proj packing + for name, t in scratch.items(): + interleaved = deep_gemm_layout(t, block_size=128) + live[fused_name(name)].data.copy_(interleaved) + +class CutlassFP8(CompilePass): + target = "cutlass_fp8" + def apply(self, scratch, live): + # Column-major weights + epilogue scale tensors + ... +``` + +### Cost vs benefit + +**Cost**: extra GPU memory for scratch (~1× model size briefly per refit, freed after compile) + compile latency on the inference side (~50-200ms for FP8 passes on Qwen3-30B-A3B). + +**Benefit**: +- Trainer is **kernel-agnostic** — same bucket bytes serve any inference target. +- Mixed-fleet OK — different inference replicas can run different compile passes on the same trainer publish. +- Adding a new kernel = adding a new `CompilePass` subclass on the inference side. **Zero trainer change.** +- Cross-TP / cross-EP layouts decouple cleanly from compile (see §6). + +This is the same trade John Thompson made for Dynamo and the same trade we made in the original PoC. Empirically the 50-200ms compile latency is dwarfed by the 200ms RDMA pull anyway; total wall time is unchanged. + +--- + +## 4. New primitive: compile-target registry (extension to v2 shape registry) + +Our v2 shape registry today encodes: + +``` +TensorDescriptorV2: + name, global_shape, dtype, placement_kind, shard_axis, local_shard_range, + is_expert, expert_axis, owned_expert_ids +``` + +Add **two fields** at the registry level (not per-tensor — these describe how the *publisher* prepared the data): + +```diff + RegistryPayload (JSON in __mx_v2_meta__ sidecar): + version: int + trainer_world_layout: str # "fsdp:4,tp:1,pp:1,ep:1" ++ compile_target: str # "hf_raw" | "deep_gemm_fp8" | "cutlass_fp8" | ... ++ compile_metadata: dict # kernel-specific params: ++ # {block_size: 128, scale_dtype: "float8_e8m0", ++ # layout: "row_major", ...} + tensors: list[TensorDescriptorV2] +``` + +### Trainer side + +The trainer declares what it published — typically `"hf_raw"` post-optimizer-step: + +```python +publisher = MxV2TrainingPublisher( + agent_name=..., + world_layout=TrainerWorldLayout(fsdp_world_size=4, tp_world_size=1), + compile_target="hf_raw", # NEW — declarative, no inference-side dependency +) +``` + +Specialized trainers that *do* want to bake a compile pass in for perf can declare it: + +```python +publisher = MxV2TrainingPublisher( + ..., + compile_target="deep_gemm_fp8", + compile_metadata={"block_size": 128, "scale_dtype": "float8_e8m0"}, +) +``` + +— and then only same-target inference replicas will accept that publish. Mixed-target inference replicas skip it and look for an `hf_raw` source. + +### Receiver side + +`discover_v2_sources(compile_target_filter=...)` filters candidate trainers: + +```python +candidates = receiver.discover_v2_sources( + model_name=..., + min_version=N, + same_rank_only=True, + compile_target_filter={"hf_raw"}, # accept HF-raw only, run compile myself +) +``` + +Or accept any compatible target (if the receiver has a fallback compile pass): + +```python +candidates = receiver.discover_v2_sources( + model_name=..., + compile_target_filter={"hf_raw", "deep_gemm_fp8"}, # accept either +) +``` + +The picker's existing trainer-vs-replica + freshest-per-rank sort applies on the filtered set. + +### Why this is in the v2 shape registry, not a new transport + +It's metadata about the wire format. The shape registry already travels via the synthetic `__mx_v2_meta__` `TensorDescriptor` sidecar (proven by jthomson04's GB300 run + protected by PR #295's filter). Adding two more JSON fields is zero-cost on the wire and keeps the discovery contract in one place. + +--- + +## 5. Sharding and mixed-TP transfers + +The hardest case is **trainer TP/EP layout ≠ inference TP/EP layout**. The shape registry already encodes this on the publish side: + +``` +placement_kind: "SHARD" +shard_axis: 0 +local_shard_range: (start, end) # this rank's slice along shard_axis +``` + +What's missing is the **receiver's expression of what it wants**. + +### New receiver-side API + +```python +class TargetTPLayout: + """What slice of the global tensor THIS receiver needs.""" + world_size: int # inference TP world size + rank: int # this receiver's TP rank + shard_axis: int # which axis we're sharded on (model-dependent) + +receiver.receive_weights_scratch( + chosen.ref, + target_tp_layout=TargetTPLayout(world_size=8, rank=3, shard_axis=0), + tensor_shapes=chosen.registry.tensor_shapes, +) +``` + +The receiver computes its desired slice from `target_tp_layout` and the published `placement_kind`: + +| Publisher | Receiver | Result | +|---|---|---| +| `REPLICATE` (trainer TP=1) | TP=8, rank=3, axis=0 | Receiver requests rows `[3N/8 : 4N/8]` of the publisher's tensor | +| `SHARD(0)` trainer TP=4 rank=2, range `[N/2 : 3N/4]` | TP=8, rank=4, axis=0, requests `[N/2 : 5N/8]` | Receiver pulls from same-physical-rank trainer (R2), takes lower half of R2's shard | +| `SHARD(0)` trainer TP=4 rank=2, range `[N/2 : 3N/4]` | TP=8, rank=5, axis=0, requests `[5N/8 : 6N/8]` | Receiver pulls from same trainer rank (R2), takes upper half of R2's shard | +| `SHARD(0)` trainer TP=4 rank=1, range `[N/4 : N/2]` | TP=2, rank=0, axis=0, requests `[0 : N/2]` | Receiver pulls from **both** trainer R0 + R1, concatenates | + +For cases where one inference rank needs slices from multiple trainer ranks (last row), the receiver picks **N candidates** instead of one: + +```python +multi_source = receiver.discover_v2_sources_for_slice( + model_name=..., + target_slice=(start, end), + shard_axis=0, +) +# Returns one SourceRef per trainer rank whose shard overlaps target_slice +# Receiver does N parallel RDMA pulls, concatenates in scratch +``` + +### Mixed-EP for MoE + +Same machinery, on the expert axis. NemoRL v2 already does this for `owned_expert_ids`: + +```python +candidates = receiver.discover_v2_sources( + model_name=..., + target_expert_ids_per_layer={ + 5: {0, 1, 2, 3}, # this inference rank's owned experts in layer 5 + 6: {0, 1, 2, 3}, + }, +) +``` + +The picker matches candidates whose `expert_owner_per_rank` covers the needed experts (existing logic in `MxV2RefitReceiver.pick_best_source` — see `nemo_rl_v2.py`). For mixed-EP (trainer EP=4, inference EP=8), receivers may pull from multiple trainer ranks via the same `discover_v2_sources_for_slice` pattern. + +### Why this matters for PrimeRL specifically + +PI's `ExpertSlot` (in `slots.py`) and `build_topology()` (in `nixl_checkpoint_engine`-style code) already implement TP-matched and EP-matched pairing on the trainer side. They're computing `peer_chunk_descs` based on the publisher's known topology. **What's missing is the inverse**: when the inference layout differs from the trainer's, the receiver needs to express that. The compile-target registry + the slice-discovery API give it that vocabulary. + +--- + +## 6. MX client adoption — layered with this plan + +The two-phase migration from our earlier review of #2389 (Phase 1 surgical / Phase 2 client adoption) is the foundation this plan builds on: + +### Phase 1 — surgical fixes against the `nixl_mx` in-tree code (drop-in patches) + +The 6 inline-comment fixes (line numbers verified against `nixl_mx` HEAD `79ea824d8`): + +1. Same-rank `add_remote_agent` filter in `transport_plan.py` +2. Freshest-per-rank dedup in `mx_rendezvous.py::wait_for_peers` +3. `HeartbeatThread` after `set_status(READY)` in `inference/.../nixl_mx.py` +4. Read timeout from config (not hardcoded 1200s) +5. MLA-guard for non-MLA models (`update_mla_absorbed_weights`) +6. HSDP barrier ordering in `trainer/.../nixl_mx.py` + +These land **before** this plan starts — closes the bug classes without architectural change. + +### Phase 2 — graduate the rendezvous half onto ModelExpress clients + +Delete `src/prime_rl/transport/mx_rendezvous.py` (~185 LOC, replicates functionality already in `modelexpress`). Replace with imports of `MxV2TrainingPublisher` and `MxV2RefitReceiver`. The in-tree `NixlAgentWrapper` + `Slot` + `TransportPlan` + `classic_cuda_pool` stay — that's prime-rl-specific data-plane specialization and shouldn't move. + +```diff +-from prime_rl.transport.mx_rendezvous import MxRendezvous ++from modelexpress import MxV2TrainingPublisher, MxV2RefitReceiver +``` + +This is what unblocks everything in §3-§5: + +- `MxV2TrainingPublisher` exposes the v2 sidecar registry that §4's `compile_target` extends. +- `MxV2RefitReceiver.receive_weights_scratch` is the proven path from John's Dynamo work (380 Gbps GB300). +- `discover_v2_sources(compile_target_filter=...)` is a small extension to the existing picker. +- Heartbeat / freshest-dedup / retention all come along for free — no separate Phase 1 work needed once Phase 2 lands. + +### Phase 3 (this plan) — compile-target registry + mixed-TP + +- Add `compile_target` + `compile_metadata` to v2 shape registry (~30 LOC in `shape_descriptors.py` + `nemo_rl_v2.py`). +- Add `compile_target_filter` to `discover_v2_sources` (~15 LOC). +- Add `target_tp_layout` + `discover_v2_sources_for_slice` to `MxV2RefitReceiver` (~120 LOC). +- Add `compile_passes/` module in `src/prime_rl/inference/vllm/worker/` with `HFRaw`, `DeepGemmFP8`, `CutlassFP8` passes (~300 LOC). Or in NemoRL `nemo_rl/models/generation/vllm/compile_passes/` — see §7. +- PI's `nixl_mx.py` inference worker calls into the right `CompilePass` based on `engine.kernel_target`. + +Total: ~450 LOC across MX + PrimeRL, all additive. + +--- + +## 7. What we borrow from John Thompson's NemoRL+Dynamo work + +Five specific pieces, all already proven on GB300 RoCE: + +### 7.1 `receive_weights_scratch` is the foundation + +John's path uses `MxRefitReceiver.receive_weights_scratch` because vLLM's `stacked_params_mapping` requires HF-named tensors that the receiver later passes to `model.load_weights()`. That's structurally identical to "trainer ships HF-raw, receiver compiles into kernel layout": + +```python +# John's existing flow (NemoRL + Dynamo + vLLM v1) +weights = list(receiver._receiver.receive_weights_scratch( + chosen.ref, + timeout_seconds=mx_config.timeout_seconds, + tensor_shapes=tensor_shapes, +)) +self.model_runner.model.load_weights(weights=weights) # vLLM does the HF→fused remap + +# Our extension (PrimeRL post-#2389 + kernel compile) +weights = list(receiver._receiver.receive_weights_scratch( + chosen.ref, + timeout_seconds=mx_config.timeout_seconds, + tensor_shapes=tensor_shapes, + target_tp_layout=self._target_tp_layout, # NEW +)) +self._compile_pass.apply(weights, live_params) # OUR compile dispatch +``` + +The mechanism is identical. Only the post-RDMA stage differs. + +### 7.2 The `worker_extension_cls` injection pattern is cleaner than subclassing + +John's `MxRefitWorkerExtension` (in `dynamo/vllm/mx_refit/extension.py`) is injected via vLLM v1's `parallel_config.worker_extension_cls`. The class has no `__init__`; vLLM merges its methods into the existing `Worker` via `__bases__`. State is stashed lazily on `self` with `_mx_` prefixed attribute names. + +PI's `NIXLMxWeightUpdateWorker` today **subclasses** `Worker` directly. The extension-class pattern would let the refit logic live in a sibling module without touching the inheritance chain — useful when we add the compile passes (§3) because those want to live in their own package. + +**Recommend**: when graduating PrimeRL to MX clients in Phase 2, also adopt the `worker_extension_cls` pattern for the inference worker. The two changes naturally compose. + +### 7.3 PR #295's sidecar filter is required for any new v2 metadata + +If we extend the v2 sidecar (§4) without keeping PR #295's filter in `MxRefitReceiver.receive_weights{,_scratch}`, the synthetic `__mx_v2_meta__` `TensorDescriptor` poisons `prep_xfer_dlist` again — same `NIXL_ERR_NOT_FOUND` John hit before May 22. **No new code needed; the filter is already in `kavink/nemo_rl_moe` HEAD `8594fd6`.** Just don't accidentally back it out. + +### 7.4 The `FORCE_RDMA=1` test mode catches this class of bug in loopback + +The v2 demo scripts (`scripts/v2_*_e2e_demo.py`) have a `FORCE_RDMA=1` env var (commit `e8e063b`) that pins `UCX_TLS` off `cuda_ipc` so intra-node loopback exercises the strict `rc_mlx5` descriptor-list validator. **Run every new compile-pass test under `FORCE_RDMA=1`** — otherwise we'll merge a sidecar / descriptor-list bug that doesn't show up until cross-node deploy. + +### 7.5 The compile-pass module probably belongs in NemoRL first, mirrored into prime-rl + +John's Dynamo path is the most mature target for testing new kernels (Qwen3-4B-Thinking GRPO smoke is already running cross-node at 380 Gbps). The compile passes themselves are framework-agnostic — they just take `(scratch_dict, live_params_dict)` and run torch ops. **Recommend**: + +1. Implement `compile_passes/` first in `modelexpress_client/python/modelexpress/compile_passes/` — framework-neutral, reusable. +2. NemoRL + Dynamo path adopts it via `MxRefitWorkerExtension._compile_pass = HFRaw()` (or `DeepGemmFP8()`). +3. Validate end-to-end on GB300 with cutlass + DeepGemm kernels. +4. Mirror into PrimeRL's inference worker after Phase 2 graduates them to MX clients. + +This sequence means we de-risk the compile-pass design on the path that's already shown working before we touch PrimeRL. Same play we ran for the v2 sidecar — designed in NemoRL, validated in NemoRL+Dynamo (John), then graduated to prime-rl. + +--- + +## 8. Implementation phases + +| Phase | Scope | Estimated LOC | Owner | +|---|---|---|---| +| **0** | Wait for #2389 to merge upstream | — | Matej | +| **1** | 6 surgical fixes against PI's `transport/*.py` + `inference/*.py` (closes bug classes) | ~100 LOC | Us — fast follow on Matej | +| **2** | Graduate `MxRendezvous` → `MxV2TrainingPublisher` / `MxV2RefitReceiver`; adopt `worker_extension_cls` pattern in inference worker | ~−400 LOC (PI's reimpl removed) + ~150 LOC import-and-call | Us | +| **3a** | Add `compile_target` + `compile_metadata` to v2 shape registry | ~30 LOC | Us | +| **3b** | Add `compile_target_filter` to `discover_v2_sources` | ~15 LOC | Us | +| **3c** | Add `target_tp_layout` + `discover_v2_sources_for_slice` to `MxV2RefitReceiver` | ~120 LOC | Us | +| **3d** | Implement `compile_passes/` (HFRaw, DeepGemmFP8, CutlassFP8) — in MX repo for reuse | ~300 LOC | Us | +| **3e** | Validate on NemoRL+Dynamo path (John's GB300 cluster) — Qwen3-4B-Thinking with DeepGemm and cutlass kernels both running on the same MX server | E2E | Us + John | +| **3f** | Mirror compile-pass dispatch into PI's `inference/vllm/worker/nixl_mx.py` | ~50 LOC | Us — PR back to PI | +| **4** | Mixed-TP / mixed-EP slice discovery wired end-to-end (multi-source RDMA pulls) | ~200 LOC | Us — separable from Phase 3 | + +**Phases 0-1** are fully sequenced (must wait for upstream + apply surgical fixes). **Phases 2 onward** can run in parallel if we're willing to maintain a `kavink/post-2389-*` branch off PI's main + a follow-on PR per phase. + +--- + +## 9. Open questions + +1. **Does PI want the compile passes in their tree, or in MX?** If MX, they import a pluggable `CompilePassRegistry`. If their tree, the kernel ecosystem stays close to their Slot system (which already does fp8_blockwise). My lean: **MX**, because Dynamo + NemoRL also want them, and PI's per-Slot conversion stays for the trainer-side path when teams opt into "publish post-compile". + +2. **Does cutlass-FP8 work on inference-side compute?** Compile pass needs ~200ms of CUDA time. If the inference engine is mid-rollout when the refit arrives, we either pause and run the compile or queue it for the next "between rollouts" window. PrimeRL's current orchestrator does the latter; this plan inherits. + +3. **How do we handle trainers that publish post-compile (Matej's current path)?** Their `compile_target = "deep_gemm_fp8"`; receivers either accept it directly (fast path) or reject and look for `hf_raw`. Mixed fleets get clean error messages, not corrupt weights. + +4. **Mixed-TP across nodes — what's the bandwidth math?** Trainer TP=4 ↔ inference TP=8 means each inference rank pulls from 1-2 trainer ranks. For Qwen3-30B-A3B on GB200 (~30 GB / 4 trainer ranks = 7.5 GB/rank), an inference rank pulling 2× 4 GB slices is well within NIC budget. For larger EP layouts where one inference rank needs experts from N>2 trainer ranks, fan-in becomes interesting — that's where pipeline replication (TensorHub) and rollouts-as-replicas pay off. + +5. **What's the deprecation story for trainer-side compile?** We don't deprecate it — Matej's path stays valid for teams that want zero inference-side latency. The `compile_target` field is just informational; receivers filter on it. + +6. **Should the compile pass run before or after `update_mla_absorbed_weights`?** After. MLA absorption operates on live params; compile runs first so live params are in the right layout when MLA absorption runs. + +--- + +## 10. Component view + sequence diagram + +```mermaid +flowchart TB + subgraph trainer["Trainer side (after Phase 2)"] + TBcast["NIXLMxWeightBroadcast
(PI's lifecycle wrapper)"] + TSlots["Slots: Sharded · Gathered · Expert
(PI's data plane, kept)"] + TPlan["TransportPlan
(PI's, kept)"] + TPub["MxV2TrainingPublisher
(NEW — replaces MxRendezvous)"] + TAgent["NixlAgentWrapper
(PI's, kept)"] + TBcast --> TSlots + TBcast --> TPlan + TBcast --> TPub + TSlots --> TAgent + TPlan --> TAgent + TPub -. publishes registry incl.
compile_target=hf_raw .-> TPlan + end + + subgraph mx["MX control plane (unchanged)"] + MXSVR[("MX Server · gRPC + Redis
shape registry, compile-target
filter, tree fan-out catalog")] + end + + subgraph inf["Inference side (after Phase 2+3)"] + IRec["MxV2RefitReceiver
(NEW — replaces ad-hoc rendezvous)"] + IScratch["receive_weights_scratch
+ target_tp_layout
(extended John's path)"] + ICompile["CompilePass dispatch (NEW)
HFRaw · DeepGemmFP8 · CutlassFP8"] + ILive["vLLM model.named_parameters()
(live params — kernel-specific layout)"] + IRec --> IScratch + IScratch --> ICompile + ICompile --> ILive + end + + TPub <-.->|"publish_metadata (incl. compile_target)
set_status · update_status"| MXSVR + IRec <-.->|"discover_v2_sources(compile_target_filter=...)
list_sources · get_metadata"| MXSVR + TAgent <==>|"one-sided RDMA WRITE
UCX rc_mlx5 / RoCE
(HF-raw bytes, post-cast pre-compile)"| IScratch + + style mx fill:#fec,stroke:#963 + style MXSVR fill:#fec,stroke:#963,stroke-width:2px + style TPub fill:#cce,stroke:#33c,stroke-width:2px + style IRec fill:#cce,stroke:#33c,stroke-width:2px + style ICompile fill:#cfc,stroke:#363,stroke-width:2px + style ILive fill:#fcc,stroke:#c33 +``` + +### One refit cycle, after this plan lands + +```mermaid +sequenceDiagram + autonumber + participant O as Orchestrator + participant T as Trainer (NIXLMxWeightBroadcast) + participant MX as MX Server + participant I as Inference worker + participant C as CompilePass + + Note over T,I: BOOT (once per refit run) + O->>I: POST /init_nixl_mx (host, port, rank, kernel_target=deep_gemm_fp8) + I->>I: register live params with NIXL (PI's data plane) + I->>MX: publish_metadata(role=inference, kernel_target=deep_gemm_fp8) + I->>MX: update_status(READY) + + Note over T,I: PER REFIT STEP + O->>I: POST /update_weights + I->>MX: wait_for_all_peers_ready(role=trainer, READY) + + T->>T: lazy_init slots; per-rank scratch fill from state_dict
(NO trainer-side cutlass — bytes are HF-raw) + T->>MX: publish_metadata(role=trainer, compile_target="hf_raw",
compile_metadata={...}, shape registry per tensor) + T->>MX: update_status(INITIALIZING → READY) + + I->>MX: discover_v2_sources(model, min_version=N,
compile_target_filter={"hf_raw"},
target_tp_layout=TP=8 rank=3) + MX-->>I: candidates = [trainer R0 (covers requested slice)] + I->>I: pick_best_source + + Note right of I: SCRATCH PATH (from John's NemoRL+Dynamo work) + I->>T: NIXL one-sided RDMA WRITE → scratch buffers
(HF-raw layout, ~380 Gbps on GB300 RoCE) + I->>I: torch.cuda.synchronize + + Note right of I: COMPILE PASS (new — runs inference-side, ~50-200ms) + I->>C: apply(scratch_buffers, live_params)
e.g. DeepGemm scale interleave, fused gate_up_proj pack + C-->>I: live params updated in DeepGemm-friendly layout + + I->>I: update_mla_absorbed_weights (if MLA model) + I->>MX: publish_self_as_source(role=inference_replica, version=N)
(tree fan-out for next refit) + + I-->>O: 200 OK + O->>O: scheduler advances · next rollout uses new weights +``` + +--- + +## 11. Cross-references + +ModelExpress design docs (NVIDIA-authored, for context on the client surface this plan adopts): + +- [`docs/RL/PRIMERL_MX_OVERVIEW.md`](https://github.com/ai-dynamo/modelexpress/blob/main/docs/RL/PRIMERL_MX_OVERVIEW.md) — the foundational prime-rl × MX integration design (catalog + star wiring story). +- [`docs/RL/NEMORL_MX_OVERVIEW.md`](https://github.com/ai-dynamo/modelexpress/blob/kavink/nemo_rl_moe/docs/RL/NEMORL_MX_OVERVIEW.md) — the v2 design (rank-to-rank, tree fan-out, expert filter, shape registry) that this plan extends with the compile-target axis. +- [`docs/RL/VERL_MX_OVERVIEW.md`](https://github.com/ai-dynamo/modelexpress/blob/main/docs/RL/VERL_MX_OVERVIEW.md) — the verl `MxCheckpointEngine` integration; sibling adopter of the same MX clients. + +Upstream branches this plan refers to: + +- ModelExpress branch [`kavink/nemo_rl_moe`](https://github.com/ai-dynamo/modelexpress/tree/kavink/nemo_rl_moe) — the v2 client surface (`MxV2TrainingPublisher`, `MxV2RefitReceiver`), `shape_descriptors`, sidecar transport, and PR #295's sidecar filter. The MX-side dependency that Phase 2 imports. +- ModelExpress PR [#295](https://github.com/ai-dynamo/modelexpress/pull/295) — synthetic-sidecar `TensorDescriptor` filter in `MxRefitReceiver`. Required for any v2 metadata extension (including this plan's `compile_target` field) to survive `prep_xfer_dlist` validation on cross-node RoCE. Already merged into `kavink/nemo_rl_moe`; just don't back it out. +- NemoRL × Dynamo branches (John Thompson, NVIDIA): [`KavinKrishnan/RL:kavink/mx_integration`](https://github.com/KavinKrishnan/RL/tree/kavink/mx_integration) (NeMo-RL side) + the Dynamo-side companion. Validated at 380 Gbps on GB300 RoCE for an 8.82 GB / 399-tensor refit (Qwen3-4B-Thinking GRPO smoke). + +NVIDIA-internal context (not necessary for upstream review, listed for our own bookkeeping): + +- The 6 inline review comments + summary message we have queued for #2389 — verified line numbers against HEAD `dabaa19f5` (still applicable on `79ea824d8` after I re-checked May 27). Phase 1 of this rollout. +- Current state of #2389: +10 commits since `dabaa19f5` (4× conversion-cast polish, 4× DeepGemm env-var hygiene, 2× config/import fixes). None touched the 6 flagged lines. diff --git a/docs/proposals/post-pr2389-status-and-plan.md b/docs/proposals/post-pr2389-status-and-plan.md new file mode 100644 index 0000000000..f1c838f2fc --- /dev/null +++ b/docs/proposals/post-pr2389-status-and-plan.md @@ -0,0 +1,232 @@ +# prime-rl × ModelExpress — Status of the post-#2389 work and how it addresses the MoE / kernel / quant pain points + +> **Related docs in this directory**: +> - [`post-pr2389-kernel-compile-plan.md`](./post-pr2389-kernel-compile-plan.md) — the full RFC with phase-by-phase design rationale +> - [`build-notes-2026-05-28.md`](./build-notes-2026-05-28.md) — image-build experience, cluster observations, vLLM native RL APIs reframing +> +> This doc is the executive summary: where prime-rl × MX stands today, what the failure classes are, and where in the follow-up plan each one gets resolved. + +**Scope**: status of the ModelExpress + NIXL weight-refit integration in prime-rl, the three classes of failure observed on MoE models (kernel-target mismatches, fused-vs-unfused gates across kernels, quantization + packing mismatches), and the four-phase follow-up plan that resolves them. + +**TL;DR**: [PR #2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) lands the first cut of MX-driven NIXL refit in prime-rl and runs at ~10s/cycle on GB200. The three failure classes reduce to one root cause — prime-rl's weight-conversion registry has only two entries, and there is no `compile_target` tag on published bytes that distinguishes DeepGemm vs Cutlass vs other layouts. Two follow-up PRs are in flight that fix this without changing prime-rl's transfer architecture, plus a ~80-LOC-per-kernel extension that can land independently of either PR. + +--- + +## 1. Current state + +### 1.1 PR #2389 — what it does + +[PR #2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) adds a third weight-broadcast type to prime-rl: `weight_broadcast.type = "nixl_mx"`. The data flow per refit cycle is: + +1. **Trainer side** — for each step's state-dict tensor, run a *trainer-side conversion* (today: `bf16_cast` or `fp8_128x128`) that produces the destination layout vLLM expects, with fusion (e.g. `q_proj/k_proj/v_proj → qkv_proj`) and quant already applied. Land the bytes into pre-allocated NIXL-registered buffers (`ShardedSlot` / `GatheredSlot` / `ExpertSlot`). +2. **Rendezvous** — register with ModelExpress (MX) so the inference side knows where to pull from. Lightweight — just metadata; the bytes never go through MX. +3. **NIXL RDMA push** — trainer writes directly into the inference workers' pre-registered parameter buffers over RDMA. Sub-second for a 30 GB model on GB200 cross-node. +4. **Inference side** — vLLM workers receive a `/update_weights` HTTP from the orchestrator, synchronize, and resume serving. + +### 1.2 Validation status + +Validation has run on a GB200 cluster: + +| Component | State | +|---|---| +| Workload | Qwen3-30B-A3B-Instruct-2507, FSDP 2×2, EP=4, 32/128 experts per rank, FLASHINFER attention | +| Image | `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.5.2` | +| Steady-state refit | ~10 s/cycle, reward=1.0000, off-policy=0 | +| Errors | Zero NIXL errors (no `REMOTE_DISCONNECT`, no `NOT_ALLOWED`, no stale-READY) | +| Tested cycles | 4/4 clean, plus 1 trainer restart cycle (clean exit on `max_steps`, k8s respawned, resumed cleanly) | + +PR #2389 is a real working baseline — not a prototype that falls over on first push. + +### 1.3 Two GB200-specific runtime patches folded in + +Two issues on the multi-NIC fabric required patches to the trainer's NIXL setup: + +- **Same-rank-only peer filter** — on GCP GB200 the four RDMA NICs (`rdma-0..3`) are separate L3 subnets, so trainer rank N can only safely peer with inference rank N. Cross-subnet pairs fail. +- **Freshest-per-rank dedup** — when MX has multiple entries at the same `worker_rank` (e.g. after a pod restart), the freshest entry by `updated_at` must be picked, not the first one. Otherwise the inference side picks a stale entry and the NIXL `add_remote_agent` call refuses with `NIXL_ERR_NOT_ALLOWED`. + +Both are applied today as a runtime monkey-patch in the configmap (`patch_nixl_mx.py`). The [Phase 2 draft PR](https://github.com/KavinKrishnan/prime-rl/pull/1) bakes them into the rendezvous class so they are no longer runtime shims. + +### 1.4 What PR #2389 does not yet do (where the failure classes live) + +Two specific gaps map directly to the observed failure classes: + +| Gap | Failure class | +|---|---| +| The conversion registry (`prime_rl/trainer/models/conversions/__init__.py`) has exactly **two entries**: `bf16_cast` and `fp8_128x128`. Anything else raises `NotImplementedError` at startup. | Quantization and packing mismatches — if the target inference engine wants DeepGemm K-major scale interleaving, Cutlass FP8, MXFP4, a non-128 block size, etc., prime-rl cannot produce those bytes at all. | +| There is no `compile_target` tag on published tensors. Receivers cannot tell whether the bytes were compiled for DeepGemm, Cutlass, or anything else. | MoE-specific corruption + fused-vs-unfused-gate mismatches — when trainer and receiver disagree on the layout, the result is silent byte misinterpretation instead of a clean error. The receiver loads the wrong layout and rollouts are corrupted, often without an obvious crash. | + +These are the same root cause expressed three ways: *prime-rl currently assumes the trainer and inference side agree on one layout, baked in at config time, with no runtime check*. Heterogeneous fleets or kernel-target disagreements break it silently. + +--- + +## 2. The follow-up effort + +The follow-up work is captured in the [post-#2389 RFC](https://github.com/KavinKrishnan/prime-rl/blob/kavink/post-2389-kernel-compile-plan/docs/proposals/post-pr2389-kernel-compile-plan.md). It has four phases and a clear strategic direction. + +### Strategic direction + +The design stays on the **trainer-side post-processed transfer** model: the trainer compiles weights into the destination kernel's exact layout, RDMA-writes them straight into the inference worker's pre-registered parameter buffers, and the inference side does no compute on receipt. This is what prime-rl does today; the `ConversionSpec` / `ShardedSlot` / `ExpertSlot` / `GatheredSlot` structure is the right shape. + +The design does *not* move to a scratch-buffer + `model.load_weights()` pattern. The internal NemoRL + Dynamo integration prototype uses that pattern as a workaround for a vLLM-specific quirk (`stacked_params_mapping` does HF→fused name remapping at load time, so the trainer can publish HF-raw names and the receiver figures it out). It works, but it adds receive-side compute (~50-200 ms per refit on FP8 casts) and is not needed for prime-rl, which already does fusion + quant trainer-side and writes directly into pre-allocated fused buffers. + +### Four phases + +| Phase | What | Where | Status | +|---|---|---|---| +| **Phase 2** | Bake the two GB200 runtime patches into `MxRendezvous` permanently; add a `HeartbeatThread` so crashed workers don't leave stale READY entries; swap the in-tree rendezvous over to the published v2 ModelExpress client. | [`KavinKrishnan/prime-rl#1`](https://github.com/KavinKrishnan/prime-rl/pull/1) | Draft PR open, 11/11 unit tests green | +| **Phase 3a** | Add `compile_target` and `compile_metadata` fields to ModelExpress's `TensorDescriptorV2`. Trainer stamps every publish with what layout it produced — `deep_gemm_fp8`, `cutlass_fp8`, `block_size=128`, `gate_fusion=gate_up_swiglu`, etc. | [`ai-dynamo/modelexpress#349`](https://github.com/ai-dynamo/modelexpress/pull/349) | Draft PR open, 6/6 unit tests green | +| **Phase 3b** | Receivers filter sources by `compile_target` + `compile_metadata` *before* RDMA. Mismatched bytes are refused at discovery; no more silent corruption. | Same PR | 4/4 unit tests green | +| **Phase 4** | Multi-source slice picker for the mixed-TP case (e.g. trainer TP=4 publishing to inference TP=8). Each receiver discovers which subset of publisher ranks covers its slice. | Same PR | 8/8 unit tests green | +| **Phase 0 (parallel)** | prime-rl-only: extend the conversion registry from two entries to N. **No MX dependency. Can land independently of the four phases above.** | Open work (see §4 below) | — | + +### How each phase maps to the three failure classes + +| Failure class | What fixes it | +|---|---| +| Quantization and packing issues (`NotImplementedError`) | **Phase 0**: extend `prime_rl/trainer/models/conversions/` with the kernel-specific layouts required. ~80 LOC per kernel. | +| MoE expert layout differences across kernels | **Phase 0** (new `ConversionSpec` per (model × kernel)) + **Phase 3a/3b** to surface mismatches as clean errors. | +| Some gates are fused vs not | **Phase 0** (per-(model × kernel) `ConversionSpec` defining `sources` + `cat_dim` differently) + **Phase 3a/3b** so the trainer's `compile_metadata.gate_fusion` tag travels and is filtered on. | + +In all three cases, **Phase 0 is the immediate unblock** — Phase 3 is the safety net that prevents silent corruption when there is a mismatch. + +--- + +## 3. The architecture, with what's new highlighted + +```mermaid +flowchart TB + classDef today fill:#e6f3ff,stroke:#0066cc,color:#000 + classDef phase0 fill:#fff4cc,stroke:#c68a00,color:#000,stroke-width:2px + classDef phase2 fill:#d4edda,stroke:#28a745,color:#000,stroke-width:2px + classDef phase3 fill:#f8d7da,stroke:#dc3545,color:#000,stroke-width:2px + classDef physical fill:#eee,stroke:#666,color:#000 + + subgraph trainer["Trainer pod — FSDP 2×2 / EP=4"] + TSD[("HF state_dict
(bf16, unfused)")] + TCONV["ConversionSpec registry
Today: bf16_cast, fp8_128x128
Phase 0: + deep_gemm_fp8, cutlass_fp8,
+ MoE fusion variants per model"] + TSLOTS["Slots: ShardedSlot · GatheredSlot · ExpertSlot
(pre-allocated, NIXL-registered)"] + TRENDZ["MxRendezvous
Today: thin wrapper over MxClient
Phase 2: swap to MxV2TrainingPublisher
+ heartbeat + same-rank + freshest-dedup"] + TTAG["Phase 3a: stamp every publish with
compile_target + compile_metadata
(rides on TensorDescriptorV2)"] + TNIXL["NIXL agent (UCX rc_mlx5)"] + TSD --> TCONV --> TSLOTS --> TNIXL + TSLOTS -.metadata.-> TRENDZ + TRENDZ -.-> TTAG + end + + subgraph mx["ModelExpress server (control plane only — no weight bytes)"] + MXSVR[("gRPC + Redis catalog
(workers, status, descriptors, tags)")] + end + + subgraph inference["Inference pod — vLLM v1, 4 ranks, EP=4"] + IRENDZ["MxRendezvous / MxV2RefitReceiver"] + IFILT["Phase 3b: discover_v2_sources(
compile_target_filter={…},
required_compile_metadata={…})
Mismatched bytes refused BEFORE RDMA"] + IPICK["Phase 4: discover_v2_sources_for_slice(…)
multi-source picker for mixed-TP"] + IPARAMS["vLLM model.named_parameters()
(pre-registered, fused destinations like qkv_proj)"] + INIXL["NIXL agent (UCX rc_mlx5)"] + IRENDZ --> IFILT --> IPICK + INIXL --> IPARAMS + end + + TNIXL ==>|"RDMA WRITE
(post-processed bytes,
~10s for 30GB Qwen3-30B-A3B)"| INIXL + TRENDZ <-.->|"publish metadata + tags"| MXSVR + MXSVR <-.->|"discover + filter"| IRENDZ + + class TSD,TSLOTS,IPARAMS today + class TNIXL,INIXL,TCONV today + class TCONV phase0 + class TRENDZ,IRENDZ phase2 + class TTAG,IFILT,IPICK phase3 + class MXSVR physical +``` + +Legend: 🟦 today's PR #2389 surface, 🟨 Phase 0 (immediate unblock), 🟩 Phase 2, 🟥 Phase 3 / 4. + +The data plane (RDMA write of post-processed bytes from trainer NIC to inference NIC) stays exactly as it is today. Everything new lives in the **metadata plane** (what gets stamped on the publish) and the **registry** (what conversions are available trainer-side). + +--- + +## 4. Phase 0 — extending the conversion registry + +This is independent of both PRs in flight. The conversion registry is plug-in: + +``` +prime_rl/trainer/models/conversions/ +├── __init__.py # registry, select_default_conversion() +├── bf16_cast.py # registered as "bf16_cast" ← today +└── fp8_blockwise.py # registered as "fp8_128x128" ← today +``` + +To add support for a new kernel layout, drop in a new file: + +```python +# prime_rl/trainer/models/conversions/deep_gemm_fp8.py +from prime_rl.trainer.models.conversions import register + +def _convert(src, dst, scale): + # src: bf16 HF-format tensor + # dst: pre-allocated FP8 destination buffer (kernel's expected layout) + # scale: paired scale buffer if requires_scale=True + ... + +register("deep_gemm_fp8", _convert, requires_scale=True) +``` + +The default-conversion picker in `__init__.py:select_default_conversion` reads the inference model's `config.json` and chooses based on `quantization_config`. Today it knows two cases; extend the if/elif chain for new ones. + +**For MoE specifically — the gate-fusion choice lives in the `ConversionSpec` definitions per model.** Example pattern in `prime_rl/trainer/models/qwen3_moe/converting_qwen3_moe.py`: + +```python +ConversionSpec( + dst="mlp.experts.gate_up_proj.weight", # fused destination + sources=("mlp.experts.gate_proj.weight", # source 1 + "mlp.experts.up_proj.weight"), # source 2 + cat_dim=0, # concat along dim 0 +) +``` + +For an unfused-gate kernel target, a parallel set of `ConversionSpec`s would emit `gate_proj` and `up_proj` separately. The framework supports multiple `ConversionSpec` sets via the `Conversion` selector in `MaybeQuantize`. + +Total scope per kernel target: ~80 LOC for the conversion function + per-model `ConversionSpec` additions for the required models. + +Once Phase 3a lands in MX and Phase 2 lands in prime-rl, every publish automatically carries the `compile_target` tag (e.g. `"deep_gemm_fp8"`) and `compile_metadata` (e.g. `{"block_size": 128, "scale_layout": "K-major", "gate_fusion": "unfused"}`), and any inference worker expecting a different layout refuses the source cleanly at discovery instead of corrupting rollouts. + +--- + +## 5. Information required to write the missing conversion entries + +Two inputs are required per kernel target before the missing conversion entries can be written: + +1. **Inference engine + quant config target**. Examples: "vLLM 0.7 + DeepGemm grouped-GEMM MoE FP8 block-128", "vLLM 0.7 + Cutlass MoE FP8 with K-major scales", "Triton MoE with unfused gates and BF16". The exact kernel name + scale layout determines the required output format. +2. **Model architectures involved**. Existing conversion specs cover Qwen3, Qwen3-MoE, GLM-MoE-DSA, Nemotron-H, MiniMax-M2, and Laguna. Any model not yet in `prime_rl/trainer/models/` requires an additional ~150 LOC for the model adapter. + +--- + +## 6. References + +**Branches and PRs** + +- Upstream: [PrimeIntellect-ai/prime-rl#2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) +- Phase 2 draft: [KavinKrishnan/prime-rl#1](https://github.com/KavinKrishnan/prime-rl/pull/1) — `MxRendezvous` heartbeat + dedup + same-rank filter +- Phase 3+4 draft: [ai-dynamo/modelexpress#349](https://github.com/ai-dynamo/modelexpress/pull/349) — `compile_target` + multi-source slice picker +- RFC: [`KavinKrishnan/prime-rl:kavink/post-2389-kernel-compile-plan`](https://github.com/KavinKrishnan/prime-rl/blob/kavink/post-2389-kernel-compile-plan/docs/proposals/post-pr2389-kernel-compile-plan.md) — the full post-#2389 plan +- Build notes: [`KavinKrishnan/prime-rl:kavink/post-2389-kernel-compile-plan/docs/proposals/build-notes-2026-05-28.md`](https://github.com/KavinKrishnan/prime-rl/blob/kavink/post-2389-kernel-compile-plan/docs/proposals/build-notes-2026-05-28.md) — image-build experience, cluster observations, and the vLLM-native-RL-APIs reframing +- Inline review on upstream PR: six inline + one summary, posted at PR HEAD `79ea824d8` — covering cross-subnet `add_remote_agent`, freshest-per-rank dedup, missing heartbeat, hardcoded 1200s timeout, unconditional `update_mla_absorbed_weights`, HSDP barrier ordering + +**Code locations in prime-rl** + +- Conversion registry: `src/prime_rl/trainer/models/conversions/__init__.py` +- Existing conversions: `src/prime_rl/trainer/models/conversions/{bf16_cast,fp8_blockwise}.py` +- `ConversionSpec` definition: `src/prime_rl/trainer/models/conversion_spec.py` +- Per-model `ConversionSpec` registrations: `src/prime_rl/trainer/models//converting_.py` +- Slots: `src/prime_rl/trainer/models/slots.py` +- NIXL+MX trainer broadcast: `src/prime_rl/trainer/rl/broadcast/nixl_mx.py` +- NIXL+MX inference worker: `src/prime_rl/inference/vllm/worker/nixl_mx.py` +- Rendezvous: `src/prime_rl/transport/mx_rendezvous.py` +- Transport plan: `src/prime_rl/transport/transport_plan.py` + +**Cluster** + +- GB200 dev cluster, dedicated namespace +- Image (today): `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.5.2` +- Source-baked Phase 2 + Phase 3 image: `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3` +- ModelExpress server: deployed in the same namespace, Redis-backed