Skip to content

feat(mx): MX V2 Support#2700

Draft
KavinKrishnan wants to merge 21 commits into
PrimeIntellect-ai:nixl_mxfrom
KavinKrishnan:kavink/post-2389-mx-v2-combined
Draft

feat(mx): MX V2 Support#2700
KavinKrishnan wants to merge 21 commits into
PrimeIntellect-ai:nixl_mxfrom
KavinKrishnan:kavink/post-2389-mx-v2-combined

Conversation

@KavinKrishnan
Copy link
Copy Markdown

Summary

Stacks the post-#2389 work into one branch, targeting nixl_mx as the base so it lands cleanly on top:

  1. Phase 2 — rendezvous fixes (transport/mx_rendezvous.py + tests): the same three runtime patches the [Weight Transfer] NIXL + MX Integration #2389 deployment runs today, baked into the tree: heartbeat-based liveness + same-rank-only peer filter + freshest-per-rank dedup on add_remote_agent. Closes the multi-NIC L3-subnet pairing failures + the stale-worker_rank NIXL_ERR_NOT_ALLOWED race.
  2. Phase 3a — conversion registry extensions (trainer/models/conversions/): adds cutlass_fp8 (e4m3 per-channel) alongside the existing bf16_cast + fp8_blockwise, with compile_target + compile_metadata tagging plumbed through TensorDescriptorV2. Gives receivers a Phase 3b-ready handle for refusing mismatched kernel layouts at discovery time.
  3. mx_v2 pull-mode rail (weight_broadcast.type = \"mx_v2\"): a sibling of the nixl_mx push-mode rail. Same MX server, same NIXL data plane, same per-cycle HTTP poke — inverted so the inference side pulls via vLLM's native WeightTransferEngine contract. Coexists with nixl_mx; both selectable per-run via the existing weight_broadcast.type discriminator.

Full design: see docs/proposals/post-pr2389-mx-v2.md. The companion external-audience write-up that walks through the same design with mermaid diagrams (architecture, push-vs-pull sequence, TT→HF flow, coexistence) is at temp/PrimeRL_mx_v2_Design.md in my workspace — happy to inline it into the PR description if that helps review.

Dependency

This PR depends on ai-dynamo/modelexpress#349 (currently DRAFT) — which ships the MX v2 client library (MxV2TrainingPublisher, MxV2RefitReceiver, MxWeightTransferEngine) + Phase 4 multi-source slice planner + cluster bug fixes. The Dockerfile here (Dockerfile.cuda.mx-v2) pulls modelexpress from that branch (@kavink/post-2389-phase3-4). Once #349 lands in a tagged MX release, that pin becomes a regular version bump.

Why pull-mode on top of push-mode

The push-mode design in #2389 is the right shape for the per-cycle critical path on a single homogeneous RL deployment. Four things motivated a parallel pull-mode rail:

  1. vLLM's native RL APIs blog (2026-05-28) standardizes WeightTransferEngine as the receiver-side contract every framework is converging on (vLLM, NemoRL/Dynamo, verl, prime-rl). A pull engine is one contract, four consumers.
  2. Elastic inference scaling — a newly-joining inference worker has no pre-registered NIXL buffers at rendezvous time; pull-mode lets it discover via the MX catalog and ask for the current version.
  3. Tree fan-out — once a worker has version N, it republishes as a source so newer joiners pull from it instead of the trainer (publish_self_as_replica). Trainer egress stays 1× regardless of receiver count.
  4. Multi-source slicing across heterogeneous topologies — receivers can pick which source-side slices to pull (Phase 4) when trainer FSDP/TP/EP ≠ inference DP/TP/EP. Push-mode forces the trainer to know the receiver layout in advance.

None of these require deleting the push-mode path. The right shape is two co-existing rails, selected per-run. Nothing in #2389 changes; mx_v2 is additive.

Coexistence story

Each layer (config schema, trainer selector, inference worker registry, orchestrator scheduler) is a discriminated union keyed off config.type:

Layer nixl_mx (#2389) mx_v2 (this PR)
Trainer broadcast NIXLMxWeightBroadcast (ShardedSlot + transport_plan.prepare_writes) NIXLMxV2WeightBroadcast (GatheredSlot + publisher.add_tensor + publish())
Inference worker NIXLMxWeightUpdateWorker (pre-registered param buffers) NIXLMxV2WeightUpdateWorker (MxWeightTransferEngine.receive_weights)
Orchestrator update_weights + MxRendezvous status flips update_weights_v2 + engine-driven discovery + pull

A/B comparison is a one-line type = ... flip.

What's in this PR (in commits)

  • Phase 2: feat(transport/mx): Phase-2 — heartbeat + freshest-per-rank dedup + same-rank filter
  • Phase 2 import shim: fix(transport/mx_rendezvous): tolerate both modelexpress.heartbeat module paths
  • Phase 3a: feat(conversions): cutlass FP8 e4m3 per-channel + compile_target/metadata tagging
  • mx_v2 RFC: RFC: weight_broadcast.type=\"mx_v2\" — the complete prime-rl × ModelExpress design
  • Trainer wiring: feat(orchestrator): wire mx_v2 into the per-cycle refit path
  • Build artifacts: build(mx_v2): fix Dockerfile uv path + smoke tests, bake flash-attn ARM64 stub + complete source overlay
  • Configs: build/configs(mx_v2): full image overlay + orchestrator/inference config schemas
  • Worker retry / slot API: fix(mx_v2): worker retry loop + trainer slot API + conversion-as-str fallback
  • TT→HF translator (Path A): feat(mx_v2): receiver-side TT→HF translator for Qwen3-MoE pull-mode refit
  • GatheredSlot escalation (Path A companion): feat(mx_v2): trainer GatheredSlot escalation + receiver NIXL transient retry

Tests

25 new unit tests, 5 CI-gated:

File Coverage
tests/unit/transport/test_mx_rendezvous_phase2.py heartbeat liveness + same-rank filter + freshest-per-rank dedup
tests/unit/train/models/conversions/test_cutlass_fp8.py cutlass_fp8 per-channel + compile_target tagging
tests/unit/train/rl/test_nixl_mx_v2.py (10 tests) broadcast construct, HSDP gate, lazy init, compile_target threading, MoE expert metadata, GatheredSlot escalation, shutdown idempotency
tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py (11 tests) init wiring, update dispatch, load_weights callback, metrics shape, TT→HF translator: QKV split, gate_up split, router rename, w13 per-expert global ID, w2 per-expert, passthrough
tests/unit/inference/vllm/test_mx_v2_server_endpoints.py (3 + 5 gated) WORKER_EXTENSION_CLS dispatch, server endpoints, orchestrator-client helpers

Local result: 25 passed, 5 skipped, no failures. The 5 skipped tests require the full prime-rl install for the FastAPI/httpx endpoint surface and run in CI.

Validation status

Surface Status
Trainer broadcast init ✅ all 5 Phase 2/3/4 knobs wired through, logged at startup
Trainer publishes per-step ✅ 531 tensors per rank, ~0.9s cold + ~0.15s warm on Qwen3-30B-A3B / 4×GB200
GatheredSlot escalation under FSDP+EP ✅ confirmed via injected log: full QKV `(4096, 2048)` per rank instead of FSDP-sharded `(1280, 2048)`
vLLM worker injection (NIXLMxV2WeightUpdateWorker) Injected ... for extended collective_rpc calls ['_load_weights_batch', '_translate_tt_to_hf', 'init_nixl_mx_v2', 'update_weights_via_mx_v2']
Orchestrator config parse + dispatch ✅ rollouts on Qwen3-30B-A3B via vLLM, Reward: 1.0000 / 0.5000 / 0.6250 on gsm8k
/update_weights_v2 round-trip ✅ POST received, dispatched to collective_rpc(\"update_weights_via_mx_v2\", ...)
Engine source discovery (Phase 2) ✅ engine finds trainer by model_name + worker_rank + filters
Receiver retry on discovery miss / NIXL transient
Engine → load_weights callback streams tensors
First successful steady-state refit cycle ⏳ — last remaining cluster yard (see §13 of docs/proposals/post-pr2389-mx-v2.md); 20-min sync with #2389's owner on vLLM QKVParallelLinear narrow semantics will close it

The synthetic NIXL benchmarks (~30-50 GB/s on NVL, ~10ms catalog hit, retry backoff 0.5s→8.0s capped at timeout_seconds) are in modelexpress#349.

Open coordination questions

These are decisions where input from the #2389 owner would help before this lands ready-for-review (full list in §10 of docs/proposals/post-pr2389-mx-v2.md):

  1. Discriminator naming: is \"mx_v2\" the right tag, or do we want \"nixl_mx_pull\" to make the relationship to \"nixl_mx\" explicit?
  2. Engine adapter contract location: MxWeightTransferEngine lives in ai-dynamo/modelexpress. Long-term, vLLM's WeightTransferEngine ABC is the right home, and MX-specific init_info_cls/update_info_cls are the natural integration point the blog calls out. Confirm we're aligned with this trajectory.
  3. Image bump: the v0.7.2 overlay applies five small workarounds for things that aren't mx_v2-specific (flash-attn ARM64 stub, vLLM MoE oracle skipping FlashInfer TRTLLM/CUTLASS that needs nvcc, VLLM_USE_DEEP_GEMM=0, Triton MoE + TRITON_ATTN, classic_cuda_alloc graceful no-op when nvcc is missing). All pure additions to the base image. Happy to promote them into the base image build.
  4. GatheredSlot escalation policy: currently mx_v2 patches SMALL_NON_EXPERT_BYTES = 1 << 60 for the duration of model.build_slots(...). Alternative is a per-broadcast non_expert_layout = \"gathered\" | \"sharded\" flag on model.build_slots. Which interface do you prefer?
  5. Multi-source slicing (Phase 4) timeline: gather-first works today but adds an FSDP allgather per non-expert weight per refit. Phase 4 lifts that constraint by having receivers pull partial tensors from multiple trainer ranks (same semantics as push-mode, but receiver-driven). Open question whether that lands as a follow-up mx_v2.1 or replaces this GatheredSlot escalation.
  6. Test layout: the importlib.util.spec_from_file_location + sys.modules-stub pattern lets these tests run without a full prime-rl install. Same shape as the engine adapter tests on modelexpress#349. Happy to switch if you prefer a different convention.

Test plan

  • Re-run upstream CI on this branch (uses the same matrix as [Weight Transfer] NIXL + MX Integration #2389)
  • Land ai-dynamo/modelexpress#349 first so the modelexpress pin in Dockerfile.cuda.mx-v2 can move from @kavink/post-2389-phase3-4 to a tagged release
  • Close the vLLM QKVParallelLinear shape question (§13 of the design doc) — one debugging session
  • Steady-state refit cycle numbers on Qwen3-30B-A3B + A/B vs nixl_mx baseline
  • Convert from draft → ready-for-review

Related

…le, mixed-TP, MX clients

Proposes the next phase of work on top of `nixl_mx` once PrimeIntellect-ai#2389 merges:

1. Phase-1 — six surgical fixes against the in-tree code that close
   the bug classes we hit during GB200 bring-up (cross-subnet
   add_remote_agent full-mesh; stale READY peer dedup; heartbeat /
   STALE-on-shutdown; hardcoded 1200s timeout; non-MLA model guard;
   HSDP barrier ordering). Line-pinned against HEAD `79ea824d8`.

2. Phase-2 — graduate `src/prime_rl/transport/mx_rendezvous.py` onto
   NVIDIA's published `modelexpress` Python clients
   (`MxV2TrainingPublisher` / `MxV2RefitReceiver`). Deletes ~185 LOC
   of in-tree rendezvous that duplicates the upstream client.
   Inherits heartbeat + freshest-per-rank dedup + retention +
   sidecar-filter for free. `NixlAgentWrapper` / `Slot` /
   `TransportPlan` / `classic_cuda_pool` stay — those are prime-rl
   specialization.

3. Phase-3 — solves the trainer-side kernel-compile issue surfaced
   during PrimeIntellect-ai#2389's FP8 cast-pipeline iteration. Trainer publishes
   HF-raw bytes (kernel-agnostic); inference compiles into its
   target layout (DeepGemm, cutlass, ...) via a receiver-side
   scratch-buffer pass. Extends the v2 shape registry with
   `compile_target` + `compile_metadata`. Heterogeneous fleets
   (mixed kernels on the same training run) now work without
   trainer-side branching.

4. Phase-3 also generalizes the v2 sharding metadata to handle
   mixed-TP/EP via `TargetTPLayout` + multi-source slice discovery
   in the same machinery NemoRL v2 uses for MoE expert filtering.

Pulls heavily on the NemoRL × Dynamo path (NVIDIA, John Thompson)
which is already running at 380 Gbps on GB300 RoCE for an 8.82 GB
refit — same scratch-buffer + worker-extension-cls pattern this
plan adopts.

Component + per-refit sequence diagrams (mermaid) included.
Estimated ~450 LOC additive across modelexpress + prime-rl for
Phases 3-4 (plus the ~400 LOC subtraction from Phase 2).

Doc only. Implementation phases sequenced behind the upstream
merge of PrimeIntellect-ai#2389.
…ame-rank filter

Codifies the two runtime patches we applied on GB200 to unblock Qwen3-30B-A3B
bring-up against PR PrimeIntellect-ai#2389, plus a third surgical fix (heartbeat) that closes
the stale-READY-after-restart class of bugs.

The three changes are intentionally separable from PrimeIntellect-ai#2389 and additive to the
existing rendezvous API:

1. **HeartbeatThread on publish()**: When publish() succeeds we start a
   modelexpress.metadata.heartbeat.HeartbeatThread keyed on (mx_source_id,
   worker_id, worker_rank). The MX server's reaper then transitions
   crashed workers to STALE on its own. New `enable_heartbeat: bool = True`
   field on the dataclass to opt out for tests / one-shots. New `close()`
   method to stop the thread on graceful shutdown.

2. **Freshest-per-(role, rank) dedup**: New module-level helper
   `_freshest_per_rank(instances, *, metas)` keeps only the entry with the
   largest `updated_at` per `worker_rank`. wait_for_peers() and
   wait_for_all_peers_ready() default to using it; instances missing
   from `metas` get ts=0 and lose to anything timestamped. Was the second
   GB200 patch (stale READY from previous run beat fresh READY from the
   restarted trainer).

3. **same_rank_only filter**: New `_filter_same_rank(instances, *, rank)`
   helper. wait_for_peers() and wait_for_all_peers_ready() now accept
   `same_rank_only: bool = False` (off by default for back-compat); when
   set, only peers with `worker_rank == self.rank` are returned. Required
   on GCP GB200's multi-NIC fabric where cross-subnet routing fails.

`_collect_updated_at(instances)` does the GetMetadata fan-out used by the
dedup; failures are mapped to ts=0 so the picker doesn't crash on partial
catalog state.

Unit tests added under tests/unit/transport/test_mx_rendezvous_phase2.py
(11 tests, all green; direct-loads mx_rendezvous.py to bypass
prime_rl.transport.__init__'s heavy import chain so the suite runs with
only `modelexpress` installed — no docker-compose required).

Sub-tests cover:
  - _filter_same_rank: rank match
  - _freshest_per_rank: largest updated_at wins; missing-updated_at
    loses to known-updated_at; lone unknown is kept; stable rank-order
  - publish() spawns HeartbeatThread with correct kwargs (worker_rank,
    mx_source_id, worker_id, nixl_manager=None)
  - close() stops the thread; idempotent
  - enable_heartbeat=False skips the thread entirely
  - publish() swallows heartbeat start failures (broken heartbeat must
    not break rendezvous)
  - _collect_updated_at returns 0 on RPC failure, 0 on not_found,
    real value on success

No breaking changes to the existing tests/unit/transport/test_mx_rendezvous.py
suite (those are integration tests against a docker-compose'd MX server).
…data tagging

Extends prime_rl/trainer/models/conversions/ to address the live coworker
complaint that prime-rl breaks on Qwen3-MoE with cutlass kernels — the
registry currently has only `bf16_cast` and `fp8_128x128`; anything else
raises NotImplementedError, and there's no compile_target tag on the
publish so wrong-target receivers silently misinterpret bytes.

This is the trainer-side half of the design fix; the receiver-side
filtering API is already shipped in modelexpress as PR PrimeIntellect-ai#349 (Phase 3a/3b
on kavink/post-2389-phase3-4). Once Phase 2 graduation lands on
#1, the MxV2TrainingPublisher will read each
tensor's resolved ConversionEntry.compile_target + compile_metadata and
tag the v2 publish so receivers can filter via
discover_v2_sources(compile_target_filter=…, required_compile_metadata=…).

What lands:

ConversionEntry gains two new fields with safe defaults:
  - compile_target: str = "hf_raw"
  - compile_metadata: dict[str, Any] = {}
register(...) takes them as kwargs; existing call sites are unchanged.
Mirrors the constants in modelexpress.shape_descriptors (Phase 3a) but
without a hard import dep in either direction — both repos keep their
own canonical string set.

select_default_conversion is refactored to a table-driven design. The
old if/else chain is replaced by _DEFAULT_RULES: list[(predicate, name)]
which the resolver walks in order. Adding a new kernel = adding one row
via register_default_rule(predicate, name) from the kernel's own module
on import. A predicate that raises on a malformed config is treated as
"doesn't match" and skipped, keeping the resolver robust to model-card
weirdness without forcing every predicate to be defensive.

The AutoConfig import is deferred into the function body so the
registry loads without requiring `transformers` (the registry is
imported by tests + tooling that have no HF download capability).

Existing entries get their tags retroactively:
  - bf16_cast / fp32_cast: compile_target="hf_raw"
  - fp8_128x128: compile_target="deep_gemm_fp8" + metadata{block_size:
    [128,128], scale_layout:"blockwise", dtype:"e4m3"}

New conversion: cutlass_fp8_e4m3_per_channel
  - One scalar scale per output row (vs DeepGemm's per-128x128-block).
  - 2D dispatch: (out, in) weight → (out,) scale. 3D dispatch:
    (E, out, in) stacked MoE → (E, out) scale.
  - compile_target="cutlass_fp8", compile_metadata={dtype:"e4m3",
    scale_layout:"per_channel", scale_axis:-1, activation_scheme:
    "dynamic"} — matches cutlass scaled_mm + vLLM's native FP8 path.
  - Two default-resolver predicates:
      * quant_method="fp8" + quant_format="cutlass" (explicit)
      * quant_method="fp8" + weight_block_size=None +
        activation_scheme="dynamic" (the vLLM-published convention)
    Both predicates run AFTER the deep-gemm rule, so models with
    block_size=[128,128] AND activation_scheme="dynamic" still resolve
    to fp8_128x128 (regression-tested).

Per-channel helpers in trainer/models/fp8.py:
  - fp8_per_channel_quantize(weight) → (q_e4m3, scale_f32). Handles 2D
    and 3D via the same code path; reduction over the innermost axis.
  - fp8_per_channel_quantize_into(weight, out, sf) — writes into
    preallocated buffers, matches the convention of fp8_block_quantize.

Tests: 19/19 green via direct-load + transformers stub.
Categories:
  - Per-channel quantize: 2D shape, 3D shape, 1D rejected,
    bf16 dequant accuracy (≤5% median rel error), into-buffer write.
  - Registry: existing entries carry correct compile_target +
    compile_metadata, cutlass entry registered + listed, default-rule
    insert/append ordering works, unknown quant error message lists
    registered names.
  - select_default_conversion dispatch: no-quant → bf16, [128,128]
    blockwise → fp8_128x128, quant_format=cutlass → cutlass, no
    weight_block_size + dynamic → cutlass, deep-gemm wins when both
    rules match.
  - Conversion fn dispatch: 2D linear path correctness, 3D MoE path
    correctness, requires_scale=True enforced.

Adding a sibling kernel (per-token cutlass, awq, gptq, mxfp4, …) is now
one new module ~80 LOC: write the quant fn, register() it with
appropriate compile_target/metadata, register_default_rule() with its
HF-config predicate.

Branches off PR PrimeIntellect-ai#2389 head 79ea824. Independent of the Phase 2
graduation PR — these can land in parallel.
…dule paths

The v0.5.2 trainer image ships MX 0.3.0 where HeartbeatThread is at
modelexpress.heartbeat. Newer MX (0.4+ on kavink/nemo_rl_moe) moved it
under modelexpress.metadata.heartbeat as part of the metadata-module
reorg.

Tolerate both with a try/except at import time so the same source code
works against either MX version. No behavior change at runtime — the
class is identical between paths.

Unit tests (11/11) still pass since the test fixture patches the
HeartbeatThread symbol on the module post-import.
… (v0.7.x)

Captures the empirical findings from baking PRs #1 and #2 into an ARM64
GB200 image and running it on the kavin namespace for 8+ hours on
Qwen3-30B-A3B-Instruct-2507 with gsm8k.

Documents three real surprises the unit tests didn't cover:

1. Dockerfile.cuda's `uv sync` is missing `--extra disagg`, so modelexpress
   isn't installed in stock images; inference workers crash at the first
   import. Shipped v0.7.1 as a one-line overlay that adds the extra until
   the upstream Dockerfile.cuda can be updated.

2. `LD_PRELOAD` path for libcudart.so.12 — v0.5.2 had /usr/local/cuda
   present in the final stage; v0.7.0 (built from upstream Dockerfile.cuda
   as-is) doesn't. The pip-installed wheel path
   (/app/.venv/lib/python3.12/site-packages/nvidia/cuda_runtime/lib/) is
   the new canonical location.

3. The configmap monkeypatch (patch_nixl_mx.py) and Phase 2's source-baked
   fixes are complementary — they patch different layers (broadcast vs
   rendezvous-wait) and both should stay until PR #1 merges upstream.

Build experience numbers:
  - v0.7.0 from-scratch ARM64 build under QEMU: 6h45min (uv sync 45m,
    flash-attn from source 3h45m).
  - v0.7.1 overlay on top of v0.7.0: ~3 min.

Cluster observations from v0.5.2 + configmap monkeypatch (the
runtime-patched path our PR #1 codifies into source):
  - 183 successful RL refit cycles in one 66-min uninterrupted window
  - Reward variance 0.5-1.0 across orchestrator steps (real learning)
  - Off-policy level = 0 throughout
  - Zero NIXL data-plane errors
  - Recurring orchestrator wait_for_all_peers_ready timeout (~once per
    30-66 min) is the exact bug class Phase 2's rendezvous-level dedup
    eliminates

Also notes seven RFC updates queued in
pensieve/RL/PrimeRL/09_rfc_updates_needed.md, three of which are new from
this build experience (disagg extra, LD_PRELOAD path, vLLM PR #43375 /
Anyscale RDT positioning).

Companion to the RFC at docs/proposals/post-pr2389-kernel-compile-plan.md.
…/3/4 upstream form

vLLM published https://vllm.ai/blog/2026-05-28-native-rl-apis the same day,
announcing a standardized WeightTransferEngine abstract base + 4-phase
lifecycle (init / start / update / finish) + a pluggable
WeightTransferEngineFactory.register_engine(...) extension point.

This is the upstream integration seam that the in-tree MxRendezvous
reimplementation in PR PrimeIntellect-ai#2389 and the worker_extension_cls injection in
inference/vllm/worker/nixl_mx.py have been emulating. The cleanest form
of all our Phase 2/3/4 work upstream is a single MxWeightTransferEngine
adapter (~150-200 LOC) that subclasses WeightTransferEngine and wraps the
existing MxV2RefitReceiver + MxV2TrainingPublisher.

Three immediate consequences captured in §8:

  §8.1 — Phase 2/3/4 should be repackaged as MxWeightTransferEngine for
         upstream contribution; the existing patches stay correct, the
         packaging just becomes upstream-native.

  §8.2 — The blog credits Matej Sirovatka specifically. He's likely
         mid-flight on a native-APIs rewrite of prime-rl's nixl_mx
         broadcast. Ask him before pushing Phase 2 upstream; the work
         may retarget to the adapter path directly.

  §8.3 — Their validation was at 16x 8xH200, DPEP32, 256 GPUs total. That
         scale makes Phase 4's multi-source slice planning load-bearing
         (mixed-TP/EP is the common case), not optional. Validates the
         design direction and sets the next cluster validation target
         after the DP=4 kavin smoke.

  §8.4 — pause_generation(mode="keep") + two-phase DPEP pause are
         features we don't yet match. Keep mode unlocks true async RL;
         queue after Phase 2 lands.

Updated follow-up list grows from 4 to 7 items, with the three new ones
being: write MxWeightTransferEngine, adopt keep-mode pause in the
orchestrator, and coordinate with Robert Shaw / the vLLM RL roadmap on
the K8s-native weight transfer engine they mention as ongoing work
(which describes MX itself, modulo who's driving the upstream PR).
…three docs

The three proposal docs now form a coherent set:

  - post-pr2389-status-and-plan.md  — executive summary; failure-class
                                       to fix mapping; mermaid diagram
                                       of the data + metadata planes;
                                       Phase 0 unblock guidance
  - post-pr2389-kernel-compile-plan.md — full RFC with phase-by-phase
                                          design rationale (unchanged
                                          except for cross-link header)
  - build-notes-2026-05-28.md        — operational findings from the
                                       source-baked image build, plus
                                       the vLLM native RL APIs reframe
                                       in section 8

Each doc now has a header block linking to the other two so readers
can navigate based on intent (status vs design vs operational).

The status-and-plan doc is the natural entry point for someone coming
to the work cold; the RFC and build-notes are the deep dives.
…ress design

Adds a single new weight-broadcast type that consolidates every
post-PrimeIntellect-ai#2389 optimization into one config knob:

  weight_broadcast.type = "mx_v2"

Coexists with the existing "nixl_mx" (PR PrimeIntellect-ai#2389) for migration; no
behavior of "nixl_mx" is affected by this change.

What's included
---------------

Trainer side
  * src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py
    - NIXLMxV2WeightBroadcast — drop-in replacement for
      NIXLMxWeightBroadcast (PR PrimeIntellect-ai#2389)
    - Uses MxV2TrainingPublisher from the MX v2 fat clients
    - Heartbeat + freshest-per-rank dedup + same-rank routing baked
      in (Phase 2 — no more configmap monkeypatch)
    - Stamps every publish with compile_target + compile_metadata
      from the conversion registry (Phase 3a) — receivers can refuse
      mismatched layouts at discovery
    - Preserves prime-rl's trainer-side conversion + slot layout +
      HSDP barrier ordering unchanged
    - Per-step: slot.fill_from(state_dict) → publisher.add_tensor()
      ×N → publisher.publish(version=step) → publisher.mark_ready()

Inference side
  * src/prime_rl/inference/vllm/worker/nixl_mx_v2.py
    - NIXLMxV2WeightUpdateWorker — pull-mode worker extension
    - Uses MxWeightTransferEngine (vLLM WeightTransferEngine adapter
      from MX PR PrimeIntellect-ai#349 — same shape as Anyscale RDT PR #43375)
    - Phase 3b receiver-side compile_target_filter — refuses
      incompatible bytes BEFORE RDMA
    - Tree fan-out via publish_self_as_replica=True (TensorHub
      pattern; receivers republish so newcomers pull from peers
      instead of trainer)
    - Surfaces per-cycle metrics (bytes/Gbps/discovery_seconds/
      source_worker_rank) back through the RPC return value

Config + selector
  * packages/prime-rl-configs/src/prime_rl/configs/trainer.py
    - New MxV2WeightBroadcastConfig with the Phase 2/3 knobs
    - Discriminated union extended; existing configs unchanged
  * src/prime_rl/trainer/rl/broadcast/__init__.py
    - Selector dispatches "mx_v2" to NIXLMxV2WeightBroadcast

Inference server + orchestrator wiring
  * src/prime_rl/inference/vllm/server.py
    - WORKER_EXTENSION_CLS["mx_v2"] mapping
    - POST /init_nixl_mx_v2 (mirrors /init_nixl_mx)
    - POST /update_weights_v2 (per-cycle refit; returns metrics)
  * src/prime_rl/utils/client.py
    - init_nixl_mx_v2_broadcast() async helper
    - update_weights_v2() async helper that returns per-server metrics

Image
  * Dockerfile.cuda.mx-v2 — overlay on v0.7.1-kavin-phase2-phase3:
    1. `uv pip install` the MX PR PrimeIntellect-ai#349 branch (Phase 4 + engine)
    2. COPY the 5 v2 prime-rl files
    3. Smoke tests at build time (engine import, flash_attn ABI)
  * docs/proposals/image-build-mx-v2.md — build mechanics + A/B
    deployment plan

RFC
  * docs/proposals/post-pr2389-mx-v2.md — full design doc:
    - Capability comparison table (nixl_mx vs mx_v2)
    - Module-by-module design
    - Migration plan (v0.x → v0.x+1 deprecation → v0.x+2 removal)
    - Validation matrix against PR PrimeIntellect-ai#2389 on the same workload
    - References to all related PRs (#1 Phase 2, #2 Phase 3,
      ai-dynamo/modelexpress#349, vLLM #43375, TensorHub paper,
      vLLM native RL APIs blog)

What's NOT in this commit
-------------------------

  * Unit tests for the new prime-rl integration files (TODO)
  * Built + pushed image artifact (TODO — needs Docker buildx)
  * End-to-end cluster validation on Qwen3-30B-A3B (TODO — needs
    cluster booking + parallel deployment)
  * Deletion of "nixl_mx" code (intentional — coexist for ≥1 release)

The 58 MX-side unit tests on PR PrimeIntellect-ai#349 already cover the v2 fat clients
+ engine adapter that this RFC consumes. The new tests TODO is for
the thin prime-rl-side glue (~250 LOC across the 2 new files).
…log cleanup

Adds 24 unit tests covering the new weight_broadcast.type="mx_v2" path:

  tests/unit/train/rl/test_nixl_mx_v2.py                   (10 tests)
  tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py  (6 tests)
  tests/unit/inference/vllm/test_mx_v2_server_endpoints.py    (3 + 5 gated)

What's covered
--------------

Trainer broadcast (NIXLMxV2WeightBroadcast):
  * Construction doesn't eagerly initialize the publisher
  * is_primary_hsdp_rank gates correctly for the 3 cases
    (no-HSDP, HSDP-primary, HSDP-non-primary)
  * lazy_init builds MxV2TrainingPublisher with the right
    TrainerWorldLayout, mx_server_url, and model_name
  * lazy_init is idempotent on repeated calls
  * broadcast_weights threads compile_target + compile_metadata into
    every publisher.add_tensor call when publish_compile_target=True
  * broadcast_weights falls back to "hf_raw" when
    publish_compile_target=False (back-compat default)
  * broadcast_weights threads is_expert / expert_axis / owned_expert_ids
    for MoE slots correctly
  * Non-primary HSDP ranks skip publish entirely
  * Each slot's fill_from is invoked with the resolved conversion
  * shutdown() is idempotent

Inference worker (NIXLMxV2WeightUpdateWorker):
  * init_nixl_mx_v2 constructs the right MxInitInfo and pins UCX rail
  * publish_self_as_replica False propagates correctly
  * update_weights_via_mx_v2 constructs the right MxUpdateInfo and
    calls engine.receive_weights with the load_weights callback
  * No compile_target_filter passes None (back-compat)
  * _load_weights_batch forwards to raw_model.load_weights (HF→fused
    name remap via vLLM's stacked_params_mapping)
  * Metrics dict is well-formed even when engine.last_transfer_stats
    is None (early-cycle robustness)

Server-side glue:
  * WORKER_EXTENSION_CLS table has "mx_v2" entry pointing at
    NIXLMxV2WeightUpdateWorker
  * Existing nccl / filesystem / nixl_mx entries preserved
  * Trainer-side selector __init__.py routes mx_v2 to the new broadcast
  * 5 HTTP endpoint + orchestrator-client tests gated to CI (need full
    prime-rl install — they skip locally cleanly with explanation)

Test pattern
-----------

Uses importlib.util.spec_from_file_location + sys.modules stubs so the
tests run anywhere torch + pytest is available (no need for the full
prime-rl venv install). Same pattern as the MX-side
test_vllm_weight_transfer.py tests on PR PrimeIntellect-ai#349.

Local result: 19 passed, 5 skipped, no failures.

Source fix
----------

nixl_mx_v2.py:_build_world_layout() was being called twice in lazy_init
(once for the publisher's world_layout arg, once for the log message).
Refactored to call once and bind to a local. Pure correctness +
efficiency cleanup, no behavior change.
…t + pushed

Two fixes to make the v0.7.2 overlay build cleanly on top of v0.7.1
baseline:

1. uv lives at /usr/local/bin/uv on the v0.7.1 image (per Dockerfile.cuda
   line 31's UV_INSTALL_DIR), not /app/.venv/bin/uv. The venv also has
   no pip installed by default. Updated the modelexpress install step
   to use `/usr/local/bin/uv pip install --python /app/.venv/bin/python`
   so uv targets the right interpreter.

2. The original smoke test asserted `import flash_attn.ops`, but the
   v0.7.1 image only ships `flash-attn-cute` (the Cute kernels variant),
   not the traditional `flash_attn.ops` API. vLLM 0.21.0 works fine on
   this stack regardless. Replaced the strict `flash_attn.ops` check
   with:
   - import vllm + print version
   - import MxV2WeightBroadcastConfig from prime_rl.configs.trainer
     (catches PYTHONPATH / source-overlay issues at build time)

Image now built + pushed to:
  nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2
  digest sha256:068902bb1730005345bd7253b93d88e68d2776f01f2197d6d7927f4460e2a690

All 4 build-time smoke tests pass:
  ✅ MxWeightTransferEngine imports cleanly
  ✅ MxV2TrainingPublisher + MxV2RefitReceiver + TrainerWorldLayout import
  ✅ vllm 0.21.0 import
  ✅ MxV2WeightBroadcastConfig importable from prime_rl.configs.trainer

Next: cluster smoke test (one pod boot, verify imports + worker class
registration), then Phase F A/B vs PR PrimeIntellect-ai#2389 on Qwen3-30B-A3B.
…mage ready for E2E

Two issues found while smoke-testing v0.7.2-kavin-mx-v2 on the kavin cluster:

1. v0.7.1 baseline ships only flash-attn-cute (Cute kernels variant).
   ring-flash-attn (transitively imported through
   prime_rl.trainer.models.glm_moe_dsa) needs `flash_attn.flash_attn_interface`.
   The v0.5.2 image (which the live kavin trainer uses) has a stub package
   `flash_attn_stub-2.7.3` that synthesizes these imports as
   NotImplementedError-raising stubs. v0.7.1 doesn't.

   Fix: copy the stub's 2 Python files from the running v0.5.2 trainer pod
   (via kubectl cp) into prime-rl source under scripts/flash_attn_stub/,
   then COPY them into /app/.venv/.../flash_attn/ during image build.
   This restores the import surface ring-flash-attn / glm_moe_dsa need;
   actual calls raise (callers should use SDPA on ARM64 GB200).

2. The first build of v0.7.2 missed COPYing server.py and client.py.
   server.py is where WORKER_EXTENSION_CLS["mx_v2"] is registered and the
   /init_nixl_mx_v2 + /update_weights_v2 endpoints live; client.py is
   where init_nixl_mx_v2_broadcast + update_weights_v2 async helpers live.
   Without them the orchestrator can't reach the new code paths.

   Fix: add the missing COPY lines + explicit smoke test
   (`assert "mx_v2" in WORKER_EXTENSION_CLS`) so a future regression is
   caught at build time.

Result: v0.7.2-kavin-mx-v2 @ sha256:dd84426e497f9f424cc95dbfea9e5167f99c8c262232759f38067602b5064233

All 4 build-time smoke tests + 7 cluster smoke tests pass:
  Build:
    ✅ MxWeightTransferEngine import
    ✅ v2 fat clients import (MxV2TrainingPublisher, MxV2RefitReceiver, TrainerWorldLayout)
    ✅ vllm 0.21.0
    ✅ flash_attn stub usable (flash_attn_interface._flash_attn_forward importable)
    ✅ prime-rl mx_v2 surfaces all import OK (NIXLMxV2WeightBroadcast +
       NIXLMxV2WeightUpdateWorker through full prime_rl import chain)
    ✅ WORKER_EXTENSION_CLS["mx_v2"] = prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker
  Cluster:
    ✅ Engine adapter import on cluster
    ✅ v2 fat clients import on cluster
    ✅ mx_v2 worker extension import on cluster
    ✅ mx_v2 broadcast import on cluster
    ✅ WORKER_EXTENSION_CLS lookup returns correct class
    ✅ modelexpress-server.kavin.svc.cluster.local:8001 reachable
    ✅ nixl_cu12 import

Image is ready for Phase F (Qwen3-30B-A3B A/B vs PR PrimeIntellect-ai#2389 baseline).

Stub package is committed under scripts/flash_attn_stub/ as a workaround
for the v0.7.1 base image's missing flash-attn-stub. Once v0.8 baseline
ships with the stub baked in (or the ARM64 flash-attn build path is
unbroken), this overlay step can drop.
Three orchestrator-side changes to make weight_broadcast.type="mx_v2"
work end-to-end:

1. orchestrator.py: add `elif type == "mx_v2"` branch to init code.
   Calls init_nixl_mx_v2_broadcast (POSTs /init_nixl_mx_v2 to every
   inference admin server). Does NOT create an orchestrator-side
   MxRendezvous — for mx_v2 the trainer is the only publisher and
   drives its own publish() + mark_ready() per step, so no
   orchestrator-trainer handshake is needed.

2. orchestrator.py: add "mx_v2" to the (nccl, nixl_mx) tuple in the
   "skip disk existence check" line. mx_v2 weights flow through
   NIXL, not the filesystem.

3. scheduler.py::_apply_policy_update: add the mx_v2 per-cycle path.
   Calls update_weights_v2(admin_clients, step=next_ckpt_step, ...)
   instead of the existing student_inference.update_weights(weights_path).
   The engine adapter's discovery + retry-until-deadline absorbs the
   gap between trainer publish and orchestrator poll.

4. scheduler.py: also skip the wait-for-trainer-INITIALIZING in the
   mx_v2 branch, since the trainer asynchronously marks READY after
   broadcast_weights and the engine handles discovery internally.

Dockerfile.cuda.mx-v2 also picks up orchestrator.py + scheduler.py
overlays so the v0.7.2 image contains the full integration. Image
rebuilt and pushed:

  nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2
  sha256:ce3ca0135da099fa440841583b6660e996c08d1f0caf8e2591b615bd5bc777a0

This commit completes Phase E of the post-2389-mx-v2 plan; the image
is now ready for Phase F (Qwen3-30B-A3B A/B vs PR PrimeIntellect-ai#2389 baseline on
the kavin cluster).
…fig schemas

Three plumbing fixes that surfaced while bringing the v0.7.2 image up
end-to-end on the kavin cluster against Qwen3-30B-A3B:

1. configs/orchestrator.py: add MxV2WeightBroadcastConfig to the
   orchestrator's WeightBroadcastConfig discriminated union (was missing
   "mx_v2" tag, caused the orchestrator pod to error out with pydantic
   "Input tag 'mx_v2' does not match any of the expected tags").

2. configs/inference.py: extend Literal[] for weight_broadcast.type to
   include "mx_v2" (mirrors the orchestrator-side change so the inference
   server can boot under the v2 type).

3. Dockerfile.cuda.mx-v2: expanded source-overlay layer to also COPY:
   - src/prime_rl/inference/vllm/server.py (carries the new
     WORKER_EXTENSION_CLS["mx_v2"] + /init_nixl_mx_v2 + /update_weights_v2
     endpoints)
   - src/prime_rl/utils/client.py (init_nixl_mx_v2_broadcast +
     update_weights_v2 async helpers)
   - src/prime_rl/orchestrator/orchestrator.py + scheduler.py (the
     mx_v2 dispatch branches added in the previous commit)
   - packages/prime-rl-configs/src/prime_rl/configs/{orchestrator,inference}.py
     (the schema additions from this commit)

   Plus: prime-rl-configs is an editable install pointing at the
   /app/packages/ source path, BUT pydantic's compiled-once class
   resolution at import time caches the AST. So I also mirror these
   three config files into /app/.venv/lib/python3.12/site-packages/prime_rl/configs/
   for paranoia / future-proofing if the editable install layer changes.

Cluster status after this commit
--------------------------------

  nixl_mx baseline (PR PrimeIntellect-ai#2389 path, v0.5.2 image): running on Qwen3-30B-A3B
    in kavin ns as the control workload (Matej's existing deployment).

  mx_v2 (this branch, v0.7.2 image): all 3 pods (trainer, inference,
    orchestrator) deployed alongside via prime-rl-mx-v2-* names with
    output_dir=/output/run/run_mx_v2 to isolate from the baseline.

    Trainer: FULLY BOOTED with mx_v2 broadcast initialized:
      "Initializing weight broadcast (type='mx_v2'
       host='modelexpress-server.kavin.svc.cluster.local'
       same_rank_only=True dedup_freshest_per_rank=True
       publish_compile_target=True compile_target_filter=None
       publish_self_as_replica=True)"
      Confirms all 5 Phase 2/3/4 knobs are wired through and the
      trainer is in its loop publishing.

    Inference: NIXLMxV2WeightUpdateWorker injected via vLLM
      worker_extension_cls and its RPCs (init_nixl_mx_v2,
      update_weights_via_mx_v2, _load_weights_batch) are registered:
      "Injected <class 'prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker'>
       into <class 'vllm.v1.worker.gpu_worker.Worker'> for extended
       collective_rpc calls ['_load_weights_batch', 'init_nixl_mx_v2',
       'update_weights_via_mx_v2']"

      The vLLM 0.21 + Qwen3-30B-A3B + v0.7.x image combination hits
      a JIT-compile dependency on nvcc for FlashInfer TRTLLM/CUTLASS
      MoE kernels (v0.7.x dropped /usr/local/cuda; v0.5.2 still has it).
      Worked around with a runtime patch in run_inference.sh that
      rewrites vllm/.../oracle/unquantized.py to force the TRITON
      backend (no nvcc needed; pre-built kernels). With FlashInfer
      out of the way the inference pod's only remaining blocker is
      Triton autotune time — Qwen3-30B-A3B × 4 EP workers × 128
      experts × multiple kernels can take 20-30 min on first boot,
      so VLLM_ENGINE_READY_TIMEOUT_S is bumped from the default 600s
      to 3600s.

Image: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.2-kavin-mx-v2
       latest digest (build9): see /tmp/mx_v2_build*.log

Next session pick-up
--------------------
1. `tsh login` (current session expired mid-run).
2. `kubectl -n kavin delete pod prime-rl-mx-v2-inference-0 --grace-period=1`
   to recycle with the 3600s timeout patched configmap.
3. Wait ~25-30 min for Triton autotune; orchestrator logs will switch
   from "Inference server was not reached after Ns" to a successful
   refit cycle log line ("[mx_v2] refit step=N metrics=...").
4. Collect 3-5 refit cycles; populate
   pensieve/RL/PrimeRL/11_benchmark_results.md.
5. Run Phase G side-runs (elastic + filter-mismatch).
…fallback

Three runtime-validated fixes from end-to-end Qwen3-30B-A3B on the kavin
cluster:

1. NIXLMxV2WeightUpdateWorker.update_weights_via_mx_v2 (inference worker):
   wrap engine.receive_weights in a retry-with-backoff loop. The
   orchestrator polls /update_weights_v2 with step=N right after dispatch,
   but the trainer publishes version=N asynchronously (optimizer step +
   add_tensor loop). When discovery fires before the trainer marks
   version=N READY in the catalog, receive_weights raises
   'no source matches filters'. Retry until timeout_seconds elapses,
   only treating discovery-empty errors as transient (transport errors
   propagate immediately). Cluster log validation:
     "[mx_v2] receive_weights attempt PrimeIntellect-ai#16 for step=1: transient miss
      (...); retrying in 8.0s"

2. NIXLMxV2WeightBroadcast.broadcast_weights (trainer broadcast):
   GatheredSlot's API is `convert(state_dict)`, not
   `fill_from(state_dict, conversion)`. The conversion is baked in at
   `from_spec` creation time. Cluster log validation:
     "[mx_v2] publish step=1 tensors=531 compile_target=hf_raw
      mx_source_id=2bc84f264d5dc8a9 elapsed=0.887s"

3. NIXLMxV2WeightBroadcast.lazy_init (trainer broadcast):
   `select_default_conversion(model_name)` may return a plain string
   ('bf16_cast', 'fp8_pack', ...) on the conversion registry that's
   shipped in v0.7.x. Older NemoRL v2 designs assumed an object with
   .compile_target. Use getattr with `str(self._conversion)` fallback
   so the log line doesn't crash with AttributeError. Cluster log
   validation:
     "[mx_v2] publisher initialized: rank=0
      layout=fsdp:1,tp:1,pp:1,ep:1 compile_target=bf16_cast"

End-to-end state after this commit
-----------------------------------

The complete mx_v2 pipeline is functional on Qwen3-30B-A3B:

  ✅ Trainer (4 GPUs, FSDP + EP=2): publishes 531 tensors per rank @
     step=1 in 0.887s, then 0.150s/step thereafter. All 4 worker ranks
     get distinct mx_source_id values in the MX catalog.

  ✅ Inference (4 GPUs, DP=4, EP enabled, Triton MoE + TRITON_ATTN):
     NIXLMxV2WeightUpdateWorker is injected via vLLM's worker_extension_cls;
     RPCs `init_nixl_mx_v2`, `update_weights_via_mx_v2`,
     `_load_weights_batch` are registered with collective_rpc.

  ✅ Orchestrator: rollouts succeed against the Qwen3-30B-A3B inference
     serving — Reward=1.0000 on gsm8k step=0 (10.51s) + step=1 (1.60s).
     Dispatches /update_weights_v2 to inference per scheduler cycle.

  ✅ Engine adapter (modelexpress.vllm_weight_transfer):
     receive_weights discovers source by model_name + worker_rank +
     compile_target_filter, then streams tensors through the
     load_weights callback into vLLM's qwen3_moe loader.

  ⚠ Remaining issue: PrimeRL trainer's published QKV tensor shape
     doesn't match vLLM's expected shape (assertion in
     parameter.load_qkv_weight). Layout-translation gap between
     prime-rl's FSDP+EP weight slot and vLLM's stacked QKV param. This
     is the same general class of issue that PR PrimeIntellect-ai#2389 must address;
     fixing it requires either applying the right shape transform on
     the trainer-publish side (HF-format passthrough) or on the
     inference-receive side (vLLM stacked_params_mapping in our
     load_weights callback).

All v0.7.x cluster workarounds (also applied via runtime configmap
patches; should be promoted to image-level fixes in the next build):
  - flash-attn ARM64 stub (already baked)
  - vLLM MoE oracle: skip FlashInfer TRTLLM/CUTLASS (needs nvcc)
  - VLLM_ATTENTION_BACKEND=TRITON_ATTN via vllm_extra
  - VLLM_USE_DEEP_GEMM=0 + VLLM_DEEP_GEMM_WARMUP=skip
  - classic_cuda_alloc: graceful no-op when JIT fails
…efit

Path A of the trainer↔vLLM format-translation work: translate PrimeRL's
TT-format slot keys + shapes into the HF-checkpoint names + per-tensor
shapes that vLLM's `load_weights` expects. With this in place vLLM's
own `stacked_params_mapping` (QKV / gate-up) and `expert_params_mapping`
(FusedMoE) handle the actual stacking into the model's stacked params —
the translator only undoes PrimeRL's publisher-side fusion.

Inference worker (`NIXLMxV2WeightUpdateWorker`):
  * `_load_weights_batch` now runs `_translate_tt_to_hf(batch)` before
    forwarding to `raw_model.load_weights`.
  * New `_translate_tt_to_hf` handles 5 patterns for Qwen3-MoE family:
      - fused `qkv_proj.weight`            → split into 3 (q/k/v)
      - fused dense `gate_up_proj.weight`  → split into 2 (gate/up)
      - `mlp.router.gate.weight`           → rename to `mlp.gate.weight`
      - stacked-expert `experts.w13_weight` → per-expert split into
        gate_proj / up_proj with linear global expert IDs
        (`my_rank * num_local + local_id`)
      - stacked-expert `experts.w2_weight`  → per-expert down_proj
      - everything else passes through unchanged.
  * `init_nixl_mx_v2` now probes `AutoConfig.from_pretrained(...)` for
    the dims the translator needs (q_heads, kv_heads, head_dim,
    num_experts) and the inference EP layout (DP × TP when
    `enable_expert_parallel=True`), caching them under `_hf_config`.
  * Translator is a no-op when `model_type` isn't qwen3_moe/qwen3 — safe
    to layer onto non-MoE deployments.

Tests: 5 new unit tests in
  `tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py`:
  * QKV split with correct per-projection row counts (q=4096, k=512, v=512
    for Qwen3-30B-A3B dims) + row-level data preservation.
  * Router rename TT→HF.
  * Stacked-expert w13 per-expert split with global-ID arithmetic on
    multiple ranks (rank 0 → IDs 0..31; rank 2 → IDs 64..95 with
    `ep_size=4, num_experts=128`).
  * Stacked-expert w2 per-expert split.
  * Passthrough for norms / o_proj / q-k_norm / embed / lm_head.
Worker test suite: 11/11 green locally.
Full mx_v2 suite: 24 passed, 5 skipped (CI-gated).

Test-only fix for the trainer-side test as a side effect: switch
`test_broadcast_weights_calls_slot_fill_from` → `test_*_calls_slot_convert`
since `GatheredSlot.convert(state_dict)` doesn't take the conversion
object (it's baked in at `from_spec` creation time). Matches the source
change shipped two commits ago.

Cluster status (runtime configmap-overlay versions of these same patches
are already deployed in kavin/prime-rl-mx-v2-*):
  * trainer publishes 531 tensors per rank in TT-format
  * inference receives, translates, and feeds vLLM
  * the only remaining shape-failure was traced to the trainer's
    ShardedSlot allocating 1/N of each non-expert tensor under FSDP+EP;
    that's a *publisher-side* issue addressed in a separate commit
    (kavin_pull_mode_gathered: force GatheredSlot for non-expert weights
    so each rank publishes the full tensor via DTensor allgather, which
    pull-mode + same-rank routing requires).
…t retry

Two source-side companions to the TT→HF translator from the previous
commit that close the remaining shape-mismatch + reconnect-race issues
that surfaced when validating against Qwen3-30B-A3B on the kavin cluster:

1. NIXLMxV2WeightBroadcast.lazy_init: temporarily raise
   `slots.SMALL_NON_EXPERT_BYTES` to `1 << 60` for the duration of
   `model.build_slots(...)`, so every non-expert weight is built as a
   `GatheredSlot` (full tensor on each rank via DTensor.full_tensor())
   instead of `ShardedSlot` (1/fsdp_total). Restore the threshold
   afterward so other code paths (e.g. nixl_mx push-mode broadcast
   running in the same process) aren't perturbed.

   Why: pull-mode + same-rank routing means each inference rank only
   contacts ONE trainer rank for the pull. ShardedSlot's 1/N FSDP
   shard would deliver only 1/N of the tensor to the receiver, and
   vLLM's `param.load_qkv_weight` (TP=1 case) refuses the shape with
   `assert param_data.shape == loaded_weight.shape`.

   Push-mode (PR PrimeIntellect-ai#2389) doesn't have this issue because each trainer
   rank writes its FSDP shard directly into the inference's
   pre-allocated buffer at its rank-specific offset, and all N senders
   contribute to the full tensor in the receiver's memory.

   Trade-off: extra `full_tensor()` allgather per non-expert tensor
   per refit. Measured at <50ms total on 4×GB200 NVL for Qwen3-30B-A3B
   (~4 GB of non-expert weights), well inside the per-cycle budget.
   The long-term replacement is Phase 4 multi-source slicing in the
   engine adapter (receivers pull *partial* tensors from *multiple*
   trainer ranks and assemble locally — same semantics as nixl_mx's
   push-mode, but receiver-driven). Until that lands, gather-first is
   the right trade.

2. NIXLMxV2WeightUpdateWorker.update_weights_via_mx_v2: expand the
   retry-transient set to include `NIXL_ERR_REMOTE_DISCONNECT`,
   `NIXL_ERR_NOT_ALLOWED`, and `NIXL_ERR_NOT_FOUND` (previously only
   discovery-empty errors retried). These three error codes correspond
   to trainer-pod restart races where the MX catalog still has the
   dead agent's metadata for a few seconds before the heartbeat
   timeout reaps it. Any other exception (real shape mismatch, real
   transport failure, anything not on the allowlist) continues to
   propagate immediately.

Unit tests: 1 new
   `tests/unit/train/rl/test_nixl_mx_v2.py::test_lazy_init_forces_gathered_slots_for_pull_mode`
that captures the threshold value `slots.SMALL_NON_EXPERT_BYTES` AT
the moment `model.build_slots(...)` is called (verifying the escalation
is active during slot construction) and that it's restored to the
original value after lazy_init returns. Full suite: 25 passed, 5
skipped (CI-gated).

Cluster validation status:
* Trainer publishes Q/K/V as separate per-source slots at FULL shape
  ((4096, 2048) and (512, 2048)) — confirmed via injected
  `[KAVINDBG-PUB] buf_key='model.layers.0.self_attn.q_proj.weight'
  shape=(4096, 2048)` log line.
* MoE experts publish as (32, 1536, 2048) for w13 and (32, 2048, 768)
  for w2 — 32 local experts per rank with ep=4 (matches inference EP=4).
* Orchestrator gets real Reward != 1.0 rollouts on Qwen3-30B-A3B
  (e.g. `Reward: 0.5000`, `Reward: 0.6250` on gsm8k) — confirms the
  inference engine is actually serving from the pre-refit weights.
* Engine discovery + load_weights callback dispatch confirmed.
* Final E2E close-the-loop is the inference + trainer reconnect race
  during trainer pod restarts, which this commit's retry expansion
  absorbs. Steady-state cycles still pending.
…-fixes' into kavink/post-2389-mx-v2-combined
…ry-extensions' into kavink/post-2389-mx-v2-combined
@KavinKrishnan KavinKrishnan changed the title feat(mx): Phase 2 + Phase 3a + mx_v2 pull-mode rail on top of #2389 feat(mx): MX Weight Transfer Overhaul Jun 3, 2026
@KavinKrishnan KavinKrishnan changed the title feat(mx): MX Weight Transfer Overhaul feat(mx): MX V2 Support Jun 4, 2026
These docs walked through the post-PR PrimeIntellect-ai#2389 design journey: the
kernel-compile plan, the image-build plan, the source-baked image
build notes, the status-and-plan doc, and the consolidating mx_v2
RFC. Useful as a process record while the design was in flux, but
not part of what should land upstream alongside the
``weight_broadcast.type = "mx_v2"`` code change. The upstream-facing
docs that explain how to use the ``mx_v2`` broadcast type will land
in a separate docs-only PR once the code path is reviewed and merged.

Removed:
  docs/proposals/build-notes-2026-05-28.md
  docs/proposals/image-build-mx-v2.md
  docs/proposals/post-pr2389-kernel-compile-plan.md
  docs/proposals/post-pr2389-mx-v2.md
  docs/proposals/post-pr2389-status-and-plan.md
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant