Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e460190
docs(proposals): post-PR-#2389 plan for kernel compile, mixed-TP, MX …
KavinKrishnan May 27, 2026
7feee0d
docs(proposals): scrub stray internal-pensieve reference
KavinKrishnan May 27, 2026
e958f1c
feat(transport/mx): Phase-2 — heartbeat + freshest-per-rank dedup + s…
KavinKrishnan May 27, 2026
0805833
feat(conversions): cutlass FP8 e4m3 per-channel + compile_target/meta…
KavinKrishnan May 28, 2026
d676523
fix(transport/mx_rendezvous): tolerate both modelexpress.heartbeat mo…
KavinKrishnan May 28, 2026
1b36af8
docs(proposals): build notes for Phase 2 + Phase 3 source-baked image…
KavinKrishnan May 29, 2026
dbe936f
docs(proposals): build notes §8 — vLLM native RL APIs reframe Phase 2…
KavinKrishnan May 29, 2026
4b33e90
docs(proposals): add post-pr2389-status-and-plan.md + cross-link the …
KavinKrishnan May 29, 2026
78c0e0c
RFC: weight_broadcast.type="mx_v2" — the complete prime-rl × ModelExp…
KavinKrishnan Jun 2, 2026
df1b81a
test: add unit tests for mx_v2 worker + broadcast + selector + small …
KavinKrishnan Jun 2, 2026
0cc7c3b
build(mx_v2): fix Dockerfile uv path + smoke tests; v0.7.2 image buil…
KavinKrishnan Jun 2, 2026
82a7540
build(mx_v2): bake flash-attn ARM64 stub + complete source overlay; i…
KavinKrishnan Jun 2, 2026
b17a9fd
feat(orchestrator): wire mx_v2 into the per-cycle refit path
KavinKrishnan Jun 2, 2026
d3f1210
build/configs(mx_v2): full image overlay + orchestrator/inference con…
KavinKrishnan Jun 3, 2026
9ae21a2
fix(mx_v2): worker retry loop + trainer slot API + conversion-as-str …
KavinKrishnan Jun 3, 2026
e9072b4
feat(mx_v2): receiver-side TT→HF translator for Qwen3-MoE pull-mode r…
KavinKrishnan Jun 3, 2026
17b5b4d
feat(mx_v2): trainer GatheredSlot escalation + receiver NIXL transien…
KavinKrishnan Jun 3, 2026
311ee96
Merge remote-tracking branch 'fork/kavink/post-2389-phase2-rendezvous…
KavinKrishnan Jun 3, 2026
3bfb22c
Merge remote-tracking branch 'fork/kavink/post-2389-conversion-regist…
KavinKrishnan Jun 3, 2026
1ac546f
Merge remote-tracking branch 'fork/kavink/post-2389-mx-v2' into kavin…
KavinKrishnan Jun 3, 2026
cf4acb2
chore(docs): move mx_v2 proposal docs to local archive
KavinKrishnan Jun 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions Dockerfile.cuda.mx-v2
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 prime intellect & contributors
# SPDX-License-Identifier: Apache-2.0
#
# Overlay Dockerfile for the v2 prime-rl × ModelExpress integration.
# Layers on top of v0.7.1-kavin-phase2-phase3 (which already has Phase 2
# + Phase 3 source baked in). See docs/proposals/image-build-mx-v2.md.

FROM nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3

USER root

# ────────────────────────────────────────────────────────────────────
# 1a. Install the flash-attn ARM64 stub package
#
# The v0.7.1 baseline image ships only flash-attn-cute (Cute kernels),
# but ring-flash-attn (transitively imported by
# prime_rl.trainer.models.glm_moe_dsa) imports
# `flash_attn.flash_attn_interface`. We restore that import surface
# from the same stub the v0.5.2 image uses on the live kavin trainer
# (extracted via kubectl cp). Functions raise NotImplementedError if
# actually called — callers should use SDPA.
# ────────────────────────────────────────────────────────────────────
COPY --chown=appuser:appuser scripts/flash_attn_stub/__init__.py /app/.venv/lib/python3.12/site-packages/flash_attn/__init__.py
COPY --chown=appuser:appuser scripts/flash_attn_stub/flash_attn_interface.py /app/.venv/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py

# ────────────────────────────────────────────────────────────────────
# 1b. Update modelexpress to the PR #349 branch
# (Phase 4 multi-source slice picker + MxWeightTransferEngine)
#
# uv lives at /usr/local/bin/uv (per Dockerfile.cuda line 31's
# UV_INSTALL_DIR). The venv has no pip installed by default, so we
# use uv with --python pointing at the venv interpreter.
# ────────────────────────────────────────────────────────────────────
RUN --mount=type=cache,target=/app/.cache/uv \
/usr/local/bin/uv pip install --no-deps --reinstall \
--python /app/.venv/bin/python \
"modelexpress @ git+https://github.com/ai-dynamo/modelexpress.git@kavink/post-2389-phase3-4#subdirectory=modelexpress_client/python"

# ────────────────────────────────────────────────────────────────────
# 2. Overlay the v2 prime-rl source files
#
# Six files touched by the kavink/post-2389-mx-v2 branch:
# trainer side — broadcast/nixl_mx_v2.py (new), broadcast/__init__.py (mx_v2 dispatch)
# inference side — worker/nixl_mx_v2.py (new),
# server.py (WORKER_EXTENSION_CLS["mx_v2"] + /init_nixl_mx_v2 +
# /update_weights_v2 endpoints)
# config — packages/.../configs/trainer.py (MxV2WeightBroadcastConfig)
# orchestrator — utils/client.py (init_nixl_mx_v2_broadcast + update_weights_v2)
# transport — transport/ (was unchanged but copied as-is for completeness)
# ────────────────────────────────────────────────────────────────────
COPY --chown=appuser:appuser src/prime_rl/transport/ /app/src/prime_rl/transport/
COPY --chown=appuser:appuser src/prime_rl/inference/vllm/server.py /app/src/prime_rl/inference/vllm/server.py
COPY --chown=appuser:appuser src/prime_rl/inference/vllm/worker/nixl_mx_v2.py /app/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py
COPY --chown=appuser:appuser src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py /app/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py
COPY --chown=appuser:appuser src/prime_rl/trainer/rl/broadcast/__init__.py /app/src/prime_rl/trainer/rl/broadcast/__init__.py
COPY --chown=appuser:appuser src/prime_rl/utils/client.py /app/src/prime_rl/utils/client.py
COPY --chown=appuser:appuser src/prime_rl/orchestrator/orchestrator.py /app/src/prime_rl/orchestrator/orchestrator.py
COPY --chown=appuser:appuser src/prime_rl/orchestrator/scheduler.py /app/src/prime_rl/orchestrator/scheduler.py
COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/packages/prime-rl-configs/src/prime_rl/configs/trainer.py
COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py /app/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py
COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/inference.py /app/packages/prime-rl-configs/src/prime_rl/configs/inference.py

# prime-rl-configs is INSTALLED into /app/.venv/.../site-packages at image
# build time (per the original Dockerfile.cuda's `uv sync`). The COPYs above
# only update the source tree; the runtime imports the installed copy.
# Mirror the three updated config files into the venv site-packages so the
# new MxV2WeightBroadcastConfig + extended Literal types are actually visible
# at import time.
COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/.venv/lib/python3.12/site-packages/prime_rl/configs/trainer.py
COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py /app/.venv/lib/python3.12/site-packages/prime_rl/configs/orchestrator.py
COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/inference.py /app/.venv/lib/python3.12/site-packages/prime_rl/configs/inference.py

# ────────────────────────────────────────────────────────────────────
# 3. Smoke-test that the v2 path imports cleanly
# ────────────────────────────────────────────────────────────────────
RUN /app/.venv/bin/python -c "from modelexpress.vllm_weight_transfer import MxWeightTransferEngine, MxInitInfo, MxUpdateInfo; print('engine adapter:', MxWeightTransferEngine)"
RUN /app/.venv/bin/python -c "from modelexpress.nemo_rl_v2 import MxV2TrainingPublisher, MxV2RefitReceiver, TrainerWorldLayout; print('v2 fat clients OK')"
RUN /app/.venv/bin/python -c "import vllm; print('vllm:', vllm.__version__)"
# Smoke test that flash_attn stub allows the broadcast __init__.py import chain
RUN /app/.venv/bin/python -c "from flash_attn.flash_attn_interface import _flash_attn_forward; print('flash_attn stub OK')"
# Smoke test the mx_v2 imports from prime-rl-side + WORKER_EXTENSION_CLS["mx_v2"]
RUN /app/.venv/bin/python -c "import sys; sys.path.insert(0, '/app/src'); sys.path.insert(0, '/app/packages/prime-rl-configs/src'); from prime_rl.configs.trainer import MxV2WeightBroadcastConfig; from prime_rl.trainer.rl.broadcast.nixl_mx_v2 import NIXLMxV2WeightBroadcast; from prime_rl.inference.vllm.worker.nixl_mx_v2 import NIXLMxV2WeightUpdateWorker; print('prime-rl mx_v2 surfaces all import OK')"
RUN /app/.venv/bin/python -c "import sys; sys.path.insert(0, '/app/src'); sys.path.insert(0, '/app/packages/prime-rl-configs/src'); from prime_rl.inference.vllm.server import WORKER_EXTENSION_CLS; assert 'mx_v2' in WORKER_EXTENSION_CLS, 'mx_v2 missing from server.WORKER_EXTENSION_CLS'; print('WORKER_EXTENSION_CLS[mx_v2] =', WORKER_EXTENSION_CLS['mx_v2'])"

USER appuser
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def auto_resolve_parsers(self):


class WeightBroadcastConfig(BaseConfig):
type: Literal["nccl", "filesystem", "nixl_mx"] = "filesystem"
type: Literal["nccl", "filesystem", "nixl_mx", "mx_v2"] = "filesystem"
"""Weight broadcast transport."""


Expand Down
36 changes: 35 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,42 @@ class NIXLMxWeightBroadcastConfig(BaseConfig):
"""Total inference GPUs across all servers."""


class MxV2WeightBroadcastConfig(BaseConfig):
"""Orchestrator-side config for ``weight_broadcast.type = "mx_v2"``.

Mirrors the trainer-side ``MxV2WeightBroadcastConfig`` in
``configs/trainer.py``. The orchestrator reads ``host`` / ``port`` to
init the v2 receivers via ``/init_nixl_mx_v2`` and uses the Phase 3b
filter fields to drive per-cycle ``/update_weights_v2`` calls.
"""

type: Literal["mx_v2"] = "mx_v2"

host: str = "localhost"
port: int = 29501
timeout: int = 1200

inference_world_size: int = Field(1, ge=1)

# ─── Discovery (Phase 2) ────────────────────────────────────────────
same_rank_only: bool = True
"""GB200/EFA multi-NIC fabrics: receivers pull from same-rank trainer only."""

# ─── Layout metadata (Phase 3b) ─────────────────────────────────────
compile_target_filter: list[str] | None = None
"""Receiver-side whitelist of acceptable compile_target strings.
``None`` = back-compat (accept anything)."""

# ─── Pipeline replication (TensorHub pattern) ───────────────────────
publish_self_as_replica: bool = True
"""Receivers republish as sources after refit for tree fan-out."""


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLMxWeightBroadcastConfig,
FileSystemWeightBroadcastConfig
| NCCLWeightBroadcastConfig
| NIXLMxWeightBroadcastConfig
| MxV2WeightBroadcastConfig,
Field(discriminator="type"),
]

Expand Down
66 changes: 65 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,72 @@ class NIXLMxWeightBroadcastConfig(BaseWeightBroadcastConfig):
"""HF model name of the inference target."""


class MxV2WeightBroadcastConfig(BaseWeightBroadcastConfig):
"""v2 weight broadcast over NIXL + ModelExpress fat clients.

Selectable from config via ``weight_broadcast.type = "mx_v2"``.
Coexists with the existing ``"nixl_mx"`` path (PR #2389) for
migration. See ``docs/proposals/post-pr2389-mx-v2.md`` for the
full design. Maps to
:class:`prime_rl.trainer.rl.broadcast.nixl_mx_v2.NIXLMxV2WeightBroadcast`
(trainer) and
:class:`prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker`
(inference).
"""

type: Literal["mx_v2"] = "mx_v2"

# ─── Control plane (same as nixl_mx) ────────────────────────────
host: str = "localhost"
"""Host for the ModelExpress server."""

port: int = 29501
"""Port for the ModelExpress server."""

timeout: int = 1200
"""Timeout in seconds for rendezvous and per-step transfers."""

inference_world_size: int = 1
"""Number of GPUs used for inference."""

inference_model_name: str = ""
"""HF model name of the inference target."""

# ─── Discovery (Phase 2) ────────────────────────────────────────
same_rank_only: bool = True
"""GB200/EFA multi-NIC fabrics: receivers pull from same-rank trainer only.
rdma-0..3 are separate L3 subnets, so cross-rank writes are unrouted."""

dedup_freshest_per_rank: bool = True
"""When multiple READY entries share a worker_rank (e.g. after a pod
restart), pick the freshest by ``updated_at``. Without this, stale
catalog entries cause ``NIXL_ERR_NOT_ALLOWED`` on ``add_remote_agent``."""

# ─── Layout metadata (Phase 3) ──────────────────────────────────
publish_compile_target: bool = True
"""Trainer stamps every publish with the conversion's compile_target tag
(e.g. ``cutlass_fp8``, ``deep_gemm_fp8``, ``hf_raw``) so receivers can
refuse mismatched layouts at discovery, before any RDMA cycle."""

compile_target_filter: list[str] | None = None
"""Receiver-side whitelist of acceptable compile_target strings.
``None`` (default) = accept anything — back-compat with PR #2389
publishers that don't carry the tag. Set e.g. ``["cutlass_fp8"]``
or ``["cutlass_fp8", "hf_raw"]`` to refuse mismatches."""

# ─── Pipeline replication (TensorHub pattern) ───────────────────
publish_self_as_replica: bool = True
"""After a successful receive, inference workers republish their
NIXL buffers as additional sources. Subsequent receivers can pull
from peers instead of the trainer, amplifying total egress
bandwidth. Trainer NIC stops being the bottleneck past ~4 receivers."""


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLMxWeightBroadcastConfig,
FileSystemWeightBroadcastConfig
| NCCLWeightBroadcastConfig
| NIXLMxWeightBroadcastConfig
| MxV2WeightBroadcastConfig,
Field(discriminator="type"),
]

Expand Down
70 changes: 70 additions & 0 deletions scripts/flash_attn_stub/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Stub flash_attn package for ARM64 GB200 (no compiled kernels).

Installs an import hook that synthesizes any missing submodule of
flash_attn (e.g. flash_attn.ops, flash_attn.ops.triton.rotary) so
imports succeed at module-load time. The actual kernel functions
raise NotImplementedError if called — callers should use SDPA.
"""
__version__ = "2.7.3"

import sys
import types
import importlib.abc
import importlib.machinery


def flash_attn_func(*args, **kwargs):
raise NotImplementedError("flash_attn is stubbed on ARM64 GB200 — use attn='sdpa'")


def flash_attn_varlen_func(*args, **kwargs):
raise NotImplementedError("flash_attn is stubbed on ARM64 GB200 — use attn='sdpa'")


def flash_attn_supports_top_left_mask():
return False


def _stub_callable(name):
def _f(*args, **kwargs):
raise NotImplementedError(f"flash_attn stub: {name} not implemented on ARM64 GB200")
_f.__name__ = name
return _f


class _FlashAttnSubmoduleFinder(importlib.abc.MetaPathFinder, importlib.abc.Loader):
"""Synthesize any flash_attn.* submodule on demand.

Returns an empty module with a __getattr__ that lazily produces stub
callables for any attribute access, so imports like
`from flash_attn.ops.triton.rotary import apply_rotary` succeed
and `apply_rotary(...)` raises NotImplementedError.
"""

def find_spec(self, fullname, path, target=None):
if not fullname.startswith("flash_attn."):
return None
if fullname in sys.modules:
return None
# Don't shadow our own real submodules
if fullname == "flash_attn.flash_attn_interface":
return None
return importlib.machinery.ModuleSpec(fullname, self, is_package=True)

def create_module(self, spec):
mod = types.ModuleType(spec.name)
mod.__path__ = []
mod.__file__ = "<flash_attn stub>"
# __getattr__ returns a stub callable for any name
def __getattr__(name):
if name.startswith("__"):
raise AttributeError(name)
return _stub_callable(f"{spec.name}.{name}")
mod.__getattr__ = __getattr__
return mod

def exec_module(self, module):
pass


sys.meta_path.append(_FlashAttnSubmoduleFinder())
36 changes: 36 additions & 0 deletions scripts/flash_attn_stub/flash_attn_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Stub flash_attn_interface — raises on any real call.

Exports every symbol that ring_flash_attn, vLLM, and transformers import
from this module so the import chain doesn't break at module load time.
The actual NotImplementedError fires only if the function is *called*.
"""

_MSG = "flash_attn is stubbed on ARM64 GB200 — use attn='sdpa'"


def flash_attn_func(*a, **kw):
raise NotImplementedError(_MSG)

def flash_attn_varlen_func(*a, **kw):
raise NotImplementedError(_MSG)

def flash_attn_qkvpacked_func(*a, **kw):
raise NotImplementedError(_MSG)

def flash_attn_kvpacked_func(*a, **kw):
raise NotImplementedError(_MSG)

def flash_attn_with_kvcache(*a, **kw):
raise NotImplementedError(_MSG)

def _flash_attn_forward(*a, **kw):
raise NotImplementedError(_MSG)

def _flash_attn_backward(*a, **kw):
raise NotImplementedError(_MSG)

def _flash_attn_varlen_forward(*a, **kw):
raise NotImplementedError(_MSG)

def _flash_attn_varlen_backward(*a, **kw):
raise NotImplementedError(_MSG)
51 changes: 51 additions & 0 deletions src/prime_rl/inference/vllm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def models(request: Request) -> OpenAIServingModels:
"nccl": "prime_rl.inference.vllm.worker.nccl.NCCLWeightUpdateWorker",
"filesystem": "prime_rl.inference.vllm.worker.filesystem.FileSystemWeightUpdateWorker",
"nixl_mx": "prime_rl.inference.vllm.worker.nixl_mx.NIXLMxWeightUpdateWorker",
"mx_v2": "prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker",
}


Expand Down Expand Up @@ -128,6 +129,56 @@ async def init_nixl_mx(request: Request):
return {"status": "ok"}


@router.post("/init_nixl_mx_v2")
async def init_nixl_mx_v2(request: Request):
"""Boot-time init for the ``mx_v2`` worker extension.

Mirrors ``/init_nixl_mx`` but targets
:meth:`NIXLMxV2WeightUpdateWorker.init_nixl_mx_v2`. Accepts optional
``publish_self_as_replica`` (tree fan-out; default True) and
``listen_port`` (NIXL listen port; default None = auto-pick).
"""
data = await request.json()
await engine_client(request).collective_rpc(
"init_nixl_mx_v2",
args=(data["host"], data["port"], data["rank_offset"]),
kwargs={
"publish_self_as_replica": data.get("publish_self_as_replica", True),
"listen_port": data.get("listen_port"),
},
)
return {"status": "ok"}


@router.post("/update_weights_v2")
async def update_weights_v2(request: Request):
"""Per-cycle refit RPC for the ``mx_v2`` worker extension.

Body fields:
step (int, required): trainer version to pull (engine accepts
sources with ``version >= step``).
compile_target_filter (list[str], optional): Phase 3b filter.
``None`` = accept anything (back-compat).
timeout_seconds (float, optional): per-receive RDMA wait cap.
same_rank_only (bool, optional): default True (Phase 2).

Returns the per-worker metrics dict aggregated across collective_rpc
fan-out so the orchestrator can emit them to dashboards without log
parsing.
"""
data = await request.json()
metrics = await engine_client(request).collective_rpc(
"update_weights_via_mx_v2",
args=(int(data["step"]),),
kwargs={
"compile_target_filter": data.get("compile_target_filter"),
"timeout_seconds": float(data.get("timeout_seconds", 300.0)),
"same_rank_only": bool(data.get("same_rank_only", True)),
},
)
return {"status": "ok", "metrics": metrics}


async def custom_init_app_state(
engine_client: EngineClient,
state: State,
Expand Down
Loading