diff --git a/Dockerfile.cuda.mx-v2 b/Dockerfile.cuda.mx-v2 new file mode 100644 index 0000000000..ccea3ac7a3 --- /dev/null +++ b/Dockerfile.cuda.mx-v2 @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 prime intellect & contributors +# SPDX-License-Identifier: Apache-2.0 +# +# Overlay Dockerfile for the v2 prime-rl × ModelExpress integration. +# Layers on top of v0.7.1-kavin-phase2-phase3 (which already has Phase 2 +# + Phase 3 source baked in). See docs/proposals/image-build-mx-v2.md. + +FROM nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.7.1-kavin-phase2-phase3 + +USER root + +# ──────────────────────────────────────────────────────────────────── +# 1a. Install the flash-attn ARM64 stub package +# +# The v0.7.1 baseline image ships only flash-attn-cute (Cute kernels), +# but ring-flash-attn (transitively imported by +# prime_rl.trainer.models.glm_moe_dsa) imports +# `flash_attn.flash_attn_interface`. We restore that import surface +# from the same stub the v0.5.2 image uses on the live kavin trainer +# (extracted via kubectl cp). Functions raise NotImplementedError if +# actually called — callers should use SDPA. +# ──────────────────────────────────────────────────────────────────── +COPY --chown=appuser:appuser scripts/flash_attn_stub/__init__.py /app/.venv/lib/python3.12/site-packages/flash_attn/__init__.py +COPY --chown=appuser:appuser scripts/flash_attn_stub/flash_attn_interface.py /app/.venv/lib/python3.12/site-packages/flash_attn/flash_attn_interface.py + +# ──────────────────────────────────────────────────────────────────── +# 1b. Update modelexpress to the PR #349 branch +# (Phase 4 multi-source slice picker + MxWeightTransferEngine) +# +# uv lives at /usr/local/bin/uv (per Dockerfile.cuda line 31's +# UV_INSTALL_DIR). The venv has no pip installed by default, so we +# use uv with --python pointing at the venv interpreter. +# ──────────────────────────────────────────────────────────────────── +RUN --mount=type=cache,target=/app/.cache/uv \ + /usr/local/bin/uv pip install --no-deps --reinstall \ + --python /app/.venv/bin/python \ + "modelexpress @ git+https://github.com/ai-dynamo/modelexpress.git@kavink/post-2389-phase3-4#subdirectory=modelexpress_client/python" + +# ──────────────────────────────────────────────────────────────────── +# 2. Overlay the v2 prime-rl source files +# +# Six files touched by the kavink/post-2389-mx-v2 branch: +# trainer side — broadcast/nixl_mx_v2.py (new), broadcast/__init__.py (mx_v2 dispatch) +# inference side — worker/nixl_mx_v2.py (new), +# server.py (WORKER_EXTENSION_CLS["mx_v2"] + /init_nixl_mx_v2 + +# /update_weights_v2 endpoints) +# config — packages/.../configs/trainer.py (MxV2WeightBroadcastConfig) +# orchestrator — utils/client.py (init_nixl_mx_v2_broadcast + update_weights_v2) +# transport — transport/ (was unchanged but copied as-is for completeness) +# ──────────────────────────────────────────────────────────────────── +COPY --chown=appuser:appuser src/prime_rl/transport/ /app/src/prime_rl/transport/ +COPY --chown=appuser:appuser src/prime_rl/inference/vllm/server.py /app/src/prime_rl/inference/vllm/server.py +COPY --chown=appuser:appuser src/prime_rl/inference/vllm/worker/nixl_mx_v2.py /app/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py +COPY --chown=appuser:appuser src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py /app/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py +COPY --chown=appuser:appuser src/prime_rl/trainer/rl/broadcast/__init__.py /app/src/prime_rl/trainer/rl/broadcast/__init__.py +COPY --chown=appuser:appuser src/prime_rl/utils/client.py /app/src/prime_rl/utils/client.py +COPY --chown=appuser:appuser src/prime_rl/orchestrator/orchestrator.py /app/src/prime_rl/orchestrator/orchestrator.py +COPY --chown=appuser:appuser src/prime_rl/orchestrator/scheduler.py /app/src/prime_rl/orchestrator/scheduler.py +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py /app/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/inference.py /app/packages/prime-rl-configs/src/prime_rl/configs/inference.py + +# prime-rl-configs is INSTALLED into /app/.venv/.../site-packages at image +# build time (per the original Dockerfile.cuda's `uv sync`). The COPYs above +# only update the source tree; the runtime imports the installed copy. +# Mirror the three updated config files into the venv site-packages so the +# new MxV2WeightBroadcastConfig + extended Literal types are actually visible +# at import time. +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/trainer.py /app/.venv/lib/python3.12/site-packages/prime_rl/configs/trainer.py +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py /app/.venv/lib/python3.12/site-packages/prime_rl/configs/orchestrator.py +COPY --chown=appuser:appuser packages/prime-rl-configs/src/prime_rl/configs/inference.py /app/.venv/lib/python3.12/site-packages/prime_rl/configs/inference.py + +# ──────────────────────────────────────────────────────────────────── +# 3. Smoke-test that the v2 path imports cleanly +# ──────────────────────────────────────────────────────────────────── +RUN /app/.venv/bin/python -c "from modelexpress.vllm_weight_transfer import MxWeightTransferEngine, MxInitInfo, MxUpdateInfo; print('engine adapter:', MxWeightTransferEngine)" +RUN /app/.venv/bin/python -c "from modelexpress.nemo_rl_v2 import MxV2TrainingPublisher, MxV2RefitReceiver, TrainerWorldLayout; print('v2 fat clients OK')" +RUN /app/.venv/bin/python -c "import vllm; print('vllm:', vllm.__version__)" +# Smoke test that flash_attn stub allows the broadcast __init__.py import chain +RUN /app/.venv/bin/python -c "from flash_attn.flash_attn_interface import _flash_attn_forward; print('flash_attn stub OK')" +# Smoke test the mx_v2 imports from prime-rl-side + WORKER_EXTENSION_CLS["mx_v2"] +RUN /app/.venv/bin/python -c "import sys; sys.path.insert(0, '/app/src'); sys.path.insert(0, '/app/packages/prime-rl-configs/src'); from prime_rl.configs.trainer import MxV2WeightBroadcastConfig; from prime_rl.trainer.rl.broadcast.nixl_mx_v2 import NIXLMxV2WeightBroadcast; from prime_rl.inference.vllm.worker.nixl_mx_v2 import NIXLMxV2WeightUpdateWorker; print('prime-rl mx_v2 surfaces all import OK')" +RUN /app/.venv/bin/python -c "import sys; sys.path.insert(0, '/app/src'); sys.path.insert(0, '/app/packages/prime-rl-configs/src'); from prime_rl.inference.vllm.server import WORKER_EXTENSION_CLS; assert 'mx_v2' in WORKER_EXTENSION_CLS, 'mx_v2 missing from server.WORKER_EXTENSION_CLS'; print('WORKER_EXTENSION_CLS[mx_v2] =', WORKER_EXTENSION_CLS['mx_v2'])" + +USER appuser diff --git a/packages/prime-rl-configs/src/prime_rl/configs/inference.py b/packages/prime-rl-configs/src/prime_rl/configs/inference.py index f5ce7ef7ef..a501b9fead 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/inference.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/inference.py @@ -82,7 +82,7 @@ def auto_resolve_parsers(self): class WeightBroadcastConfig(BaseConfig): - type: Literal["nccl", "filesystem", "nixl_mx"] = "filesystem" + type: Literal["nccl", "filesystem", "nixl_mx", "mx_v2"] = "filesystem" """Weight broadcast transport.""" diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 24e0624355..da359124b4 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -557,8 +557,42 @@ class NIXLMxWeightBroadcastConfig(BaseConfig): """Total inference GPUs across all servers.""" +class MxV2WeightBroadcastConfig(BaseConfig): + """Orchestrator-side config for ``weight_broadcast.type = "mx_v2"``. + + Mirrors the trainer-side ``MxV2WeightBroadcastConfig`` in + ``configs/trainer.py``. The orchestrator reads ``host`` / ``port`` to + init the v2 receivers via ``/init_nixl_mx_v2`` and uses the Phase 3b + filter fields to drive per-cycle ``/update_weights_v2`` calls. + """ + + type: Literal["mx_v2"] = "mx_v2" + + host: str = "localhost" + port: int = 29501 + timeout: int = 1200 + + inference_world_size: int = Field(1, ge=1) + + # ─── Discovery (Phase 2) ──────────────────────────────────────────── + same_rank_only: bool = True + """GB200/EFA multi-NIC fabrics: receivers pull from same-rank trainer only.""" + + # ─── Layout metadata (Phase 3b) ───────────────────────────────────── + compile_target_filter: list[str] | None = None + """Receiver-side whitelist of acceptable compile_target strings. + ``None`` = back-compat (accept anything).""" + + # ─── Pipeline replication (TensorHub pattern) ─────────────────────── + publish_self_as_replica: bool = True + """Receivers republish as sources after refit for tree fan-out.""" + + WeightBroadcastConfig: TypeAlias = Annotated[ - FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLMxWeightBroadcastConfig, + FileSystemWeightBroadcastConfig + | NCCLWeightBroadcastConfig + | NIXLMxWeightBroadcastConfig + | MxV2WeightBroadcastConfig, Field(discriminator="type"), ] diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index fd61ceb49a..a890ed85dd 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -494,8 +494,72 @@ class NIXLMxWeightBroadcastConfig(BaseWeightBroadcastConfig): """HF model name of the inference target.""" +class MxV2WeightBroadcastConfig(BaseWeightBroadcastConfig): + """v2 weight broadcast over NIXL + ModelExpress fat clients. + + Selectable from config via ``weight_broadcast.type = "mx_v2"``. + Coexists with the existing ``"nixl_mx"`` path (PR #2389) for + migration. See ``docs/proposals/post-pr2389-mx-v2.md`` for the + full design. Maps to + :class:`prime_rl.trainer.rl.broadcast.nixl_mx_v2.NIXLMxV2WeightBroadcast` + (trainer) and + :class:`prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker` + (inference). + """ + + type: Literal["mx_v2"] = "mx_v2" + + # ─── Control plane (same as nixl_mx) ──────────────────────────── + host: str = "localhost" + """Host for the ModelExpress server.""" + + port: int = 29501 + """Port for the ModelExpress server.""" + + timeout: int = 1200 + """Timeout in seconds for rendezvous and per-step transfers.""" + + inference_world_size: int = 1 + """Number of GPUs used for inference.""" + + inference_model_name: str = "" + """HF model name of the inference target.""" + + # ─── Discovery (Phase 2) ──────────────────────────────────────── + same_rank_only: bool = True + """GB200/EFA multi-NIC fabrics: receivers pull from same-rank trainer only. + rdma-0..3 are separate L3 subnets, so cross-rank writes are unrouted.""" + + dedup_freshest_per_rank: bool = True + """When multiple READY entries share a worker_rank (e.g. after a pod + restart), pick the freshest by ``updated_at``. Without this, stale + catalog entries cause ``NIXL_ERR_NOT_ALLOWED`` on ``add_remote_agent``.""" + + # ─── Layout metadata (Phase 3) ────────────────────────────────── + publish_compile_target: bool = True + """Trainer stamps every publish with the conversion's compile_target tag + (e.g. ``cutlass_fp8``, ``deep_gemm_fp8``, ``hf_raw``) so receivers can + refuse mismatched layouts at discovery, before any RDMA cycle.""" + + compile_target_filter: list[str] | None = None + """Receiver-side whitelist of acceptable compile_target strings. + ``None`` (default) = accept anything — back-compat with PR #2389 + publishers that don't carry the tag. Set e.g. ``["cutlass_fp8"]`` + or ``["cutlass_fp8", "hf_raw"]`` to refuse mismatches.""" + + # ─── Pipeline replication (TensorHub pattern) ─────────────────── + publish_self_as_replica: bool = True + """After a successful receive, inference workers republish their + NIXL buffers as additional sources. Subsequent receivers can pull + from peers instead of the trainer, amplifying total egress + bandwidth. Trainer NIC stops being the bottleneck past ~4 receivers.""" + + WeightBroadcastConfig: TypeAlias = Annotated[ - FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLMxWeightBroadcastConfig, + FileSystemWeightBroadcastConfig + | NCCLWeightBroadcastConfig + | NIXLMxWeightBroadcastConfig + | MxV2WeightBroadcastConfig, Field(discriminator="type"), ] diff --git a/scripts/flash_attn_stub/__init__.py b/scripts/flash_attn_stub/__init__.py new file mode 100644 index 0000000000..8a8b52b06a --- /dev/null +++ b/scripts/flash_attn_stub/__init__.py @@ -0,0 +1,70 @@ +"""Stub flash_attn package for ARM64 GB200 (no compiled kernels). + +Installs an import hook that synthesizes any missing submodule of +flash_attn (e.g. flash_attn.ops, flash_attn.ops.triton.rotary) so +imports succeed at module-load time. The actual kernel functions +raise NotImplementedError if called — callers should use SDPA. +""" +__version__ = "2.7.3" + +import sys +import types +import importlib.abc +import importlib.machinery + + +def flash_attn_func(*args, **kwargs): + raise NotImplementedError("flash_attn is stubbed on ARM64 GB200 — use attn='sdpa'") + + +def flash_attn_varlen_func(*args, **kwargs): + raise NotImplementedError("flash_attn is stubbed on ARM64 GB200 — use attn='sdpa'") + + +def flash_attn_supports_top_left_mask(): + return False + + +def _stub_callable(name): + def _f(*args, **kwargs): + raise NotImplementedError(f"flash_attn stub: {name} not implemented on ARM64 GB200") + _f.__name__ = name + return _f + + +class _FlashAttnSubmoduleFinder(importlib.abc.MetaPathFinder, importlib.abc.Loader): + """Synthesize any flash_attn.* submodule on demand. + + Returns an empty module with a __getattr__ that lazily produces stub + callables for any attribute access, so imports like + `from flash_attn.ops.triton.rotary import apply_rotary` succeed + and `apply_rotary(...)` raises NotImplementedError. + """ + + def find_spec(self, fullname, path, target=None): + if not fullname.startswith("flash_attn."): + return None + if fullname in sys.modules: + return None + # Don't shadow our own real submodules + if fullname == "flash_attn.flash_attn_interface": + return None + return importlib.machinery.ModuleSpec(fullname, self, is_package=True) + + def create_module(self, spec): + mod = types.ModuleType(spec.name) + mod.__path__ = [] + mod.__file__ = "" + # __getattr__ returns a stub callable for any name + def __getattr__(name): + if name.startswith("__"): + raise AttributeError(name) + return _stub_callable(f"{spec.name}.{name}") + mod.__getattr__ = __getattr__ + return mod + + def exec_module(self, module): + pass + + +sys.meta_path.append(_FlashAttnSubmoduleFinder()) diff --git a/scripts/flash_attn_stub/flash_attn_interface.py b/scripts/flash_attn_stub/flash_attn_interface.py new file mode 100644 index 0000000000..133ef5422d --- /dev/null +++ b/scripts/flash_attn_stub/flash_attn_interface.py @@ -0,0 +1,36 @@ +"""Stub flash_attn_interface — raises on any real call. + +Exports every symbol that ring_flash_attn, vLLM, and transformers import +from this module so the import chain doesn't break at module load time. +The actual NotImplementedError fires only if the function is *called*. +""" + +_MSG = "flash_attn is stubbed on ARM64 GB200 — use attn='sdpa'" + + +def flash_attn_func(*a, **kw): + raise NotImplementedError(_MSG) + +def flash_attn_varlen_func(*a, **kw): + raise NotImplementedError(_MSG) + +def flash_attn_qkvpacked_func(*a, **kw): + raise NotImplementedError(_MSG) + +def flash_attn_kvpacked_func(*a, **kw): + raise NotImplementedError(_MSG) + +def flash_attn_with_kvcache(*a, **kw): + raise NotImplementedError(_MSG) + +def _flash_attn_forward(*a, **kw): + raise NotImplementedError(_MSG) + +def _flash_attn_backward(*a, **kw): + raise NotImplementedError(_MSG) + +def _flash_attn_varlen_forward(*a, **kw): + raise NotImplementedError(_MSG) + +def _flash_attn_varlen_backward(*a, **kw): + raise NotImplementedError(_MSG) diff --git a/src/prime_rl/inference/vllm/server.py b/src/prime_rl/inference/vllm/server.py index 09c0d1a201..d7d5a55d86 100644 --- a/src/prime_rl/inference/vllm/server.py +++ b/src/prime_rl/inference/vllm/server.py @@ -57,6 +57,7 @@ def models(request: Request) -> OpenAIServingModels: "nccl": "prime_rl.inference.vllm.worker.nccl.NCCLWeightUpdateWorker", "filesystem": "prime_rl.inference.vllm.worker.filesystem.FileSystemWeightUpdateWorker", "nixl_mx": "prime_rl.inference.vllm.worker.nixl_mx.NIXLMxWeightUpdateWorker", + "mx_v2": "prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker", } @@ -128,6 +129,56 @@ async def init_nixl_mx(request: Request): return {"status": "ok"} +@router.post("/init_nixl_mx_v2") +async def init_nixl_mx_v2(request: Request): + """Boot-time init for the ``mx_v2`` worker extension. + + Mirrors ``/init_nixl_mx`` but targets + :meth:`NIXLMxV2WeightUpdateWorker.init_nixl_mx_v2`. Accepts optional + ``publish_self_as_replica`` (tree fan-out; default True) and + ``listen_port`` (NIXL listen port; default None = auto-pick). + """ + data = await request.json() + await engine_client(request).collective_rpc( + "init_nixl_mx_v2", + args=(data["host"], data["port"], data["rank_offset"]), + kwargs={ + "publish_self_as_replica": data.get("publish_self_as_replica", True), + "listen_port": data.get("listen_port"), + }, + ) + return {"status": "ok"} + + +@router.post("/update_weights_v2") +async def update_weights_v2(request: Request): + """Per-cycle refit RPC for the ``mx_v2`` worker extension. + + Body fields: + step (int, required): trainer version to pull (engine accepts + sources with ``version >= step``). + compile_target_filter (list[str], optional): Phase 3b filter. + ``None`` = accept anything (back-compat). + timeout_seconds (float, optional): per-receive RDMA wait cap. + same_rank_only (bool, optional): default True (Phase 2). + + Returns the per-worker metrics dict aggregated across collective_rpc + fan-out so the orchestrator can emit them to dashboards without log + parsing. + """ + data = await request.json() + metrics = await engine_client(request).collective_rpc( + "update_weights_via_mx_v2", + args=(int(data["step"]),), + kwargs={ + "compile_target_filter": data.get("compile_target_filter"), + "timeout_seconds": float(data.get("timeout_seconds", 300.0)), + "same_rank_only": bool(data.get("same_rank_only", True)), + }, + ) + return {"status": "ok", "metrics": metrics} + + async def custom_init_app_state( engine_client: EngineClient, state: State, diff --git a/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py b/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py new file mode 100644 index 0000000000..4ac2623920 --- /dev/null +++ b/src/prime_rl/inference/vllm/worker/nixl_mx_v2.py @@ -0,0 +1,406 @@ +"""v2 inference-worker extension for ModelExpress weight refits. + +The v2 of :class:`NIXLMxWeightUpdateWorker` (PR #2389), built on the +``MxWeightTransferEngine`` adapter from ModelExpress PR #349. The adapter +wraps every Phase 2/3/4 capability behind vLLM's :class:`WeightTransferEngine` +ABC (the same shape Anyscale's RDT PR `#43375 +`_ targets), so this +module is **maximally thin** — it just instantiates the engine and +plumbs vLLM's ``load_weights`` callback through. + +Key differences from PR #2389: + +- **Pull semantics, not push.** The trainer ``publish()``-es weights to + the MX catalog; this worker calls ``engine.receive_weights(...)`` + which discovers + pulls. No pre-registered NIXL buffers on the + inference side (the engine uses the scratch-buffer path internally). +- **Compile-target safety net (Phase 3b).** Optional + ``compile_target_filter`` refuses sources whose tensors don't match + the kernel layout this worker expects — BEFORE any RDMA cycle is + spent. Set ``filter=None`` for back-compat (accept anything). +- **Mixed-TP path (Phase 4).** When ``target_tp_layout`` is set, the + engine uses the multi-source slice picker; otherwise it uses the + matched-TP single-source fast path. +- **Tree fan-out (TensorHub pattern).** When + ``publish_self_as_replica=True`` in the engine's ``init_info``, the + worker republishes itself as a source after each successful receive, + so subsequent receivers can pull from peers instead of the trainer. + +See :file:`docs/proposals/post-pr2389-mx-v2.md` for the design rationale +and migration plan. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch.nn import Module +from vllm.logger import init_logger + +from prime_rl.inference.vllm.worker.weight_transfer import update_mla_absorbed_weights +from prime_rl.transport.nixl_agent import make_agent_name, pin_ucx_rail + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_worker import Worker + + Worker = Worker # type: ignore +else: + Worker = object # type: ignore + +logger = init_logger("vllm.inference.vllm.worker_nixl_mx_v2") + + +class NIXLMxV2WeightUpdateWorker(Worker): + """vLLM worker extension for the v2 (pull-mode) weight-refit path. + + Mounted via vLLM's ``worker_extension_cls`` plumbing — same hook + PR #2389 uses for ``NIXLMxWeightUpdateWorker``. Two RPC endpoints: + + - :meth:`init_nixl_mx_v2` — called once at worker boot, sets up the + :class:`MxWeightTransferEngine` for this rank. + - :meth:`update_weights_via_mx_v2` — called per refit cycle by the + orchestrator; engine discovers + pulls + feeds vLLM's + ``load_weights``. + """ + + # ------------------------------------------------------------------ + # Model accessor (matches PR #2389) + # ------------------------------------------------------------------ + + @property + def raw_model(self) -> Module: + model_runner = self.model_runner + model = ( + model_runner.model.runnable + if hasattr(model_runner.model, "runnable") + else model_runner.model + ) + assert isinstance(model, Module) + return model + + # ------------------------------------------------------------------ + # Init RPC + # ------------------------------------------------------------------ + + def init_nixl_mx_v2( + self, + host: str, + port: int, + rank_offset: int, + *, + publish_self_as_replica: bool = True, + listen_port: int | None = None, + ) -> None: + """Build the :class:`MxWeightTransferEngine` for this worker. + + Args: + host, port: ``modelexpress-server`` URL. + rank_offset: orchestrator-assigned base rank for this pod; + ``global_rank = rank_offset + self.device.index``. + publish_self_as_replica: if True (default), after each + successful receive this worker republishes itself as + a source so newcomers can pull from it (tree fan-out). + listen_port: optional explicit NIXL listen port; ``None`` + = auto. + """ + from modelexpress.vllm_weight_transfer import MxInitInfo, MxWeightTransferEngine + + local_rank = self.device.index + global_rank = rank_offset + local_rank + inference_model_name = self.model_runner.model_config.model + + pin_ucx_rail(local_rank) + + self._engine = MxWeightTransferEngine( + init_info=MxInitInfo( + mx_server_url=f"{host}:{port}", + model_name=inference_model_name, + worker_rank=global_rank, + agent_name=make_agent_name("inference", global_rank), + device_id=local_rank, + listen_port=listen_port, + publish_self_as_replica=publish_self_as_replica, + ) + ) + self._global_rank = global_rank + + # Cache the HF model config + parallel layout so the TT→HF + # translator (`_translate_tt_to_hf`) can split fused tensors into + # the per-tensor / per-expert names vLLM's `load_weights` expects. + try: + from transformers import AutoConfig + hf = AutoConfig.from_pretrained(inference_model_name) + mc = self.model_runner.model_config + ep_size = getattr(mc, "ep_size", None) or getattr( + mc, "data_parallel_size", 1 + ) + self._hf_config = { + "model_type": getattr(hf, "model_type", ""), + "num_attention_heads": getattr(hf, "num_attention_heads", 0), + "num_kv_heads": getattr(hf, "num_key_value_heads", 0) + or getattr(hf, "num_attention_heads", 0), + "head_dim": getattr(hf, "head_dim", 0) + or ( + getattr(hf, "hidden_size", 0) + // max(1, getattr(hf, "num_attention_heads", 1)) + ), + "num_experts": getattr(hf, "num_experts", 0) + or getattr(hf, "num_local_experts", 0), + "ep_size": int(ep_size or 1), + } + except Exception as e: # noqa: BLE001 — never block engine init + logger.warning( + f"[mx_v2] HF config probe failed ({e!r}); TT→HF translator " + f"will fall through to passthrough — non-MoE models only." + ) + self._hf_config = None + + logger.info( + f"[mx_v2] init: rank={global_rank} model={inference_model_name} " + f"publish_self_as_replica={publish_self_as_replica} " + f"hf_config={self._hf_config}" + ) + + # ------------------------------------------------------------------ + # Per-refit RPC + # ------------------------------------------------------------------ + + @torch.no_grad() + def update_weights_via_mx_v2( + self, + step: int, + *, + compile_target_filter: list[str] | None = None, + timeout_seconds: float = 300.0, + same_rank_only: bool = True, + ) -> dict[str, float | int | None]: + """Pull version ``step`` of the weights from the catalog. + + Args: + step: training-step counter; engine pulls sources with + ``version >= step``. + compile_target_filter: receiver-side Phase 3b filter. + ``None`` (default) = back-compat, accept any layout. + Set e.g. ``["cutlass_fp8"]`` or + ``["cutlass_fp8", "hf_raw"]`` to refuse mismatches at + discovery (no RDMA cycle spent on refusal). + timeout_seconds: cap on the engine's per-receive RDMA wait. + same_rank_only: enforce same-rank routing (required on + GB200/EFA multi-NIC fabrics where rdma-0..3 are + separate L3 subnets). + + Returns: + Per-cycle metrics dict (bytes / Gbps / discovery_seconds / + rdma_seconds) suitable for emission to dashboards. + """ + import time as _time + + from modelexpress.vllm_weight_transfer import MxUpdateInfo + + update_info = MxUpdateInfo( + version=step, + compile_target_filter=set(compile_target_filter) if compile_target_filter else None, + target_tp_layout=None, # matched-TP fast path; Phase 4 wire-up future + timeout_seconds=timeout_seconds, + same_rank_only=same_rank_only, + ) + + # Async-RL synchronization: orchestrator polls /update_weights_v2 with + # step=N right after a training cycle, but the trainer publishes + # version=N asynchronously (it has to finish optimizer.step + the + # publisher's add_tensor loop). If the engine's discovery fires + # before the trainer has marked version=N READY in the MX catalog, + # `receive_weights` raises `no source matches filters`. + # + # Wrap the engine call in a bounded retry loop so the synchronization + # gap is absorbed at the worker layer (no orchestrator changes needed + # and the failure surface stays at this layer's known timeout). + retry_deadline = _time.monotonic() + timeout_seconds + backoff = 0.5 + attempts = 0 + last_err: Exception | None = None + while True: + attempts += 1 + try: + self._engine.receive_weights( + update_info, load_weights=self._load_weights_batch + ) + break + except Exception as e: # noqa: BLE001 — engine may raise plain RuntimeError + msg = str(e) + last_err = e + # Retry on discovery-empty errors (trainer hasn't published + # version=N yet) AND on NIXL transient connection errors + # (trainer-pod restart races where the catalog still has the + # dead agent's metadata for a few seconds). Any other error + # (e.g. real shape mismatch in load_weights) propagates. + transient = ( + "no source matches" in msg + or "NoSourceMatchesFilterError" in msg + or "no matching source" in msg + or "NIXL_ERR_REMOTE_DISCONNECT" in msg + or "NIXL_ERR_NOT_ALLOWED" in msg + or "NIXL_ERR_NOT_FOUND" in msg + ) + if not transient or _time.monotonic() >= retry_deadline: + raise + logger.info( + f"[mx_v2] receive_weights attempt #{attempts} for step={step}: " + f"transient miss ({msg[:80]!r}); retrying in {backoff:.1f}s" + ) + _time.sleep(backoff) + backoff = min(backoff * 1.6, 8.0) + + # Post-load housekeeping: same as PR #2389's path. + torch.cuda.synchronize(self.device) + update_mla_absorbed_weights(self.raw_model) + + # Surface the engine's metrics so the orchestrator / dashboards + # can read per-cycle bandwidth + discovery latency without + # parsing logs. + stats = self._engine.last_transfer_stats + metrics = { + "step": step, + "bytes_received": stats.bytes_received if stats else 0, + "tensors_received": stats.tensors_received if stats else 0, + "rdma_seconds": stats.elapsed_seconds if stats else 0.0, + "bandwidth_gbps": stats.bandwidth_gbps if stats else 0.0, + "discovery_seconds": self._engine.last_discovery_seconds, + "source_worker_rank": stats.source_worker_rank if stats else None, + } + logger.info( + f"[mx_v2] refit step={step} " + f"bytes={metrics['bytes_received'] / 1e6:.1f}MB " + f"rdma={metrics['rdma_seconds']:.3f}s " + f"{metrics['bandwidth_gbps']:.1f}Gbps " + f"from_rank={metrics['source_worker_rank']}" + ) + return metrics + + # ------------------------------------------------------------------ + # vLLM load-weights bridge + # ------------------------------------------------------------------ + + def _load_weights_batch(self, batch: list[tuple[str, torch.Tensor]]) -> None: + """Feed yielded ``(name, tensor)`` pairs through vLLM's load_weights. + + Translation pass: PrimeRL's trainer-side ``GatheredSlot`` emits + tensors in TT-format (fused ``qkv_proj``, stacked-expert + ``w13_weight``/``w2_weight``, ``mlp.router.gate`` prefix). vLLM's + ``load_weights`` expects HF-checkpoint names + per-expert tensors + so its ``stacked_params_mapping`` (QKV / gate-up) and + ``expert_params_mapping`` (FusedMoE) can route them into the + model's actual stacked params. We translate TT → HF here so the + engine adapter (``MxWeightTransferEngine``) stays model-agnostic. + + The slot-side conversion specs that PrimeRL applies on the + publisher side are the inverse of this translator — see + ``prime_rl.trainer.models.qwen3_moe.converting_qwen3_moe``. + """ + translated = self._translate_tt_to_hf(batch) + if translated: + self.raw_model.load_weights(translated) + + # ------------------------------------------------------------------ + # TT → HF translation + # ------------------------------------------------------------------ + + def _translate_tt_to_hf( + self, + batch: list[tuple[str, torch.Tensor]], + ) -> list[tuple[str, torch.Tensor]]: + """Translate PrimeRL TT-format slot keys to HF checkpoint names. + + Currently supports Qwen3-MoE family (Qwen3MoeForCausalLM); other + models pass through (most non-MoE PrimeRL models already match + HF naming). To extend, add per-prefix unstacking logic. + + Layout assumption: the per-trainer-rank expert subset matches the + per-inference-rank EP subset (i.e. ``trainer.ep == inference.EP``), + so local-expert index lines up with global expert ID via + ``my_rank * num_local + local_id``. Cross-EP slicing (Phase 4 + mixed-TP / multi-source picker) is the follow-up that lifts this + constraint. + """ + cfg = self._hf_config + if cfg is None or cfg.get("model_type") not in {"qwen3_moe", "qwen3"}: + return batch # passthrough for unsupported models + + q_size = cfg["num_attention_heads"] * cfg["head_dim"] + kv_size = cfg["num_kv_heads"] * cfg["head_dim"] + num_experts = cfg.get("num_experts", 0) + ep_size = max(1, cfg.get("ep_size", 1)) + num_local_experts = num_experts // ep_size if num_experts else 0 + my_rank = self._global_rank % ep_size if ep_size > 1 else 0 + + out: list[tuple[str, torch.Tensor]] = [] + for name, tensor in batch: + # ── QKV split (fused → q/k/v) ─────────────────────────────── + if name.endswith(".self_attn.qkv_proj.weight"): + prefix = name.removesuffix(".self_attn.qkv_proj.weight") + expected = q_size + 2 * kv_size + assert tensor.shape[0] == expected, ( + f"qkv_proj rows {tensor.shape[0]} != " + f"q({q_size})+k({kv_size})+v({kv_size})={expected}" + ) + out.append((f"{prefix}.self_attn.q_proj.weight", tensor[:q_size])) + out.append((f"{prefix}.self_attn.k_proj.weight", tensor[q_size : q_size + kv_size])) + out.append((f"{prefix}.self_attn.v_proj.weight", tensor[q_size + kv_size :])) + + # ── Dense MLP gate/up split (future-proof, no-op on Qwen3-30B-A3B) + elif name.endswith(".mlp.gate_up_proj.weight"): + prefix = name.removesuffix(".mlp.gate_up_proj.weight") + mid = tensor.shape[0] // 2 + out.append((f"{prefix}.mlp.gate_proj.weight", tensor[:mid])) + out.append((f"{prefix}.mlp.up_proj.weight", tensor[mid:])) + + # ── Router rename (TT prefix → HF) ────────────────────────── + elif name.endswith(".mlp.router.gate.weight"): + prefix = name.removesuffix(".mlp.router.gate.weight") + out.append((f"{prefix}.mlp.gate.weight", tensor)) + + # ── MoE w13 (fused gate+up, stacked across local experts) ─── + elif name.endswith(".mlp.experts.w13_weight"): + prefix = name.removesuffix(".mlp.experts.w13_weight") + if tensor.ndim != 3: + out.append((name, tensor)) + continue + n_local, fused_dim, _ = tensor.shape + moe_dim = fused_dim // 2 + for j in range(n_local): + global_id = my_rank * num_local_experts + j + out.append( + ( + f"{prefix}.mlp.experts.{global_id}.gate_proj.weight", + tensor[j, :moe_dim].contiguous(), + ) + ) + out.append( + ( + f"{prefix}.mlp.experts.{global_id}.up_proj.weight", + tensor[j, moe_dim:].contiguous(), + ) + ) + + # ── MoE w2 (down, stacked across local experts) ───────────── + elif name.endswith(".mlp.experts.w2_weight"): + prefix = name.removesuffix(".mlp.experts.w2_weight") + if tensor.ndim != 3: + out.append((name, tensor)) + continue + n_local = tensor.shape[0] + for j in range(n_local): + global_id = my_rank * num_local_experts + j + out.append( + ( + f"{prefix}.mlp.experts.{global_id}.down_proj.weight", + tensor[j].contiguous(), + ) + ) + + # ── Passthrough: norms, o_proj, q/k_norm, embed, lm_head ──── + else: + out.append((name, tensor)) + + return out diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 14903e6394..ec69936fcc 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -57,6 +57,7 @@ from prime_rl.utils.client import ( init_nccl_broadcast, init_nixl_mx_broadcast, + init_nixl_mx_v2_broadcast, setup_inference_pool, ) from prime_rl.utils.config import cli @@ -283,6 +284,19 @@ async def orchestrate(config: OrchestratorConfig): inference_world_size=config.weight_broadcast.inference_world_size, quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer, ) + elif config.weight_broadcast.type == "mx_v2": + await init_nixl_mx_v2_broadcast( + student_inference.admin_clients, + config.weight_broadcast.host, + config.weight_broadcast.port, + inference_world_size=config.weight_broadcast.inference_world_size, + publish_self_as_replica=config.weight_broadcast.publish_self_as_replica, + ) + # mx_v2 doesn't use an orchestrator-side MxRendezvous: the trainer + # is the only publisher, drives `publish() → mark_ready()` itself + # at each step, and inference receivers pull via the catalog. The + # scheduler drives the per-cycle refit through `/update_weights_v2` + # below. elif config.weight_broadcast.type == "nixl_mx": await init_nixl_mx_broadcast( student_inference.admin_clients, @@ -324,8 +338,8 @@ async def orchestrate(config: OrchestratorConfig): # Allow eval at resumed step by setting prev_ckpt_step one behind prev_ckpt_step = scheduler.ckpt_step - 1 - # In NCCL/NIXL modes, skip existence check - weights are pushed, not stored on disk - check_exists = config.weight_broadcast.type not in ("nccl", "nixl_mx") + # In NCCL/NIXL modes, skip existence check - weights are pushed/pulled, not stored on disk + check_exists = config.weight_broadcast.type not in ("nccl", "nixl_mx", "mx_v2") wait_timeout = config.ckpt.wait_for_weights_timeout if config.ckpt else None weights_path = get_weight_dir( config.output_dir, scheduler.ckpt_step, check_exists=check_exists, wait_timeout=wait_timeout diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index a81733e5d1..1728816907 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -14,7 +14,7 @@ from prime_rl.orchestrator.envs import TrainEnvs from prime_rl.orchestrator.vf_utils import get_seq_len from prime_rl.utils.async_utils import safe_cancel, safe_cancel_all -from prime_rl.utils.client import InferencePool +from prime_rl.utils.client import InferencePool, update_weights_v2 from prime_rl.utils.logger import ProgressTracker, get_logger from prime_rl.utils.utils import ( get_broadcast_dir, @@ -304,7 +304,15 @@ async def _apply_policy_update(self, next_ckpt_step: int) -> None: ) self.checkpoint_ready.clear() wait_for_ckpt_start_time = time.perf_counter() - if self.mx_rendezvous is not None: + if self.config.weight_broadcast.type == "mx_v2": + # mx_v2 pull-mode: trainer publishes asynchronously via + # NIXLMxV2WeightBroadcast.broadcast_weights and marks the + # source READY when version N is available. The engine + # adapter's discovery + retry-until-deadline handles the + # gap. No orchestrator-side wait needed — we just go + # straight into the per-cycle refit below. + pass + elif self.mx_rendezvous is not None: await asyncio.to_thread( self.mx_rendezvous.wait_for_all_peers_ready, role="trainer", @@ -322,17 +330,37 @@ async def _apply_policy_update(self, next_ckpt_step: int) -> None: ) update_weights_start_time = time.perf_counter() - if self.mx_rendezvous is not None: - weights_path = None - signal_trainer = lambda: self.mx_rendezvous.set_status(p2p_pb2.SOURCE_STATUS_READY) + if self.config.weight_broadcast.type == "mx_v2": + # mx_v2 pull-mode path: orchestrator pokes inference workers via + # /update_weights_v2 with the trainer's step; each worker calls + # MxWeightTransferEngine.receive_weights which discovers the + # source via the MX catalog and pulls. The trainer publishes + # version=N from its own loop (NIXLMxV2WeightBroadcast.broadcast_weights) + # — no orchestrator-side mx_rendezvous needed. + metrics = await update_weights_v2( + self.student_inference.admin_clients, + step=next_ckpt_step, + compile_target_filter=getattr( + self.config.weight_broadcast, "compile_target_filter", None + ), + timeout_seconds=float(self.config.weight_broadcast.timeout), + same_rank_only=getattr( + self.config.weight_broadcast, "same_rank_only", True + ), + ) + self.logger.debug(f"[mx_v2] refit step={next_ckpt_step} metrics={metrics}") else: - weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) - signal_trainer = None - await self.student_inference.update_weights( - weights_path, lora_name=self.lora_name, step=next_ckpt_step, on_engines_paused=signal_trainer - ) - if self.mx_rendezvous is not None: - self.mx_rendezvous.set_status(p2p_pb2.SOURCE_STATUS_INITIALIZING) + if self.mx_rendezvous is not None: + weights_path = None + signal_trainer = lambda: self.mx_rendezvous.set_status(p2p_pb2.SOURCE_STATUS_READY) + else: + weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) + signal_trainer = None + await self.student_inference.update_weights( + weights_path, lora_name=self.lora_name, step=next_ckpt_step, on_engines_paused=signal_trainer + ) + if self.mx_rendezvous is not None: + self.mx_rendezvous.set_status(p2p_pb2.SOURCE_STATUS_INITIALIZING) self.update_weights_time = time.perf_counter() - update_weights_start_time self.logger.debug(f"Updated weights to step {next_ckpt_step} in {self.update_weights_time:.2f}s") diff --git a/src/prime_rl/trainer/models/conversions/__init__.py b/src/prime_rl/trainer/models/conversions/__init__.py index 1790ca1284..1c9247502a 100644 --- a/src/prime_rl/trainer/models/conversions/__init__.py +++ b/src/prime_rl/trainer/models/conversions/__init__.py @@ -2,7 +2,10 @@ A conversion is a function that writes one source tensor into one destination tensor, optionally producing a paired scale buffer. Each conversion is -registered under a string name (e.g. ``"fp8_128x128"``). +registered under a string name (e.g. ``"fp8_128x128"``) and carries a +``compile_target`` tag plus a ``compile_metadata`` dict that downstream MX +clients (Phase 3a on ``ai-dynamo/modelexpress:kavink/post-2389-phase3-4``) +use to advertise the bytes' layout to receivers. Resolution flow at startup: @@ -10,29 +13,71 @@ :func:`select_default_conversion` to pick one conversion name to use as the default for every spec that doesn't pin its own. The choice is driven entirely by ``config.quantization_config`` (or its absence). + The resolver is **table-driven** (see ``_DEFAULT_RULES``) so adding a + new kernel = adding one row, not editing if/else chains. 2. For each :class:`~prime_rl.trainer.models.conversion_spec.ConversionSpec`, :func:`resolve` returns the registry entry — explicit ``conversion_type`` on the spec wins, otherwise the startup-chosen default applies. The registry never inspects destination buffer dtype; slot allocation is owned by the transfer slot builder. + +When the Phase 2 graduation of ``MxRendezvous`` onto +``MxV2TrainingPublisher`` lands (see +``KavinKrishnan/prime-rl:kavink/post-2389-phase2-rendezvous-fixes``), the +publisher reads each tensor's resolved ``ConversionEntry.compile_target`` +and ``compile_metadata`` and tags ``TensorDescriptorV2`` accordingly. +Receivers filter via ``MxV2RefitReceiver.discover_v2_sources( +compile_target_filter=…, required_compile_metadata=…)``. Until graduation +lands, the fields are populated but unused — callers can read them via +``ConversionEntry.compile_target`` to plumb manually if needed. """ from __future__ import annotations -from dataclasses import dataclass -from typing import Callable +from dataclasses import dataclass, field +from typing import Any, Callable from torch import Tensor -from transformers import AutoConfig ConversionFn = Callable[[Tensor, Tensor, "Tensor | None"], None] +# Canonical compile-target strings. Mirror the constants in +# ``modelexpress.shape_descriptors`` (Phase 3a, kavink/post-2389-phase3-4) +# so the two repos use exactly the same vocabulary without a hard import +# dependency in either direction. +COMPILE_TARGET_HF_RAW = "hf_raw" +COMPILE_TARGET_DEEPGEMM_FP8 = "deep_gemm_fp8" +COMPILE_TARGET_CUTLASS_FP8 = "cutlass_fp8" +COMPILE_TARGET_VLLM_FUSED = "vllm_fused" +COMPILE_TARGET_TRTLLM = "trtllm" + + @dataclass(frozen=True) class ConversionEntry: + """Registry record for one trainer→inference conversion kernel. + + Fields: + fn: The actual conversion function. Signature + ``(src, out, scale_out_or_None) -> None``. + requires_scale: True if ``fn`` writes a scale buffer; the slot + builder must allocate one. + compile_target: One of the ``COMPILE_TARGET_*`` strings. Identifies + the layout family the output bytes belong to. Receivers filter + on this via the v2 MX client. Default ``"hf_raw"`` means "no + kernel-specific layout, plain HF state-dict". + compile_metadata: Free-form key/value blob describing the specific + compile invocation (e.g. ``{"block_size": 128, + "scale_layout": "K-major"}``). Receivers should treat a + mismatch on any byte-affecting field as a hard reject even + if ``compile_target`` matches. + """ + fn: ConversionFn requires_scale: bool + compile_target: str = COMPILE_TARGET_HF_RAW + compile_metadata: dict[str, Any] = field(default_factory=dict) _REGISTRY: dict[str, ConversionEntry] = {} @@ -43,10 +88,17 @@ def register( fn: ConversionFn, *, requires_scale: bool, + compile_target: str = COMPILE_TARGET_HF_RAW, + compile_metadata: dict[str, Any] | None = None, ) -> None: if name in _REGISTRY: raise ValueError(f"conversion {name!r} is already registered") - _REGISTRY[name] = ConversionEntry(fn=fn, requires_scale=requires_scale) + _REGISTRY[name] = ConversionEntry( + fn=fn, + requires_scale=requires_scale, + compile_target=compile_target, + compile_metadata=dict(compile_metadata) if compile_metadata else {}, + ) def get(name: str) -> ConversionEntry: @@ -55,29 +107,86 @@ def get(name: str) -> ConversionEntry: return _REGISTRY[name] +def registered_names() -> list[str]: + """Snapshot the currently-registered conversion names. Used in tests + diagnostics.""" + return sorted(_REGISTRY) + + +# Table-driven default selection. Each row is a predicate on the parsed +# HF ``quantization_config`` plus the conversion name to return when it +# matches. Walked in order; first match wins. Extending support for a new +# kernel = appending one row (or registering a row from the kernel's +# module on import — see how cutlass_fp8.py does this). +_QuantPredicate = Callable[[dict[str, Any]], bool] +_DEFAULT_RULES: list[tuple[_QuantPredicate, str]] = [] + + +def register_default_rule( + predicate: _QuantPredicate, + name: str, + *, + insert_first: bool = False, +) -> None: + """Add a rule to the default-conversion resolver. + + Args: + predicate: callable taking the dict form of the HF + ``quantization_config`` (always non-None — the resolver + short-circuits to ``"bf16_cast"`` when no quantization_config + is present). Return True to claim this config. + name: registered conversion name to return on match. Must already + be in ``_REGISTRY`` (or be registered before + ``select_default_conversion`` is called). + insert_first: if True, prepend the rule so it beats earlier- + registered rules. Use sparingly — preferred is to append and + let earlier rules with stricter predicates win. + """ + pair = (predicate, name) + if insert_first: + _DEFAULT_RULES.insert(0, pair) + else: + _DEFAULT_RULES.append(pair) + + def select_default_conversion(inference_model_name: str) -> str: """Pick the default conversion name for the given inference model. - Loads the HF config and inspects ``quantization_config``: - - * absent → ``"bf16_cast"`` (no quantization; trainer→inference is a - plain dtype cast). - * ``quant_method == "fp8"`` with ``weight_block_size == [128, 128]`` → - ``"fp8_128x128"``. - * anything else → :class:`NotImplementedError`. + Loads the HF config and inspects ``quantization_config``. When no + quantization_config is present we short-circuit to ``"bf16_cast"`` so + test environments without a real HF download can still exercise the + default path. When present, we walk the ``_DEFAULT_RULES`` table in + order and return the first matching name. If nothing matches the + function raises :class:`NotImplementedError` with the full set of + registered conversions in the message — extend support by adding a + row to ``_DEFAULT_RULES`` (see :func:`register_default_rule`) from + the kernel's own module. """ + # Deferred import: ``transformers`` is a heavy dep we don't want to + # pay at registry-load time (the registry is imported by tests and + # tooling that have no HF download capability). The function is the + # only place that needs it. + from transformers import AutoConfig + config = AutoConfig.from_pretrained(inference_model_name) quant = getattr(config, "quantization_config", None) if quant is None: return "bf16_cast" if hasattr(quant, "to_dict"): quant = quant.to_dict() - method = quant["quant_method"] - block_size = tuple(quant.get("weight_block_size") or ()) - if method == "fp8" and block_size == (128, 128): - return "fp8_128x128" + for predicate, name in _DEFAULT_RULES: + try: + if predicate(quant): + return name + except Exception: + # A predicate that raises on an unexpected config shape should + # not crash the resolver — treat it as "doesn't match" and + # move on. This keeps registry hooks robust to model-name + # weirdness without forcing every predicate to be defensive. + continue raise NotImplementedError( - f"unsupported inference quantization: quant_method={method!r}, weight_block_size={block_size}" + f"unsupported inference quantization: {quant!r}; " + f"registered conversions: {sorted(_REGISTRY)}; " + f"register a new rule via prime_rl.trainer.models.conversions.register_default_rule" ) @@ -88,12 +197,20 @@ def resolve(conversion_type: str | None, default: str) -> ConversionEntry: from prime_rl.trainer.models.conversions import bf16_cast as _bf16_cast # noqa: E402, F401 from prime_rl.trainer.models.conversions import fp8_blockwise as _fp8_blockwise # noqa: E402, F401 +from prime_rl.trainer.models.conversions import cutlass_fp8 as _cutlass_fp8 # noqa: E402, F401 __all__ = [ + "COMPILE_TARGET_CUTLASS_FP8", + "COMPILE_TARGET_DEEPGEMM_FP8", + "COMPILE_TARGET_HF_RAW", + "COMPILE_TARGET_TRTLLM", + "COMPILE_TARGET_VLLM_FUSED", "ConversionEntry", "ConversionFn", - "register", "get", + "register", + "register_default_rule", + "registered_names", "resolve", "select_default_conversion", ] diff --git a/src/prime_rl/trainer/models/conversions/bf16_cast.py b/src/prime_rl/trainer/models/conversions/bf16_cast.py index 16b8dae4fe..bb2450d21c 100644 --- a/src/prime_rl/trainer/models/conversions/bf16_cast.py +++ b/src/prime_rl/trainer/models/conversions/bf16_cast.py @@ -5,7 +5,10 @@ import torch from torch import Tensor -from prime_rl.trainer.models.conversions import register +from prime_rl.trainer.models.conversions import ( + COMPILE_TARGET_HF_RAW, + register, +) def bf16_cast(src: Tensor, out: Tensor, scale_out: Tensor | None = None) -> None: @@ -18,5 +21,17 @@ def fp32_cast(src: Tensor, out: Tensor, scale_out: Tensor | None = None) -> None out.copy_(src.to(torch.float32)) -register("bf16_cast", bf16_cast, requires_scale=False) -register("fp32_cast", fp32_cast, requires_scale=False) +register( + "bf16_cast", + bf16_cast, + requires_scale=False, + compile_target=COMPILE_TARGET_HF_RAW, + compile_metadata={"dtype": "bfloat16"}, +) +register( + "fp32_cast", + fp32_cast, + requires_scale=False, + compile_target=COMPILE_TARGET_HF_RAW, + compile_metadata={"dtype": "float32"}, +) diff --git a/src/prime_rl/trainer/models/conversions/cutlass_fp8.py b/src/prime_rl/trainer/models/conversions/cutlass_fp8.py new file mode 100644 index 0000000000..e24a9b0c72 --- /dev/null +++ b/src/prime_rl/trainer/models/conversions/cutlass_fp8.py @@ -0,0 +1,105 @@ +"""Cutlass-style FP8 e4m3 with per-output-channel scaling. Registered as +``"cutlass_fp8_e4m3_per_channel"``. + +Layout contract (matches cutlass ``scaled_mm`` + vLLM's native FP8 path): + +* 2D linear weights: ``W.shape == (out_features, in_features)``, scale is + one float32 per output row → ``scale.shape == (out_features,)``. +* 3D stacked-expert MoE weights: + ``W.shape == (num_local_experts, out_features, in_features)``, scale is + one float32 per (expert, output-row) → ``scale.shape == (num_local_experts, out_features)``. + +Dispatches between the 2D and 3D paths via :func:`fp8_per_channel_quantize_into` +based on ``src.ndim`` — same dispatch convention as ``fp8_128x128``. + +Tagged with ``compile_target="cutlass_fp8"`` so receivers running cutlass +kernels can filter for it via the v2 MX client's +``discover_v2_sources(compile_target_filter={"cutlass_fp8"})`` (Phase 3b, +``ai-dynamo/modelexpress:kavink/post-2389-phase3-4``). + +``compile_metadata`` documents the byte-affecting choices: + +* ``dtype``: ``"e4m3"`` (vs ``"e5m2"`` for higher-range cutlass variants — + add as a separate entry when needed). +* ``scale_layout``: ``"per_channel"`` — receiver must allocate a 1D scale + per output row, not a 2D blockwise scale. +* ``scale_axis``: ``-1`` — reduction was over the input-features axis; + receiver dequantizes by broadcasting scale along the same axis. +* ``activation_scheme``: ``"dynamic"`` — matches HF's + ``quantization_config.activation_scheme="dynamic"`` for cutlass FP8. + +Adding a sibling cutlass entry (e.g. per-token activations, e5m2, etc.) is +~80 LOC in another file that calls :func:`register` and +:func:`register_default_rule` for its own HF-config signature. +""" + +from __future__ import annotations + +from torch import Tensor + +from prime_rl.trainer.models.conversions import ( + COMPILE_TARGET_CUTLASS_FP8, + register, + register_default_rule, +) +from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize_into + + +def cutlass_fp8_e4m3_per_channel( + src: Tensor, + out: Tensor, + scale_out: Tensor | None, +) -> None: + """Quantize ``src`` (bf16 or fp32) into per-channel FP8 e4m3. + + Writes into preallocated ``out`` (e4m3) + ``scale_out`` (float32). + Dispatches 2D vs 3D via ``src.ndim`` — same convention as + ``fp8_128x128``. + """ + assert scale_out is not None, ( + "cutlass_fp8_e4m3_per_channel requires a scale_out buffer" + ) + fp8_per_channel_quantize_into(src, out=out, sf=scale_out) + + +register( + "cutlass_fp8_e4m3_per_channel", + cutlass_fp8_e4m3_per_channel, + requires_scale=True, + compile_target=COMPILE_TARGET_CUTLASS_FP8, + compile_metadata={ + "dtype": "e4m3", + "scale_layout": "per_channel", + "scale_axis": -1, + "activation_scheme": "dynamic", + }, +) + + +def _is_cutlass_fp8_per_channel(quant: dict) -> bool: + """HF ``quantization_config`` signature for cutlass per-channel FP8. + + Two recognised shapes: + + * ``{"quant_method": "fp8", "weight_block_size": None, + "activation_scheme": "dynamic"}`` — what vLLM and most cutlass- + targeting checkpoints publish. + * ``{"quant_method": "fp8", "quant_format": "cutlass"}`` — used by a + few model cards (Qwen3-MoE FP8 cutlass variants in particular) + that disambiguate cutlass from DeepGemm by setting an explicit + format string instead of leaving ``weight_block_size`` empty. + + The DeepGemm 128x128 rule (registered earlier) takes precedence when + both predicates would match because that rule was registered before + this one in ``_DEFAULT_RULES``. + """ + if quant.get("quant_method") != "fp8": + return False + if quant.get("quant_format") == "cutlass": + return True + block_size = tuple(quant.get("weight_block_size") or ()) + activation_scheme = quant.get("activation_scheme") + return block_size == () and activation_scheme == "dynamic" + + +register_default_rule(_is_cutlass_fp8_per_channel, "cutlass_fp8_e4m3_per_channel") diff --git a/src/prime_rl/trainer/models/conversions/fp8_blockwise.py b/src/prime_rl/trainer/models/conversions/fp8_blockwise.py index 3a9256ab7d..2a5907a9a0 100644 --- a/src/prime_rl/trainer/models/conversions/fp8_blockwise.py +++ b/src/prime_rl/trainer/models/conversions/fp8_blockwise.py @@ -1,14 +1,20 @@ """FP8 e4m3 blockwise quantization, 128x128 blocks. Registered as ``"fp8_128x128"``. Dispatches between the 2D linear layer path and the 3D stacked-expert path -based on ``src.ndim``. +based on ``src.ndim``. Tagged with ``compile_target="deep_gemm_fp8"`` so +receivers running DeepGemm kernels can filter for it via the v2 MX +client's ``discover_v2_sources(compile_target_filter=…)`` (Phase 3b). """ from __future__ import annotations from torch import Tensor -from prime_rl.trainer.models.conversions import register +from prime_rl.trainer.models.conversions import ( + COMPILE_TARGET_DEEPGEMM_FP8, + register, + register_default_rule, +) from prime_rl.trainer.models.fp8 import fp8_block_quantize, grouped_fp8_block_quantize @@ -20,4 +26,23 @@ def fp8_128x128(src: Tensor, out: Tensor, scale_out: Tensor | None) -> None: fp8_block_quantize(src, out=out, sf=scale_out) -register("fp8_128x128", fp8_128x128, requires_scale=True) +register( + "fp8_128x128", + fp8_128x128, + requires_scale=True, + compile_target=COMPILE_TARGET_DEEPGEMM_FP8, + compile_metadata={ + "dtype": "e4m3", + "scale_layout": "blockwise", + "block_size": [128, 128], + }, +) + +# HF config signature for DeepGemm-style FP8: 128x128 blockwise. +register_default_rule( + lambda quant: ( + quant.get("quant_method") == "fp8" + and tuple(quant.get("weight_block_size") or ()) == (128, 128) + ), + "fp8_128x128", +) diff --git a/src/prime_rl/trainer/models/fp8.py b/src/prime_rl/trainer/models/fp8.py index c04bf2f3b7..35576b0303 100644 --- a/src/prime_rl/trainer/models/fp8.py +++ b/src/prime_rl/trainer/models/fp8.py @@ -86,3 +86,57 @@ def grouped_fp8_block_quantize( if sf is not None: sf.copy_(s_accum) return q_accum, s_accum + + +# ---------------------------------------------------------------------------- +# Per-output-channel FP8 (cutlass-style): one scale per row of W. Used by +# cutlass scaled_mm + vLLM's native FP8 path. For a 2D weight of shape +# (out_features, in_features), reduction is over in_features (axis=-1) and +# the resulting scale has shape (out_features,). For a 3D stacked-expert +# weight of shape (num_local_experts, out_features, in_features) we run the +# same recipe per expert, producing a (num_local_experts, out_features) +# scale tensor. No padding / block reshuffling — the bytes go out in the +# same layout the trainer holds them in, which matches cutlass's +# RowMajor + per-channel scale convention. +# ---------------------------------------------------------------------------- + + +def fp8_per_channel_quantize( + weight: Tensor, +) -> tuple[Tensor, Tensor]: + """Per-output-channel symmetric FP8 e4m3 quantization. + + Supports both 2D ``(out, in)`` linear weights and 3D + ``(E, out, in)`` stacked-expert weights via the same code path. + Returns ``(quantized, scale)`` where ``scale`` has shape + ``weight.shape[:-1]`` (i.e. one scalar per output row, per expert). + """ + if weight.ndim not in (2, 3): + raise ValueError( + f"fp8_per_channel_quantize expects 2D or 3D, got shape={tuple(weight.shape)}" + ) + fp8_max = torch.finfo(torch.float8_e4m3fn).max # 448 for e4m3 + # amax over the innermost (input-features) axis. + amax = weight.detach().float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + scale = (amax / fp8_max).clamp(min=1e-12) + q = (weight.float() / scale).clamp(-fp8_max, fp8_max).to(torch.float8_e4m3fn) + return q.contiguous(), scale.squeeze(-1).to(torch.float32).contiguous() + + +def fp8_per_channel_quantize_into( + weight: Tensor, + out: Tensor | None = None, + sf: Tensor | None = None, +) -> tuple[Tensor, Tensor]: + """Per-channel FP8 quantize, optionally writing into preallocated buffers. + + Shape contract: + - ``out.shape == weight.shape``, dtype ``torch.float8_e4m3fn`` + - ``sf.shape == weight.shape[:-1]``, dtype ``torch.float32`` + """ + q, s = fp8_per_channel_quantize(weight) + if out is not None: + out.copy_(q) + if sf is not None: + sf.copy_(s) + return q, s diff --git a/src/prime_rl/trainer/rl/broadcast/__init__.py b/src/prime_rl/trainer/rl/broadcast/__init__.py index d3883cbd1f..549f2d9273 100644 --- a/src/prime_rl/trainer/rl/broadcast/__init__.py +++ b/src/prime_rl/trainer/rl/broadcast/__init__.py @@ -23,5 +23,10 @@ def setup_weight_broadcast( assert parallel_dims is not None, "nixl_mx requires parallel_dims" return NIXLMxWeightBroadcast(output_dir, config, parallel_dims) + elif config.type == "mx_v2": + from prime_rl.trainer.rl.broadcast.nixl_mx_v2 import NIXLMxV2WeightBroadcast + + assert parallel_dims is not None, "mx_v2 requires parallel_dims" + return NIXLMxV2WeightBroadcast(output_dir, config, parallel_dims) else: raise ValueError(f"Invalid weight broadcast type: {config.type}") diff --git a/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py b/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py new file mode 100644 index 0000000000..488cb48d0d --- /dev/null +++ b/src/prime_rl/trainer/rl/broadcast/nixl_mx_v2.py @@ -0,0 +1,287 @@ +"""v2 trainer-side weight broadcast using the ModelExpress v2 fat clients. + +This is the v2 of :class:`NIXLMxWeightBroadcast` (PR #2389), built on +:class:`MxV2TrainingPublisher` instead of the in-tree :class:`MxRendezvous`. +The data plane is unchanged — NIXL RDMA, GPU-direct, no CPU staging — +but the control-plane glue (heartbeat, freshest-per-rank dedup, same-rank +routing, compile_target metadata, multi-source slice picker, tree +fan-out) is graduated onto the published MX v2 surface. + +The trainer-side conversion (FP8 packing, fusion, sharding into +``Sharded`` / ``Gathered`` / ``Expert`` slots) is *unchanged* from +PR #2389 — prime-rl still owns that kernel. What changes is **how the +already-converted bytes get published** (one ``publisher.publish()`` +per step instead of a per-tensor ``post_write`` loop driven from the +trainer), and **what metadata rides along** (``compile_target`` + +``compile_metadata`` from the conversion registry + per-tensor MoE +expert ownership). + +HSDP: when ``dp_replicate > 1`` only the primary replica (``dp_replicate +rank 0``) participates. Non-primary replicas hold bit-identical weights; +broadcasting a second copy would be pure waste. + +See :file:`docs/proposals/post-pr2389-mx-v2.md` for the design rationale +and migration plan. +""" + +from __future__ import annotations + +import time +from pathlib import Path +from typing import Any + +import torch +import torch.distributed as dist +import torch.nn as nn +from modelexpress.nemo_rl_v2 import MxV2TrainingPublisher, TrainerWorldLayout +from transformers import AutoConfig + +from prime_rl.configs.trainer import MxV2WeightBroadcastConfig +from prime_rl.trainer.models import PreTrainedModelPrimeRL +from prime_rl.trainer.models.conversions import select_default_conversion +from prime_rl.trainer.parallel_dims import ParallelDims +from prime_rl.trainer.rl.broadcast.base import WeightBroadcast +from prime_rl.trainer.runs import get_multi_run_manager +from prime_rl.trainer.utils import get_world +from prime_rl.transport.classic_cuda_pool import classic_cuda_alloc +from prime_rl.transport.nixl_agent import make_agent_name, pin_ucx_rail + + +class NIXLMxV2WeightBroadcast(WeightBroadcast): + """v2 weight broadcast over NIXL + ModelExpress fat clients. + + Selectable from config via ``weight_broadcast.type = "mx_v2"``. + Coexists with the existing ``"nixl_mx"`` path (PR #2389); no + behavior of ``"nixl_mx"`` is affected by importing this module. + + Args: + output_dir: training output directory (forwarded to base class). + config: parsed :class:`MxV2WeightBroadcastConfig`. + parallel_dims: ``ParallelDims`` instance describing the trainer's + FSDP / TP / EP / DP layout — used to construct the + ``TrainerWorldLayout`` carried in v2 metadata. + """ + + def __init__( + self, + output_dir: Path, + config: MxV2WeightBroadcastConfig, + parallel_dims: ParallelDims, + ) -> None: + super().__init__(output_dir) + self.config = config + self.world = get_world() + self.parallel_dims = parallel_dims + + self.is_initialized = False + self._publisher: MxV2TrainingPublisher | None = None + self._model_slots: list[Any] | None = None + self._conversion = None + self._hf_config = None + + if self.is_primary_hsdp_rank: + pin_ucx_rail(torch.cuda.current_device()) + + self._multi_run_manager = get_multi_run_manager() + + # ------------------------------------------------------------------ + # HSDP gate — only rank 0 of dp_replicate publishes + # ------------------------------------------------------------------ + + @property + def is_primary_hsdp_rank(self) -> bool: + if self.parallel_dims.dp_replicate_enabled: + return self.parallel_dims.get_mesh("dp_replicate").get_local_rank() == 0 + return True + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def _build_world_layout(self) -> TrainerWorldLayout: + """Translate prime-rl's ParallelDims into the MX v2 world layout.""" + return TrainerWorldLayout( + fsdp_world_size=getattr(self.parallel_dims, "dp_shard_size", 1), + tp_world_size=getattr(self.parallel_dims, "tp_size", 1), + pp_world_size=getattr(self.parallel_dims, "pp_size", 1), + ep_world_size=getattr(self.parallel_dims, "ep_size", 1), + ) + + def lazy_init(self, model: PreTrainedModelPrimeRL) -> None: + """Build the v2 publisher + slot layout on first call. + + The model isn't available at ``__init__`` time (the WeightBroadcast + instance is constructed before the trainer model is materialized), + so slot construction and publisher initialization happen on the + first ``broadcast_weights`` call. + """ + if self.is_initialized: + return + + self._hf_config = AutoConfig.from_pretrained(self.config.inference_model_name) + self._conversion = select_default_conversion(self.config.inference_model_name) + + # Pull-mode (this broadcast type) + same-rank routing means each + # inference rank only contacts ONE trainer rank, so the trainer must + # have the FULL tensor on each rank — ShardedSlot's 1/N FSDP-shard + # would deliver only 1/N to each receiver and vLLM's load_weights + # would refuse the shape mismatch. Force every non-expert slot + # into GatheredSlot (DTensor.full_tensor() allgather + full tensor + # held per rank) by raising the threshold beyond any realistic + # weight size. Expert slots remain ExpertSlot (each rank owns its + # EP shard, which is exactly what same-rank pull mode wants). + # + # In push-mode (PR #2389's `nixl_mx`) ShardedSlot is correct + # because each trainer rank writes its FSDP shard directly into + # the inference's pre-allocated buffer at its rank-specific + # offset — there each receiver assembles N shards from N senders. + # Pull-mode receivers can't do that without Phase-4 multi-source + # slicing in the engine adapter; until that lands, gather first. + from prime_rl.trainer.models import slots as _slots_mod + if getattr(self, "_orig_small_non_expert_bytes", None) is None: + self._orig_small_non_expert_bytes = _slots_mod.SMALL_NON_EXPERT_BYTES + _slots_mod.SMALL_NON_EXPERT_BYTES = 1 << 60 + + try: + with classic_cuda_alloc(): + self._model_slots = model.build_slots( + self.parallel_dims, self._conversion, self._hf_config.torch_dtype + ) + finally: + # Restore the threshold so we don't perturb other code paths + # (e.g. nixl_mx broadcast running in the same process). + _slots_mod.SMALL_NON_EXPERT_BYTES = self._orig_small_non_expert_bytes + + # The v2 publisher owns the NIXL agent + MX client + heartbeat. + # We pass our rank as ``worker_rank``; receivers with + # ``same_rank_only=True`` (Phase 2 default) will only pull from + # the trainer rank matching their own. + world_layout = self._build_world_layout() + self._publisher = MxV2TrainingPublisher( + agent_name=make_agent_name("trainer", self.world.rank), + device_id=torch.cuda.current_device(), + mx_server_url=f"{self.config.host}:{self.config.port}", + worker_rank=self.world.rank, + world_layout=world_layout, + ) + self._publisher.initialize( + model_name=self.config.inference_model_name, + dtype=str(self._hf_config.torch_dtype).replace("torch.", ""), + ) + self.is_initialized = True + # `select_default_conversion` may return either a registered conversion + # object (with .compile_target + .compile_metadata) on the newer + # conversion registry, OR a plain string ('bf16_cast', 'fp8_pack', ...) + # on older registries. Use getattr so we degrade gracefully. + conversion_target = getattr(self._conversion, "compile_target", str(self._conversion)) + self.logger.info( + f"[mx_v2] publisher initialized: rank={self.world.rank} " + f"layout={world_layout.encode()} " + f"compile_target={conversion_target}" + ) + + # ------------------------------------------------------------------ + # Per-step broadcast + # ------------------------------------------------------------------ + + @torch.no_grad() + def broadcast_weights(self, model: nn.Module, step: int) -> None: + """Publish version ``step`` of the converted weights. + + Per-step lifecycle: + + 1. (HSDP) only the primary replica participates; others barrier. + 2. Fill the conversion slots from ``model.state_dict()`` — + **same code path PR #2389 uses**, prime-rl owns the kernel. + 3. For each slot's buffers, call ``publisher.add_tensor(...)`` + tagged with the conversion's ``compile_target`` / + ``compile_metadata`` (Phase 3) and any per-tensor MoE + expert metadata. + 4. ``publisher.publish(version=step)`` + ``mark_ready()`` — + catalog entry now visible to receivers polling for + ``min_version=step``. + 5. Bump the heartbeat (the publisher's ``HeartbeatThread`` + runs in the background; nothing to do here). + """ + if self.is_primary_hsdp_rank: + self.lazy_init(model) + + if self.world.is_master: + for idx in self._multi_run_manager.used_idxs: + if self._multi_run_manager.ready_to_update[idx]: + self._multi_run_manager.ready_to_update[idx] = False + + dist.barrier() + + if not self.is_primary_hsdp_rank: + # Non-primary HSDP replicas: bit-identical weights; nothing to publish. + dist.barrier() + return + + start = time.perf_counter() + + # 2. Fill slots from the live model state-dict via the conversion. + # This is where FP8 packing + fusion happens; same code path + # as PR #2389. We do NOT change the kernel. + # GatheredSlot's API takes only the state_dict; the conversion + # is baked into the slot at `from_spec` creation time. + state_dict = model.state_dict() + for slot in self._model_slots: + slot.convert(state_dict) + + # 3. Register every slot tensor with the v2 publisher, tagged with + # compile_target + compile_metadata so receivers can refuse + # mismatched layouts at discovery (Phase 3). + # Falls back to the safe "hf_raw" default when: + # - publish_compile_target=False (caller opts out), or + # - the conversion is on an older registry without the + # compile_target/compile_metadata fields (graceful + # degradation; back-compat with PR #2389 conversions). + if self.config.publish_compile_target: + compile_target = getattr(self._conversion, "compile_target", "hf_raw") + compile_metadata = getattr(self._conversion, "compile_metadata", None) + else: + compile_target = "hf_raw" + compile_metadata = None + + n_tensors = 0 + for slot in self._model_slots: + slot_is_expert = bool(getattr(slot, "is_expert", False)) + slot_expert_axis = int(getattr(slot, "expert_axis", 0)) + slot_owned_experts = tuple(getattr(slot, "owned_expert_ids", ())) + for buf_key, tensor, _ in slot.buffers: + self._publisher.add_tensor( + name=buf_key, + tensor=tensor, + is_expert=slot_is_expert, + expert_axis=slot_expert_axis, + owned_expert_ids=slot_owned_experts, + compile_target=compile_target, + compile_metadata=compile_metadata, + ) + n_tensors += 1 + + # 4. Publish + mark READY in one shot. + mx_source_id = self._publisher.publish(version=step) + self._publisher.mark_ready() + + elapsed = time.perf_counter() - start + self.logger.info( + f"[mx_v2] publish step={step} tensors={n_tensors} " + f"compile_target={compile_target} mx_source_id={mx_source_id} " + f"elapsed={elapsed:.3f}s" + ) + + dist.barrier() + + # ------------------------------------------------------------------ + # Teardown + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + if self._publisher is not None: + try: + self._publisher.shutdown() + finally: + self._publisher = None + self.is_initialized = False diff --git a/src/prime_rl/transport/mx_rendezvous.py b/src/prime_rl/transport/mx_rendezvous.py index 817c1ab7dc..10414808a5 100644 --- a/src/prime_rl/transport/mx_rendezvous.py +++ b/src/prime_rl/transport/mx_rendezvous.py @@ -7,20 +7,89 @@ (role baked into ``SourceIdentity.extra_parameters`` so trainer/inference hash to different ``mx_source_id``s) and the polling loop, and delegates all gRPC to ``modelexpress.MxClient``. + +Phase-2 fixes (post-#2389) baked in: + +- **Heartbeat**: spawning :class:`MxRendezvous` starts a background + :class:`HeartbeatThread` on ``publish()`` so the MX server's reaper can + detect crashed workers and mark them ``STALE``. Crashed workers were + leaving permanent ``READY`` rows that broke restarts on GB200. +- **Freshest-per-(role, rank) dedup**: when multiple entries for the same + (role, rank) live in the catalog (e.g. after a partial pod restart), + callers see only the most recently updated one. This is the second of + the two GB200 runtime patches. +- **Same-rank-only filter**: optional ``same_rank_only=True`` on the wait + methods restricts results to peers with ``worker_rank == self.rank``, + closing the cross-subnet full-mesh path that fails on GCP multi-NIC + RDMA fabrics. Off by default; the caller opts in. """ from __future__ import annotations +import logging import time import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Iterable, Literal from modelexpress import p2p_pb2 from modelexpress.client import MxClient +# HeartbeatThread moved in MX 0.4+ from ``modelexpress.heartbeat`` to +# ``modelexpress.metadata.heartbeat`` as part of the metadata-module +# reorganization. Tolerate both so this code works against the v0.5.2 +# image (MX 0.3.0, old path) and the newer ``kavink/nemo_rl_moe`` MX +# (which exposes the new path). The MX-side migration tracker is in +# ``pensieve/RL/PrimeRL/09_rfc_updates_needed.md``. +try: + from modelexpress.metadata.heartbeat import HeartbeatThread # MX 0.4+ +except ImportError: # pragma: no cover - environment-dependent + from modelexpress.heartbeat import HeartbeatThread # MX 0.3 + Role = Literal["trainer", "inference", "orchestrator"] +_log = logging.getLogger("prime_rl.transport.mx_rendezvous") + + +def _freshest_per_rank( + instances: Iterable[p2p_pb2.SourceInstanceRef], + *, + metas: dict[str, int], +) -> list[p2p_pb2.SourceInstanceRef]: + """Dedup peers by ``worker_rank``, keeping the one with the largest + ``updated_at`` from ``metas``. + + ``metas`` maps ``worker_id`` → ``updated_at`` (ms-epoch as reported by + the MX server). Instances whose ``worker_id`` is missing from ``metas`` + are kept (we err on the side of "visible but not freshness-known"). + + This is the Phase-2 codification of the runtime patch we applied on + GB200: the prime-rl trainer's NIXL agent rotated ``mx_source_id`` on + restart, leaving a stale ``READY`` entry at the same ``worker_rank``; + receivers picked the stale one and got ``NIXL_ERR_NOT_ALLOWED`` when + they tried to ``add_remote_agent``. + """ + by_rank: dict[int, tuple[int, p2p_pb2.SourceInstanceRef]] = {} + for inst in instances: + ts = metas.get(inst.worker_id, 0) + cur = by_rank.get(inst.worker_rank) + if cur is None or ts > cur[0]: + by_rank[inst.worker_rank] = (ts, inst) + return [v[1] for _, v in sorted(by_rank.items())] + + +def _filter_same_rank( + instances: Iterable[p2p_pb2.SourceInstanceRef], *, rank: int +) -> list[p2p_pb2.SourceInstanceRef]: + """Keep only peers whose ``worker_rank == rank``. + + The cross-subnet full-mesh routing path failed on GCP GB200's multi-NIC + fabric — each rank has its own IB subnet, so trainer rank N can only + safely peer with inference rank N. Filtering at the rendezvous layer + prevents the broken connections from ever being attempted. + """ + return [inst for inst in instances if inst.worker_rank == rank] + @dataclass class MxRendezvous: @@ -47,11 +116,13 @@ class MxRendezvous: peer_world_size: int model_name: str worker_id: str = "" + enable_heartbeat: bool = True def __post_init__(self) -> None: if not self.worker_id: self.worker_id = str(uuid.uuid4()) self._mx_source_id: str | None = None + self._heartbeat: HeartbeatThread | None = None @property def peer_role(self) -> Role: @@ -84,21 +155,58 @@ def publish( nixl_metadata: bytes = b"", tensors: Iterable[p2p_pb2.TensorDescriptor] = (), ) -> str: - """Publish this worker's metadata. Returns the assigned ``mx_source_id``.""" + """Publish this worker's metadata. Returns the assigned ``mx_source_id``. + + Side effect (Phase 2): if ``enable_heartbeat`` is True, a + :class:`HeartbeatThread` is started after a successful publish so + the MX server's reaper can detect liveness. Heartbeat is idempotent + — calling ``publish()`` again on the same instance is a no-op for + the heartbeat (the existing thread keeps running). + """ worker = p2p_pb2.WorkerMetadata( worker_rank=self.rank, nixl_metadata=nixl_metadata, tensors=list(tensors), ) self._mx_source_id = self.client.publish_metadata(self._identity(self.role), worker, self.worker_id) + + if self.enable_heartbeat and self._heartbeat is None: + try: + self._heartbeat = HeartbeatThread( + mx_client=self.client, + mx_source_id=self._mx_source_id, + worker_id=self.worker_id, + worker_rank=self.rank, + nixl_manager=None, # prime-rl drives NIXL outside MX's manager + ) + self._heartbeat.start() + except Exception as e: # noqa: BLE001 + _log.warning( + "MxRendezvous: failed to start HeartbeatThread (role=%s rank=%s): %s", + self.role, + self.rank, + e, + ) + return self._mx_source_id + def close(self) -> None: + """Stop the heartbeat thread. Safe to call multiple times.""" + if self._heartbeat is not None: + try: + self._heartbeat.stop() + except Exception as e: # noqa: BLE001 + _log.warning("MxRendezvous: heartbeat.stop() failed: %s", e) + self._heartbeat = None + def wait_for_peers( self, *, status: int | None = None, timeout: float = 1200.0, poll_interval: float = 1.0, + same_rank_only: bool = False, + dedup_freshest_per_rank: bool = True, ) -> list[p2p_pb2.SourceInstanceRef]: """Block until ``peer_world_size`` peers of the counterpart role are visible. @@ -106,32 +214,74 @@ def wait_for_peers( status: If set, only count peers in this :class:`p2p_pb2.SourceStatus`. timeout: Wall-clock seconds to wait before raising :class:`TimeoutError`. poll_interval: Seconds between ``ListSources`` polls. + same_rank_only: If True, only return peers whose ``worker_rank`` + equals this rendezvous's own rank. Required on GB200's + multi-NIC fabric where cross-subnet routing fails. Off by + default to preserve the pre-Phase-2 single-NIC behaviour. + dedup_freshest_per_rank: If True (default), keep only the + freshest ``SourceInstanceRef`` per ``worker_rank``. This + neutralises the stale-READY-after-restart bug we caught on + GB200. Pass ``False`` to keep all duplicates (e.g. debug). """ - import logging - - _log = logging.getLogger("prime_rl.transport.mx_rendezvous") deadline = time.monotonic() + timeout peer_id = self._identity(self.peer_role) _logged = False while True: resp = self.client.list_sources(peer_id, status_filter=status) + kept = list(resp.instances) + if same_rank_only: + kept = _filter_same_rank(kept, rank=self.rank) + if dedup_freshest_per_rank and kept: + kept = _freshest_per_rank( + kept, metas=self._collect_updated_at(kept) + ) if not _logged: all_resp = self.client.list_sources(peer_id) _log.info( - f"wait_for_peers: role={self.peer_role} need={self.peer_world_size} " - f"found_with_status={len(resp.instances)} found_any={len(all_resp.instances)} " - f"status_filter={status} model={peer_id.model_name}" + "wait_for_peers: role=%s need=%s found_with_status=%s found_any=%s " + "post_filter=%s status_filter=%s model=%s same_rank_only=%s", + self.peer_role, + self.peer_world_size, + len(resp.instances), + len(all_resp.instances), + len(kept), + status, + peer_id.model_name, + same_rank_only, ) _logged = True - if len(resp.instances) >= self.peer_world_size: - return list(resp.instances) + if len(kept) >= self.peer_world_size: + return kept if time.monotonic() >= deadline: raise TimeoutError( f"timed out after {timeout}s waiting for {self.peer_world_size} " - f"{self.peer_role!r} peers (saw {len(resp.instances)})" + f"{self.peer_role!r} peers (saw {len(kept)} after filters; " + f"{len(resp.instances)} raw)" ) time.sleep(poll_interval) + def _collect_updated_at( + self, instances: Iterable[p2p_pb2.SourceInstanceRef] + ) -> dict[str, int]: + """Fetch ``updated_at`` per peer in one round of GetMetadata calls. + + Used by the freshest-per-rank dedup. Failures (missing worker, RPC + errors) are mapped to ``0`` so the stale entries lose to anything + with a real timestamp. + """ + out: dict[str, int] = {} + for inst in instances: + try: + resp = self.client.get_metadata(inst.mx_source_id, inst.worker_id) + except Exception: # noqa: BLE001 + out[inst.worker_id] = 0 + continue + if not getattr(resp, "found", False): + out[inst.worker_id] = 0 + continue + out[inst.worker_id] = int(getattr(resp.worker, "updated_at", 0) or 0) + return out + def wait_for_all_peers_ready( self, *, @@ -139,6 +289,8 @@ def wait_for_all_peers_ready( status: int = p2p_pb2.SOURCE_STATUS_READY, timeout: float = 1200.0, poll_interval: float = 0.05, + same_rank_only: bool = False, + dedup_freshest_per_rank: bool = True, ) -> list[p2p_pb2.SourceInstanceRef]: """Discover peer count from MX, then block until ALL of them reach ``status``. @@ -147,27 +299,45 @@ def wait_for_all_peers_ready( entries exist in MX (any status) and uses that count as the target. Each side publishes one entry per rank, so the count equals the peer's world size — no config plumbing needed. + + Phase-2 additions (``same_rank_only`` and ``dedup_freshest_per_rank``) + behave identically to :meth:`wait_for_peers`. """ target_role = role or self.peer_role peer_id = self._identity(target_role) deadline = time.monotonic() + timeout + def _apply_filters( + insts: list[p2p_pb2.SourceInstanceRef], + ) -> list[p2p_pb2.SourceInstanceRef]: + kept = insts + if same_rank_only: + kept = _filter_same_rank(kept, rank=self.rank) + if dedup_freshest_per_rank and kept: + kept = _freshest_per_rank( + kept, metas=self._collect_updated_at(kept) + ) + return kept + peer_count = 0 while peer_count == 0: - peer_count = len(self.client.list_sources(peer_id).instances) + insts = list(self.client.list_sources(peer_id).instances) + kept = _apply_filters(insts) + peer_count = len(kept) if peer_count == 0: if time.monotonic() >= deadline: raise TimeoutError(f"timed out waiting for {target_role!r} peers to appear in MX") time.sleep(poll_interval) while True: - matched = self.client.list_sources(peer_id, status_filter=status) - if len(matched.instances) >= peer_count: - return list(matched.instances) + insts = list(self.client.list_sources(peer_id, status_filter=status).instances) + kept = _apply_filters(insts) + if len(kept) >= peer_count: + return kept if time.monotonic() >= deadline: raise TimeoutError( f"timed out after {timeout}s waiting for {peer_count} " - f"{target_role!r} peers to reach status {status} (saw {len(matched.instances)})" + f"{target_role!r} peers to reach status {status} (saw {len(kept)})" ) time.sleep(poll_interval) diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index d633bbbba0..308f9d7fdd 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -539,3 +539,75 @@ async def _init(admin_client: AsyncClient, rank_offset: int) -> None: response.raise_for_status() await asyncio.gather(*[_init(admin_client, i * gpus_per_server) for i, admin_client in enumerate(admin_clients)]) + + +async def init_nixl_mx_v2_broadcast( + admin_clients: list[AsyncClient], + host: str, + port: int, + inference_world_size: int, + *, + publish_self_as_replica: bool = True, + listen_port: int | None = None, +) -> None: + """Initialize the ``mx_v2`` (pull-mode) receivers on inference servers. + + Mirrors :func:`init_nixl_mx_broadcast` but targets the v2 worker + extension (``NIXLMxV2WeightUpdateWorker``) which uses the published + :class:`MxWeightTransferEngine` adapter instead of the in-tree + :class:`MxRendezvous`. + """ + logger = get_logger() + gpus_per_server = inference_world_size // len(admin_clients) + + logger.info( + f"Initializing NIXL+MX v2 broadcast: {len(admin_clients)} servers, " + f"inference_world_size={inference_world_size}, gpus_per_server={gpus_per_server}, " + f"publish_self_as_replica={publish_self_as_replica}" + ) + + async def _init(admin_client: AsyncClient, rank_offset: int) -> None: + response = await admin_client.post( + "/init_nixl_mx_v2", + json={ + "host": host, + "port": port, + "rank_offset": rank_offset, + "publish_self_as_replica": publish_self_as_replica, + "listen_port": listen_port, + }, + ) + response.raise_for_status() + + await asyncio.gather(*[_init(admin_client, i * gpus_per_server) for i, admin_client in enumerate(admin_clients)]) + + +async def update_weights_v2( + admin_clients: list[AsyncClient], + step: int, + *, + compile_target_filter: list[str] | None = None, + timeout_seconds: float = 300.0, + same_rank_only: bool = True, +) -> list[dict]: + """Drive a v2 (pull-mode) refit on all inference servers. + + Mirrors the existing ``/update_weights`` poke but for the + ``mx_v2`` worker path. Returns the per-server metrics dicts so the + orchestrator can emit per-cycle timing to its dashboards. + """ + + async def _update(admin_client: AsyncClient) -> dict: + response = await admin_client.post( + "/update_weights_v2", + json={ + "step": int(step), + "compile_target_filter": compile_target_filter, + "timeout_seconds": float(timeout_seconds), + "same_rank_only": bool(same_rank_only), + }, + ) + response.raise_for_status() + return response.json() + + return list(await asyncio.gather(*[_update(c) for c in admin_clients])) diff --git a/tests/unit/inference/vllm/__init__.py b/tests/unit/inference/vllm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/inference/vllm/test_mx_v2_server_endpoints.py b/tests/unit/inference/vllm/test_mx_v2_server_endpoints.py new file mode 100644 index 0000000000..18b53af10a --- /dev/null +++ b/tests/unit/inference/vllm/test_mx_v2_server_endpoints.py @@ -0,0 +1,417 @@ +"""Unit tests for the ``mx_v2`` server-side glue. + +Three pieces tested here: + +1. The ``WORKER_EXTENSION_CLS["mx_v2"]`` entry in server.py — i.e. that + the worker-extension selector points at our new worker extension class. +2. The new HTTP endpoints ``/init_nixl_mx_v2`` and ``/update_weights_v2`` + on server.py — verified to forward to the right ``collective_rpc`` + method names with the right kwargs. +3. The orchestrator-side helpers ``init_nixl_mx_v2_broadcast`` and + ``update_weights_v2`` in client.py — verified to POST to the right + endpoints with the right JSON body. + +Plus the trainer-side selector dispatch (``setup_weight_broadcast`` for +``config.type == "mx_v2"``). + +We use ``importlib.util.spec_from_file_location`` to load each target +file against a stubbed dep graph, so the test runs anywhere torch + +pytest is present without prime-rl needing to be installed. +""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + + +_PRIME_RL_ROOT = Path(__file__).resolve().parents[4] +_SERVER_FILE = ( + _PRIME_RL_ROOT / "src" / "prime_rl" / "inference" / "vllm" / "server.py" +) +_CLIENT_FILE = _PRIME_RL_ROOT / "src" / "prime_rl" / "utils" / "client.py" +_BROADCAST_INIT_FILE = ( + _PRIME_RL_ROOT + / "src" + / "prime_rl" + / "trainer" + / "rl" + / "broadcast" + / "__init__.py" +) + + +# ---------------------------------------------------------------------------- +# 1. WORKER_EXTENSION_CLS table — read directly from the source AST so we +# don't have to install the package or stub anywhere near as much +# ---------------------------------------------------------------------------- + + +def _extract_worker_extension_cls(): + """Parse server.py and pull out the WORKER_EXTENSION_CLS dict literal. + + Avoids the import-graph problem entirely — we only need the table. + """ + import ast + + src = _SERVER_FILE.read_text() + tree = ast.parse(src) + for node in ast.walk(tree): + if isinstance(node, ast.Assign) and any( + isinstance(t, ast.Name) and t.id == "WORKER_EXTENSION_CLS" + for t in node.targets + ): + return { + key.value: value.value + for key, value in zip(node.value.keys, node.value.values) + if isinstance(key, ast.Constant) and isinstance(value, ast.Constant) + } + raise RuntimeError("WORKER_EXTENSION_CLS not found in server.py") + + +def test_worker_extension_cls_table_has_mx_v2_entry(): + table = _extract_worker_extension_cls() + assert "mx_v2" in table + assert ( + table["mx_v2"] + == "prime_rl.inference.vllm.worker.nixl_mx_v2.NIXLMxV2WeightUpdateWorker" + ) + + +def test_worker_extension_cls_table_preserves_existing_backends(): + """Adding mx_v2 must not have removed nccl / filesystem / nixl_mx.""" + table = _extract_worker_extension_cls() + assert "nccl" in table + assert "filesystem" in table + assert "nixl_mx" in table + # And nixl_mx vs mx_v2 are two distinct worker classes. + assert table["nixl_mx"] != table["mx_v2"] + + +# ---------------------------------------------------------------------------- +# 2. Server endpoints — load via spec_from_file_location with stubs +# ---------------------------------------------------------------------------- + + +def _install_server_stubs(): + """Stub the heavy server.py deps (vLLM, FastAPI bits, prime_rl imports). + + Just enough to let server.py's module-level statements run; we only need + to call the two new endpoint coroutines. + """ + # FastAPI bits + fake_request_cls = type("Request", (), {}) + fake_apirouter_cls = MagicMock(name="APIRouter") + fake_apirouter = MagicMock(name="apirouter_inst") + # Make APIRouter.post / get return identity decorators so the @router.post + # decorators in server.py work without registering anything. + fake_apirouter.post = lambda *a, **kw: (lambda f: f) + fake_apirouter.get = lambda *a, **kw: (lambda f: f) + fake_apirouter_cls.return_value = fake_apirouter + fake_jsonresponse = MagicMock(name="JSONResponse") + sys.modules["fastapi"] = types.SimpleNamespace( + Request=fake_request_cls, APIRouter=fake_apirouter_cls + ) + sys.modules["fastapi.responses"] = types.SimpleNamespace( + JSONResponse=fake_jsonresponse + ) + + # vllm bits + sys.modules["vllm"] = types.SimpleNamespace() + sys.modules["vllm.engine"] = types.SimpleNamespace() + sys.modules["vllm.engine.protocol"] = types.SimpleNamespace( + EngineClient=type("EngineClient", (), {}) + ) + sys.modules["vllm.entrypoints"] = types.SimpleNamespace() + sys.modules["vllm.entrypoints.openai"] = types.SimpleNamespace() + sys.modules["vllm.entrypoints.openai.api_server"] = types.SimpleNamespace( + State=type("State", (), {}), + init_app_state=MagicMock(), + run_headless=MagicMock(), + ) + sys.modules["vllm.entrypoints.openai.protocol"] = types.SimpleNamespace( + LoadLoRAAdapterRequest=type("LoadLoRAAdapterRequest", (), {}), + ErrorResponse=type("ErrorResponse", (), {}), + ) + sys.modules["vllm.utils"] = types.SimpleNamespace(FlexibleArgumentParser=MagicMock()) + # prime_rl deps used at top of server.py + sys.modules.setdefault("prime_rl", types.ModuleType("prime_rl")) + sys.modules.setdefault("prime_rl.utils", types.ModuleType("prime_rl.utils")) + sys.modules["prime_rl.utils.logger"] = types.SimpleNamespace( + get_logger=MagicMock(return_value=MagicMock(name="logger")), + setup_logger=MagicMock(), + ) + + # PrimeRlServingTokens etc. + sys.modules["prime_rl.inference"] = types.ModuleType("prime_rl.inference") + sys.modules["prime_rl.inference.vllm"] = types.ModuleType("prime_rl.inference.vllm") + sys.modules["prime_rl.inference.vllm.serving_tokens"] = types.SimpleNamespace( + PrimeRlServingTokens=type("PrimeRlServingTokens", (), {}) + ) + + +@pytest.fixture +def server_mod(): + """Load server.py with stubs in place.""" + # Wipe cached state + for k in list(sys.modules.keys()): + if k.startswith("prime_rl") or k.startswith("vllm") or k.startswith("fastapi"): + del sys.modules[k] + + _install_server_stubs() + + spec = importlib.util.spec_from_file_location( + "_test_server_under_test", _SERVER_FILE + ) + mod = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(mod) + except Exception as e: + pytest.skip( + f"server.py imports too much to stub cleanly: {e}; this test " + f"runs in CI where prime-rl IS installed" + ) + yield mod + + +@pytest.mark.asyncio +async def test_init_nixl_mx_v2_endpoint_dispatches_collective_rpc(server_mod): + fake_client = MagicMock() + fake_client.collective_rpc = AsyncMock() + fake_request = MagicMock() + fake_request.json = AsyncMock( + return_value={ + "host": "modelexpress-server.kavin.svc.cluster.local", + "port": 8001, + "rank_offset": 4, + "publish_self_as_replica": True, + "listen_port": None, + } + ) + + orig = getattr(server_mod, "engine_client", None) + server_mod.engine_client = lambda r: fake_client + try: + result = await server_mod.init_nixl_mx_v2(fake_request) + finally: + if orig is not None: + server_mod.engine_client = orig + + assert result == {"status": "ok"} + fake_client.collective_rpc.assert_called_once_with( + "init_nixl_mx_v2", + args=( + "modelexpress-server.kavin.svc.cluster.local", + 8001, + 4, + ), + kwargs={"publish_self_as_replica": True, "listen_port": None}, + ) + + +@pytest.mark.asyncio +async def test_update_weights_v2_endpoint_dispatches_collective_rpc(server_mod): + fake_metrics = [ + {"step": 42, "bytes_received": 536_870_912, "bandwidth_gbps": 52.4} + ] + fake_client = MagicMock() + fake_client.collective_rpc = AsyncMock(return_value=fake_metrics) + fake_request = MagicMock() + fake_request.json = AsyncMock( + return_value={ + "step": 42, + "compile_target_filter": ["cutlass_fp8"], + "timeout_seconds": 180.0, + "same_rank_only": True, + } + ) + + orig = getattr(server_mod, "engine_client", None) + server_mod.engine_client = lambda r: fake_client + try: + result = await server_mod.update_weights_v2(fake_request) + finally: + if orig is not None: + server_mod.engine_client = orig + + assert result == {"status": "ok", "metrics": fake_metrics} + fake_client.collective_rpc.assert_called_once_with( + "update_weights_via_mx_v2", + args=(42,), + kwargs={ + "compile_target_filter": ["cutlass_fp8"], + "timeout_seconds": 180.0, + "same_rank_only": True, + }, + ) + + +@pytest.mark.asyncio +async def test_update_weights_v2_endpoint_defaults(server_mod): + fake_client = MagicMock() + fake_client.collective_rpc = AsyncMock(return_value=[]) + fake_request = MagicMock() + fake_request.json = AsyncMock(return_value={"step": 1}) + + orig = getattr(server_mod, "engine_client", None) + server_mod.engine_client = lambda r: fake_client + try: + await server_mod.update_weights_v2(fake_request) + finally: + if orig is not None: + server_mod.engine_client = orig + + kwargs = fake_client.collective_rpc.call_args.kwargs["kwargs"] + assert kwargs["compile_target_filter"] is None + assert kwargs["timeout_seconds"] == 300.0 + assert kwargs["same_rank_only"] is True + + +# ---------------------------------------------------------------------------- +# 3. Orchestrator-side helpers — load client.py with stubs +# ---------------------------------------------------------------------------- + + +def _install_client_stubs(): + sys.modules.setdefault("prime_rl", types.ModuleType("prime_rl")) + sys.modules.setdefault("prime_rl.utils", types.ModuleType("prime_rl.utils")) + sys.modules["prime_rl.utils.logger"] = types.SimpleNamespace( + get_logger=MagicMock(return_value=MagicMock(name="logger")), + setup_logger=MagicMock(), + ) + # httpx AsyncClient stub — client.py imports it + sys.modules["httpx"] = types.SimpleNamespace( + AsyncClient=type("AsyncClient", (), {}) + ) + + +@pytest.fixture +def client_mod(): + for k in list(sys.modules.keys()): + if k.startswith("prime_rl") or k == "httpx": + del sys.modules[k] + _install_client_stubs() + spec = importlib.util.spec_from_file_location( + "_test_client_under_test", _CLIENT_FILE + ) + mod = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(mod) + except Exception as e: + pytest.skip( + f"client.py imports too much to stub cleanly: {e}; this test " + f"runs in CI where prime-rl IS installed" + ) + yield mod + + +@pytest.mark.asyncio +async def test_init_nixl_mx_v2_broadcast_posts_to_all_servers(client_mod): + """POSTs /init_nixl_mx_v2 with rank_offset = i * gpus_per_server per server.""" + admin_clients = [] + for _ in range(3): + c = MagicMock() + resp = MagicMock() + resp.raise_for_status = MagicMock() + c.post = AsyncMock(return_value=resp) + admin_clients.append(c) + + await client_mod.init_nixl_mx_v2_broadcast( + admin_clients, + host="mx-server", + port=8001, + inference_world_size=12, + publish_self_as_replica=True, + listen_port=None, + ) + + # gpus_per_server = 12 // 3 = 4 → rank_offsets 0, 4, 8 + expected_offsets = [0, 4, 8] + for c, expected_offset in zip(admin_clients, expected_offsets): + c.post.assert_called_once() + args, kwargs = c.post.call_args + assert args[0] == "/init_nixl_mx_v2" + body = kwargs["json"] + assert body["host"] == "mx-server" + assert body["port"] == 8001 + assert body["rank_offset"] == expected_offset + assert body["publish_self_as_replica"] is True + + +@pytest.mark.asyncio +async def test_update_weights_v2_posts_step_and_returns_metrics(client_mod): + fake_servers = [] + expected_responses = [ + {"status": "ok", "metrics": [{"step": 5, "bandwidth_gbps": 50.0}]}, + {"status": "ok", "metrics": [{"step": 5, "bandwidth_gbps": 48.0}]}, + ] + for resp_body in expected_responses: + c = MagicMock() + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json = MagicMock(return_value=resp_body) + c.post = AsyncMock(return_value=resp) + fake_servers.append(c) + + results = await client_mod.update_weights_v2( + fake_servers, + step=5, + compile_target_filter=["cutlass_fp8"], + timeout_seconds=180.0, + same_rank_only=True, + ) + + assert results == expected_responses + for c in fake_servers: + args, kwargs = c.post.call_args + assert args[0] == "/update_weights_v2" + body = kwargs["json"] + assert body["step"] == 5 + assert body["compile_target_filter"] == ["cutlass_fp8"] + assert body["timeout_seconds"] == 180.0 + assert body["same_rank_only"] is True + + +# ---------------------------------------------------------------------------- +# 4. Trainer-side selector dispatch — verify __init__.py routes mx_v2 correctly +# ---------------------------------------------------------------------------- + + +def test_broadcast_init_dispatches_mx_v2_via_ast(): + """The selector in broadcast/__init__.py routes config.type == "mx_v2" + to NIXLMxV2WeightBroadcast. Parse the source directly to avoid the heavy + import graph.""" + import ast + + src = _BROADCAST_INIT_FILE.read_text() + tree = ast.parse(src) + + # Find the setup_weight_broadcast function + func = next( + node + for node in ast.walk(tree) + if isinstance(node, ast.FunctionDef) and node.name == "setup_weight_broadcast" + ) + + # Find the elif branch with `config.type == "mx_v2"` + mx_v2_branch_found = False + for node in ast.walk(func): + if isinstance(node, ast.Compare): + # Detect `config.type == "mx_v2"` + if ( + len(node.comparators) == 1 + and isinstance(node.comparators[0], ast.Constant) + and node.comparators[0].value == "mx_v2" + ): + mx_v2_branch_found = True + break + assert mx_v2_branch_found, "mx_v2 dispatch branch not found in selector" + + # And the branch references NIXLMxV2WeightBroadcast + assert "NIXLMxV2WeightBroadcast" in src + assert "from prime_rl.trainer.rl.broadcast.nixl_mx_v2" in src diff --git a/tests/unit/inference/vllm/worker/__init__.py b/tests/unit/inference/vllm/worker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py b/tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py new file mode 100644 index 0000000000..32c330d4f0 --- /dev/null +++ b/tests/unit/inference/vllm/worker/test_nixl_mx_v2_worker.py @@ -0,0 +1,437 @@ +"""Unit tests for ``NIXLMxV2WeightUpdateWorker``. + +Same pattern as ``test_nixl_mx_v2.py`` — load the production module via +``importlib.util.spec_from_file_location`` against a fully-stubbed +dependency graph (vLLM, modelexpress, prime_rl.transport, etc.), so the +test runs anywhere torch + pytest is present. + +The worker has two RPC entry points: + +- ``init_nixl_mx_v2(host, port, rank_offset, *, publish_self_as_replica, listen_port)`` +- ``update_weights_via_mx_v2(step, *, compile_target_filter, timeout_seconds, same_rank_only)`` + +We verify init-info construction, update-info construction, the +``load_weights`` callback path, metrics-dict shape, and the post-load +``update_mla_absorbed_weights`` hook. +""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +_PRIME_RL_ROOT = Path(__file__).resolve().parents[5] # prime-rl root +_WORKER_FILE = ( + _PRIME_RL_ROOT + / "src" + / "prime_rl" + / "inference" + / "vllm" + / "worker" + / "nixl_mx_v2.py" +) + + +def _install_stubs(): + """Insert fake modules so the worker file imports cleanly.""" + mocks: dict[str, MagicMock] = {} + + # ─── modelexpress.vllm_weight_transfer ────────────────────────────── + fake_engine_cls = MagicMock(name="MxWeightTransferEngine_cls") + fake_engine = MagicMock(name="MxWeightTransferEngine_inst") + fake_stats = types.SimpleNamespace( + bytes_received=536_870_912, + tensors_received=64, + elapsed_seconds=0.082, + bandwidth_gbps=52.4, + discovery_seconds=0.014, + source_worker_rank=0, + ) + fake_engine.last_transfer_stats = fake_stats + fake_engine.last_discovery_seconds = 0.014 + fake_engine_cls.return_value = fake_engine + + fake_init_info_cls = MagicMock(name="MxInitInfo") + fake_update_info_cls = MagicMock(name="MxUpdateInfo") + + sys.modules["modelexpress"] = types.SimpleNamespace() + sys.modules["modelexpress.vllm_weight_transfer"] = types.SimpleNamespace( + MxWeightTransferEngine=fake_engine_cls, + MxInitInfo=fake_init_info_cls, + MxUpdateInfo=fake_update_info_cls, + ) + mocks["engine_cls"] = fake_engine_cls + mocks["engine"] = fake_engine + mocks["init_info_cls"] = fake_init_info_cls + mocks["update_info_cls"] = fake_update_info_cls + mocks["stats"] = fake_stats + + # ─── vllm.logger ──────────────────────────────────────────────────── + sys.modules["vllm"] = types.SimpleNamespace() + sys.modules["vllm.logger"] = types.SimpleNamespace( + init_logger=lambda name: MagicMock(name=f"logger({name})") + ) + # vllm.v1.worker.gpu_worker only used inside TYPE_CHECKING so no stub needed + + # ─── prime_rl.inference.vllm.worker.weight_transfer ──────────────── + fake_update_mla = MagicMock(name="update_mla_absorbed_weights") + sys.modules.setdefault("prime_rl", types.ModuleType("prime_rl")) + sys.modules["prime_rl.inference"] = types.ModuleType("prime_rl.inference") + sys.modules["prime_rl.inference.vllm"] = types.ModuleType("prime_rl.inference.vllm") + sys.modules["prime_rl.inference.vllm.worker"] = types.ModuleType( + "prime_rl.inference.vllm.worker" + ) + pkg_wt = types.ModuleType("prime_rl.inference.vllm.worker.weight_transfer") + pkg_wt.update_mla_absorbed_weights = fake_update_mla + # `build_expert_map` is imported by the OLD worker (nixl_mx.py) — not by + # nixl_mx_v2 — so we don't need to stub it for this test. Add a no-op + # in case the test imports the broadcast __init__ which may pull it in. + pkg_wt.build_expert_map = MagicMock(name="build_expert_map", return_value={}) + sys.modules["prime_rl.inference.vllm.worker.weight_transfer"] = pkg_wt + mocks["update_mla"] = fake_update_mla + + # ─── prime_rl.transport.nixl_agent ────────────────────────────────── + fake_make_agent_name = MagicMock(return_value="vllm-inference-r0") + fake_pin_ucx_rail = MagicMock() + sys.modules["prime_rl.transport"] = types.ModuleType("prime_rl.transport") + pkg_na = types.ModuleType("prime_rl.transport.nixl_agent") + pkg_na.make_agent_name = fake_make_agent_name + pkg_na.pin_ucx_rail = fake_pin_ucx_rail + sys.modules["prime_rl.transport.nixl_agent"] = pkg_na + mocks["make_agent_name"] = fake_make_agent_name + mocks["pin_ucx_rail"] = fake_pin_ucx_rail + + return mocks + + +@pytest.fixture +def worker_mod(): + """Load nixl_mx_v2.py worker under fully-stubbed deps.""" + for k in list(sys.modules.keys()): + if k.startswith("prime_rl") or k == "modelexpress" or k.startswith("modelexpress."): + del sys.modules[k] + if k == "vllm" or k.startswith("vllm."): + del sys.modules[k] + + mocks = _install_stubs() + + import torch + if not hasattr(torch.cuda, "synchronize"): + torch.cuda.synchronize = MagicMock() + original_synchronize = torch.cuda.synchronize + torch.cuda.synchronize = MagicMock() + + spec = importlib.util.spec_from_file_location( + "_test_nixl_mx_v2_worker_under_test", _WORKER_FILE + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + try: + yield (mod, mocks) + finally: + torch.cuda.synchronize = original_synchronize + + +def _make_worker(mod, *, model_name="bench/synthetic-1.5B", device_index=0): + """Build a worker with a mocked vLLM Worker context. + + The `raw_model` property does `assert isinstance(model, Module)` so the + inner model has to be a real ``torch.nn.Module`` subclass. We attach a + `load_weights` method onto it so tests can spy on the callback path. + """ + import torch.nn as nn + + class FakeInnerModel(nn.Module): + def __init__(self): + super().__init__() + self.load_weights = MagicMock(name="load_weights") + + worker = mod.NIXLMxV2WeightUpdateWorker() + worker.device = MagicMock() + worker.device.index = device_index + runner = MagicMock(name="ModelRunner") + runner.model_config = MagicMock(model=model_name) + fake_inner = FakeInnerModel() + runner.model = MagicMock() + runner.model.runnable = fake_inner + worker.model_runner = runner + return worker + + +def test_init_nixl_mx_v2_builds_engine_with_correct_init_info(worker_mod): + mod, mocks = worker_mod + worker = _make_worker(mod, device_index=2) + worker.init_nixl_mx_v2( + host="modelexpress-server.kavin.svc.cluster.local", + port=8001, + rank_offset=4, + publish_self_as_replica=True, + listen_port=None, + ) + + mocks["init_info_cls"].assert_called_once() + init_kwargs = mocks["init_info_cls"].call_args.kwargs + assert ( + init_kwargs["mx_server_url"] + == "modelexpress-server.kavin.svc.cluster.local:8001" + ) + assert init_kwargs["worker_rank"] == 4 + 2 + assert init_kwargs["agent_name"] == "vllm-inference-r0" + assert init_kwargs["device_id"] == 2 + assert init_kwargs["publish_self_as_replica"] is True + + mocks["engine_cls"].assert_called_once() + eng_kwargs = mocks["engine_cls"].call_args.kwargs + assert "init_info" in eng_kwargs + mocks["pin_ucx_rail"].assert_called_once_with(2) + assert worker._global_rank == 6 + + +def test_init_nixl_mx_v2_respects_publish_self_as_replica_false(worker_mod): + mod, mocks = worker_mod + worker = _make_worker(mod) + worker.init_nixl_mx_v2( + host="x", port=8001, rank_offset=0, publish_self_as_replica=False + ) + init_kwargs = mocks["init_info_cls"].call_args.kwargs + assert init_kwargs["publish_self_as_replica"] is False + + +def test_update_weights_via_mx_v2_dispatches_engine_receive(worker_mod): + mod, mocks = worker_mod + worker = _make_worker(mod) + worker.init_nixl_mx_v2(host="x", port=8001, rank_offset=0) + + metrics = worker.update_weights_via_mx_v2( + 42, + compile_target_filter=["cutlass_fp8", "hf_raw"], + timeout_seconds=180.0, + same_rank_only=True, + ) + + mocks["update_info_cls"].assert_called_once() + upd_kwargs = mocks["update_info_cls"].call_args.kwargs + assert upd_kwargs["version"] == 42 + assert upd_kwargs["compile_target_filter"] == {"cutlass_fp8", "hf_raw"} + assert upd_kwargs["timeout_seconds"] == 180.0 + assert upd_kwargs["same_rank_only"] is True + assert upd_kwargs["target_tp_layout"] is None + + mocks["engine"].receive_weights.assert_called_once() + call = mocks["engine"].receive_weights.call_args + assert "load_weights" in call.kwargs + + mocks["update_mla"].assert_called_once() + + assert metrics["step"] == 42 + assert metrics["bytes_received"] == 536_870_912 + assert metrics["tensors_received"] == 64 + assert metrics["bandwidth_gbps"] == pytest.approx(52.4) + assert metrics["discovery_seconds"] == pytest.approx(0.014) + assert metrics["source_worker_rank"] == 0 + + +def test_update_weights_via_mx_v2_no_filter_passes_none(worker_mod): + mod, mocks = worker_mod + worker = _make_worker(mod) + worker.init_nixl_mx_v2(host="x", port=8001, rank_offset=0) + worker.update_weights_via_mx_v2(1, compile_target_filter=None) + upd_kwargs = mocks["update_info_cls"].call_args.kwargs + assert upd_kwargs["compile_target_filter"] is None + + +def test_load_weights_batch_passthrough_when_no_hf_config(worker_mod): + """When _hf_config is None (non-MoE model / probe failed), the translator + falls through to passthrough and forwards the batch unchanged.""" + mod, _ = worker_mod + worker = _make_worker(mod) + worker._hf_config = None # force passthrough + captured_batches = [] + worker.raw_model.load_weights = MagicMock( + side_effect=lambda batch: captured_batches.append(batch) + ) + batch_1 = [("model.layers.0.weight", "TENSOR1")] + batch_2 = [("model.layers.1.weight", "TENSOR2"), ("a", "T3")] + worker._load_weights_batch(batch_1) + worker._load_weights_batch(batch_2) + assert captured_batches == [batch_1, batch_2] + + +def test_translate_tt_to_hf_qkv_split(worker_mod): + """Fused qkv_proj.weight (TT format) splits into q/k/v (HF format) + with the right per-projection row counts derived from head dims.""" + import torch + + mod, _ = worker_mod + worker = _make_worker(mod) + # Qwen3-30B-A3B-like dims: 32 q heads, 4 kv heads, head_dim=128, hidden=2048. + worker._hf_config = { + "model_type": "qwen3_moe", + "num_attention_heads": 32, + "num_kv_heads": 4, + "head_dim": 128, + "num_experts": 128, + "ep_size": 4, + } + worker._global_rank = 0 + q_size = 32 * 128 # 4096 + kv_size = 4 * 128 # 512 + rows = q_size + 2 * kv_size # 5120 + qkv = torch.arange(rows * 2048, dtype=torch.float32).view(rows, 2048) + out = worker._translate_tt_to_hf( + [("model.layers.0.self_attn.qkv_proj.weight", qkv)] + ) + names = [n for n, _ in out] + assert names == [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.v_proj.weight", + ] + assert out[0][1].shape == (q_size, 2048) + assert out[1][1].shape == (kv_size, 2048) + assert out[2][1].shape == (kv_size, 2048) + # Data preserved: first row of q == first row of qkv + assert torch.equal(out[0][1][0], qkv[0]) + assert torch.equal(out[1][1][0], qkv[q_size]) + assert torch.equal(out[2][1][0], qkv[q_size + kv_size]) + + +def test_translate_tt_to_hf_router_rename(worker_mod): + """mlp.router.gate.weight renames to mlp.gate.weight (HF naming).""" + import torch + + mod, _ = worker_mod + worker = _make_worker(mod) + worker._hf_config = { + "model_type": "qwen3_moe", + "num_attention_heads": 32, + "num_kv_heads": 4, + "head_dim": 128, + "num_experts": 128, + "ep_size": 4, + } + worker._global_rank = 0 + gate = torch.randn(128, 2048) + out = worker._translate_tt_to_hf( + [("model.layers.3.mlp.router.gate.weight", gate)] + ) + assert [n for n, _ in out] == ["model.layers.3.mlp.gate.weight"] + assert torch.equal(out[0][1], gate) + + +def test_translate_tt_to_hf_expert_w13_per_expert_split(worker_mod): + """Stacked w13 (gate+up) splits per-expert with the correct global + expert ID derived from rank * num_local + local_id.""" + import torch + + mod, _ = worker_mod + worker = _make_worker(mod) + worker._hf_config = { + "model_type": "qwen3_moe", + "num_attention_heads": 32, + "num_kv_heads": 4, + "head_dim": 128, + "num_experts": 128, + "ep_size": 4, + } + worker._global_rank = 2 # ep_rank=2 → global IDs 64..95 (num_local=32) + moe_dim = 768 + hidden = 2048 + n_local = 32 + w13 = torch.arange(n_local * 2 * moe_dim * hidden, dtype=torch.float32).view( + n_local, 2 * moe_dim, hidden + ) + out = worker._translate_tt_to_hf( + [("model.layers.5.mlp.experts.w13_weight", w13)] + ) + # Each local expert produces TWO tensors (gate + up) + assert len(out) == n_local * 2 + # First emitted should be local-expert-0 → global ID 64 (rank 2 × 32) + first_name, first_t = out[0] + assert first_name == "model.layers.5.mlp.experts.64.gate_proj.weight" + assert first_t.shape == (moe_dim, hidden) + # Second emitted should be local-0's up_proj (global ID 64) + assert out[1][0] == "model.layers.5.mlp.experts.64.up_proj.weight" + # Last local expert (31) → global ID 95 + last_gate_name = f"model.layers.5.mlp.experts.{2 * 32 + 31}.gate_proj.weight" + assert last_gate_name in [n for n, _ in out] + # Data preservation: local-0's gate-slice matches w13[0, :moe_dim] + assert torch.equal(first_t, w13[0, :moe_dim]) + + +def test_translate_tt_to_hf_expert_w2_per_expert(worker_mod): + """w2 (down) splits per-expert with the correct global IDs.""" + import torch + + mod, _ = worker_mod + worker = _make_worker(mod) + worker._hf_config = { + "model_type": "qwen3_moe", + "num_attention_heads": 32, + "num_kv_heads": 4, + "head_dim": 128, + "num_experts": 128, + "ep_size": 4, + } + worker._global_rank = 0 + hidden = 2048 + moe_dim = 768 + n_local = 32 + w2 = torch.randn(n_local, hidden, moe_dim) + out = worker._translate_tt_to_hf([("model.layers.7.mlp.experts.w2_weight", w2)]) + assert len(out) == n_local + assert out[0][0] == "model.layers.7.mlp.experts.0.down_proj.weight" + assert out[-1][0] == "model.layers.7.mlp.experts.31.down_proj.weight" + assert torch.equal(out[0][1], w2[0]) + assert torch.equal(out[31][1], w2[31]) + + +def test_translate_tt_to_hf_passthrough_for_unknown_names(worker_mod): + """Names not in the TT→HF table pass through unchanged (norms, embed, + lm_head, o_proj, q_norm, k_norm, etc.).""" + import torch + + mod, _ = worker_mod + worker = _make_worker(mod) + worker._hf_config = { + "model_type": "qwen3_moe", + "num_attention_heads": 32, + "num_kv_heads": 4, + "head_dim": 128, + "num_experts": 128, + "ep_size": 4, + } + worker._global_rank = 0 + t = torch.randn(2048) + passthrough_cases = [ + ("model.embed_tokens.weight", t), + ("model.norm.weight", t), + ("lm_head.weight", t), + ("model.layers.0.self_attn.o_proj.weight", t), + ("model.layers.0.self_attn.q_norm.weight", t), + ("model.layers.0.self_attn.k_norm.weight", t), + ("model.layers.0.input_layernorm.weight", t), + ("model.layers.0.post_attention_layernorm.weight", t), + ] + out = worker._translate_tt_to_hf(passthrough_cases) + assert out == passthrough_cases # exact same list, unchanged + + +def test_update_weights_via_mx_v2_metrics_safe_when_stats_none(worker_mod): + mod, mocks = worker_mod + mocks["engine"].last_transfer_stats = None + worker = _make_worker(mod) + worker.init_nixl_mx_v2(host="x", port=8001, rank_offset=0) + + metrics = worker.update_weights_via_mx_v2(1) + assert metrics["bytes_received"] == 0 + assert metrics["tensors_received"] == 0 + assert metrics["bandwidth_gbps"] == 0.0 + assert metrics["source_worker_rank"] is None diff --git a/tests/unit/train/models/conversions/test_cutlass_fp8.py b/tests/unit/train/models/conversions/test_cutlass_fp8.py new file mode 100644 index 0000000000..cd1ea3ece7 --- /dev/null +++ b/tests/unit/train/models/conversions/test_cutlass_fp8.py @@ -0,0 +1,311 @@ +"""Tests for cutlass FP8 e4m3 per-channel conversion + registry-extension plumbing. + +Direct-loads the conversions package to bypass the heavy +``prime_rl.trainer`` import chain (CUDA + torchrun + ray + …) so the suite +runs on a plain CPU CI box with only ``torch`` installed. +""" + +from __future__ import annotations + +import importlib +import importlib.util +import sys +import types +from pathlib import Path + +import pytest +import torch + + +_HERE = Path(__file__).resolve().parent +_REPO_ROOT = _HERE.parent.parent.parent.parent.parent +_CONV_PKG_DIR = _REPO_ROOT / "src" / "prime_rl" / "trainer" / "models" / "conversions" +_FP8_PATH = _REPO_ROOT / "src" / "prime_rl" / "trainer" / "models" / "fp8.py" + + +def _direct_load(name: str, path: Path): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + spec.loader.exec_module(mod) + return mod + + +@pytest.fixture(scope="module") +def conv_pkg(): + """Load the conversions package + its dependencies in isolation. + + Order matters: ``conversions/__init__.py`` registers ``cutlass_fp8`` as + a late side-effect import; we need ``prime_rl.trainer.models.fp8`` to + be importable first. + """ + # Synthesize the prime_rl.trainer.models package hierarchy so the + # relative imports inside the conversion modules resolve. + for fqn, path in [ + ("prime_rl", _REPO_ROOT / "src" / "prime_rl"), + ("prime_rl.trainer", _REPO_ROOT / "src" / "prime_rl" / "trainer"), + ("prime_rl.trainer.models", _REPO_ROOT / "src" / "prime_rl" / "trainer" / "models"), + ]: + if fqn in sys.modules: + continue + pkg = types.ModuleType(fqn) + pkg.__path__ = [str(path)] + sys.modules[fqn] = pkg + + # Load fp8.py first — the conversion modules import from it. + _direct_load("prime_rl.trainer.models.fp8", _FP8_PATH) + + # Now load the conversions package, then its submodules. We point at + # the directory's __init__.py explicitly so we don't get the partial + # package from a parent that's already half-loaded. + pkg = _direct_load( + "prime_rl.trainer.models.conversions", _CONV_PKG_DIR / "__init__.py" + ) + return pkg + + +# ---------------------------------------------------------------------------- +# Per-output-channel quantize helper (lives in fp8.py) +# ---------------------------------------------------------------------------- + + +def test_fp8_per_channel_2d_round_trip_shape(conv_pkg): + from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize + + w = torch.randn(64, 256, dtype=torch.bfloat16) + q, s = fp8_per_channel_quantize(w) + assert q.shape == (64, 256) + assert q.dtype == torch.float8_e4m3fn + assert s.shape == (64,) + assert s.dtype == torch.float32 + + +def test_fp8_per_channel_3d_round_trip_shape(conv_pkg): + from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize + + w = torch.randn(8, 64, 256, dtype=torch.bfloat16) + q, s = fp8_per_channel_quantize(w) + assert q.shape == (8, 64, 256) + assert q.dtype == torch.float8_e4m3fn + assert s.shape == (8, 64) + assert s.dtype == torch.float32 + + +def test_fp8_per_channel_rejects_1d(conv_pkg): + from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize + + with pytest.raises(ValueError, match="2D or 3D"): + fp8_per_channel_quantize(torch.randn(64)) + + +def test_fp8_per_channel_dequant_close_to_original(conv_pkg): + """Round-trip accuracy: per-channel scaling has ~1% error band on bf16 inputs.""" + from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize + + torch.manual_seed(0) + w = torch.randn(32, 128, dtype=torch.bfloat16) * 0.1 + q, s = fp8_per_channel_quantize(w) + dequant = q.float() * s.unsqueeze(-1) + # FP8 e4m3 has ~3-bit mantissa → relative error tolerance is generous + rel = (dequant - w.float()).abs() / (w.float().abs() + 1e-6) + assert rel.median().item() < 0.05 # 5 % median error is realistic for fp8 e4m3 + + +def test_fp8_per_channel_into_writes_buffers(conv_pkg): + from prime_rl.trainer.models.fp8 import fp8_per_channel_quantize_into + + w = torch.randn(16, 64, dtype=torch.bfloat16) + out = torch.empty(16, 64, dtype=torch.float8_e4m3fn) + sf = torch.empty(16, dtype=torch.float32) + fp8_per_channel_quantize_into(w, out=out, sf=sf) + # Both buffers should now reflect a real quantization (not the empty pattern). + assert sf.gt(0).all() + assert out.float().abs().max() <= 448.0 # fp8 e4m3 finite range + + +# ---------------------------------------------------------------------------- +# Registry extensions: compile_target + compile_metadata + new entry +# ---------------------------------------------------------------------------- + + +def test_conversion_entry_carries_compile_target(conv_pkg): + entry = conv_pkg.get("bf16_cast") + assert entry.compile_target == conv_pkg.COMPILE_TARGET_HF_RAW + assert entry.compile_metadata == {"dtype": "bfloat16"} + + +def test_fp8_128x128_tagged_deep_gemm(conv_pkg): + entry = conv_pkg.get("fp8_128x128") + assert entry.compile_target == conv_pkg.COMPILE_TARGET_DEEPGEMM_FP8 + assert entry.compile_metadata["block_size"] == [128, 128] + assert entry.compile_metadata["scale_layout"] == "blockwise" + + +def test_cutlass_fp8_entry_registered(conv_pkg): + entry = conv_pkg.get("cutlass_fp8_e4m3_per_channel") + assert entry.requires_scale is True + assert entry.compile_target == conv_pkg.COMPILE_TARGET_CUTLASS_FP8 + assert entry.compile_metadata == { + "dtype": "e4m3", + "scale_layout": "per_channel", + "scale_axis": -1, + "activation_scheme": "dynamic", + } + + +def test_cutlass_fp8_in_registered_names(conv_pkg): + names = conv_pkg.registered_names() + assert "cutlass_fp8_e4m3_per_channel" in names + assert "fp8_128x128" in names + assert "bf16_cast" in names + + +def test_register_default_rule_appends(conv_pkg): + """register_default_rule appends by default and prepends with insert_first=True.""" + + sentinel_name = "bf16_cast" # we know this exists + + def predicate_a(quant): + return quant.get("quant_method") == "test_a" + + def predicate_b(quant): + return quant.get("quant_method") == "test_b" + + # These mutate module state; use unique enough names that they don't + # collide with the real rules. + conv_pkg.register_default_rule(predicate_a, sentinel_name) + conv_pkg.register_default_rule(predicate_b, sentinel_name, insert_first=True) + + # We can't read the table directly without breaking the encapsulation, + # but we can verify behaviorally: predicate_b should be matched before + # the existing rules; predicate_a should be matched after. + import prime_rl.trainer.models.conversions as conv + + rules = conv._DEFAULT_RULES + # predicate_b should now be at index 0 + assert rules[0][0] is predicate_b + # predicate_a should be at the end + assert rules[-1][0] is predicate_a + + +def test_unknown_quant_raises_listing_registered(conv_pkg, fake_hf_config): + """When no rule matches, the error message lists what IS registered.""" + fake_hf_config["quant"] = {"quant_method": "totally_unknown_method"} + with pytest.raises(NotImplementedError, match="registered conversions"): + conv_pkg.select_default_conversion("fake/model") + + +# ---------------------------------------------------------------------------- +# select_default_conversion dispatch (the new table-driven path) +# ---------------------------------------------------------------------------- + + +@pytest.fixture +def fake_hf_config(monkeypatch): + """Stub ``transformers.AutoConfig`` for the test session so + ``select_default_conversion`` runs without an HF download. + + Because the conversions module imports ``AutoConfig`` lazily inside + the function, we have to populate ``sys.modules['transformers']`` + with our stub *before* the function call resolves the import. + """ + holder = {"quant": None} + + class _Fake: + @property + def quantization_config(self): + return holder["quant"] + + class _FakeAutoConfig: + @staticmethod + def from_pretrained(*args, **kwargs): + return _Fake() + + transformers_stub = types.ModuleType("transformers") + transformers_stub.AutoConfig = _FakeAutoConfig + monkeypatch.setitem(sys.modules, "transformers", transformers_stub) + return holder + + +def test_default_no_quant_is_bf16(conv_pkg, fake_hf_config): + fake_hf_config["quant"] = None + assert conv_pkg.select_default_conversion("any/model") == "bf16_cast" + + +def test_default_deep_gemm_fp8(conv_pkg, fake_hf_config): + fake_hf_config["quant"] = { + "quant_method": "fp8", + "weight_block_size": [128, 128], + } + assert conv_pkg.select_default_conversion("any/model") == "fp8_128x128" + + +def test_default_cutlass_fp8_via_explicit_format(conv_pkg, fake_hf_config): + fake_hf_config["quant"] = { + "quant_method": "fp8", + "quant_format": "cutlass", + } + assert ( + conv_pkg.select_default_conversion("any/model") + == "cutlass_fp8_e4m3_per_channel" + ) + + +def test_default_cutlass_fp8_via_dynamic_no_block_size(conv_pkg, fake_hf_config): + fake_hf_config["quant"] = { + "quant_method": "fp8", + "weight_block_size": None, + "activation_scheme": "dynamic", + } + assert ( + conv_pkg.select_default_conversion("any/model") + == "cutlass_fp8_e4m3_per_channel" + ) + + +def test_default_deep_gemm_wins_over_cutlass_when_block_size_set( + conv_pkg, fake_hf_config +): + """Both rules could plausibly fire for a config with block_size=[128,128] + AND activation_scheme="dynamic"; the deep-gemm rule was registered first + and must win.""" + fake_hf_config["quant"] = { + "quant_method": "fp8", + "weight_block_size": [128, 128], + "activation_scheme": "dynamic", + } + assert conv_pkg.select_default_conversion("any/model") == "fp8_128x128" + + +# ---------------------------------------------------------------------------- +# End-to-end fn dispatch: 2D + 3D shapes via the registered conversion entry +# ---------------------------------------------------------------------------- + + +def test_cutlass_fp8_fn_dispatches_2d_linear(conv_pkg): + entry = conv_pkg.get("cutlass_fp8_e4m3_per_channel") + src = torch.randn(32, 128, dtype=torch.bfloat16) + out = torch.empty(32, 128, dtype=torch.float8_e4m3fn) + sf = torch.empty(32, dtype=torch.float32) + entry.fn(src, out, sf) + assert sf.gt(0).all() + assert out.float().abs().max() <= 448.0 + + +def test_cutlass_fp8_fn_dispatches_3d_moe(conv_pkg): + entry = conv_pkg.get("cutlass_fp8_e4m3_per_channel") + src = torch.randn(4, 32, 128, dtype=torch.bfloat16) # E=4 experts + out = torch.empty(4, 32, 128, dtype=torch.float8_e4m3fn) + sf = torch.empty(4, 32, dtype=torch.float32) + entry.fn(src, out, sf) + assert sf.shape == (4, 32) + assert sf.gt(0).all() + assert out.shape == (4, 32, 128) + + +def test_cutlass_fp8_fn_requires_scale(conv_pkg): + entry = conv_pkg.get("cutlass_fp8_e4m3_per_channel") + src = torch.randn(8, 16, dtype=torch.bfloat16) + out = torch.empty(8, 16, dtype=torch.float8_e4m3fn) + with pytest.raises(AssertionError, match="scale_out"): + entry.fn(src, out, None) diff --git a/tests/unit/train/rl/test_nixl_mx_v2.py b/tests/unit/train/rl/test_nixl_mx_v2.py new file mode 100644 index 0000000000..d59f285a98 --- /dev/null +++ b/tests/unit/train/rl/test_nixl_mx_v2.py @@ -0,0 +1,526 @@ +"""Unit tests for ``NIXLMxV2WeightBroadcast``. + +These tests exercise the per-step orchestration logic — slot fill, +publisher add_tensor threading, compile_target tagging, MoE expert +metadata threading, and HSDP barrier gating — without requiring CUDA, +NIXL, a live MX server, or a real model. + +We use ``importlib.util.spec_from_file_location`` to load the production +``nixl_mx_v2.py`` against a fully-stubbed dependency graph (same pattern +as MX-side ``test_vllm_weight_transfer.py``). The test is therefore +runnable anywhere torch + pytest is present, without prime-rl needing to +be installed as a package. +""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +_PRIME_RL_ROOT = Path(__file__).resolve().parents[4] # prime-rl root +_BROADCAST_FILE = ( + _PRIME_RL_ROOT + / "src" + / "prime_rl" + / "trainer" + / "rl" + / "broadcast" + / "nixl_mx_v2.py" +) + + +# ---------------------------------------------------------------------------- +# Stub the prime_rl + modelexpress + transformers + torch.distributed +# dependency graph the broadcast module needs at import time +# ---------------------------------------------------------------------------- + + +def _install_stubs(): + """Insert fake modules into sys.modules so importing nixl_mx_v2 succeeds. + Returns a dict of the live mocks for test inspection.""" + mocks: dict[str, MagicMock] = {} + + # ─── modelexpress.nemo_rl_v2 ──────────────────────────────────────── + fake_publisher_cls = MagicMock(name="MxV2TrainingPublisher_cls") + fake_publisher = MagicMock(name="MxV2TrainingPublisher_inst") + fake_publisher.publish.return_value = "abcd1234efgh5678" + fake_publisher.mark_ready.return_value = True + fake_publisher_cls.return_value = fake_publisher + + fake_layout_cls = MagicMock(name="TrainerWorldLayout_cls") + fake_layout = MagicMock(name="TrainerWorldLayout_inst") + fake_layout.encode.return_value = "fsdp:1,tp:1,pp:1,ep:1" + fake_layout_cls.return_value = fake_layout + + sys.modules["modelexpress"] = types.SimpleNamespace() + sys.modules["modelexpress.nemo_rl_v2"] = types.SimpleNamespace( + MxV2TrainingPublisher=fake_publisher_cls, + TrainerWorldLayout=fake_layout_cls, + ) + mocks["publisher_cls"] = fake_publisher_cls + mocks["publisher"] = fake_publisher + mocks["layout_cls"] = fake_layout_cls + mocks["layout"] = fake_layout + + # ─── transformers.AutoConfig ──────────────────────────────────────── + fake_auto_config = MagicMock(name="AutoConfig") + fake_auto_config.from_pretrained.return_value = MagicMock( + torch_dtype="torch.bfloat16" + ) + sys.modules["transformers"] = types.SimpleNamespace(AutoConfig=fake_auto_config) + mocks["auto_config"] = fake_auto_config + + # ─── prime_rl.configs.trainer.MxV2WeightBroadcastConfig ───────────── + fake_config_cls = MagicMock(name="MxV2WeightBroadcastConfig_cls") + pkg_configs = types.ModuleType("prime_rl.configs") + pkg_configs_trainer = types.ModuleType("prime_rl.configs.trainer") + pkg_configs_trainer.MxV2WeightBroadcastConfig = fake_config_cls + sys.modules.setdefault("prime_rl", types.ModuleType("prime_rl")) + sys.modules["prime_rl.configs"] = pkg_configs + sys.modules["prime_rl.configs.trainer"] = pkg_configs_trainer + + # ─── prime_rl.trainer.models.PreTrainedModelPrimeRL ───────────────── + fake_pretrained_cls = MagicMock(name="PreTrainedModelPrimeRL_cls") + pkg_trainer_models = types.ModuleType("prime_rl.trainer.models") + pkg_trainer_models.PreTrainedModelPrimeRL = fake_pretrained_cls + sys.modules["prime_rl.trainer"] = types.ModuleType("prime_rl.trainer") + sys.modules["prime_rl.trainer.models"] = pkg_trainer_models + + # ─── prime_rl.trainer.models.slots (for SMALL_NON_EXPERT_BYTES) ───── + pkg_slots = types.ModuleType("prime_rl.trainer.models.slots") + pkg_slots.SMALL_NON_EXPERT_BYTES = 2 * 1024 * 1024 # match real default + sys.modules["prime_rl.trainer.models.slots"] = pkg_slots + mocks["slots_mod"] = pkg_slots + + # ─── prime_rl.trainer.models.conversions.select_default_conversion ── + fake_conversion = types.SimpleNamespace( + compile_target="cutlass_fp8", + compile_metadata={"block_size": 128, "scale_layout": "per_channel"}, + ) + fake_select_conversion = MagicMock(return_value=fake_conversion) + pkg_conv = types.ModuleType("prime_rl.trainer.models.conversions") + pkg_conv.select_default_conversion = fake_select_conversion + sys.modules["prime_rl.trainer.models.conversions"] = pkg_conv + mocks["conversion"] = fake_conversion + mocks["select_conversion"] = fake_select_conversion + + # ─── prime_rl.trainer.parallel_dims.ParallelDims ──────────────────── + fake_parallel_dims_cls = MagicMock(name="ParallelDims_cls") + pkg_pd = types.ModuleType("prime_rl.trainer.parallel_dims") + pkg_pd.ParallelDims = fake_parallel_dims_cls + sys.modules["prime_rl.trainer.parallel_dims"] = pkg_pd + + # ─── prime_rl.trainer.rl.broadcast.base.WeightBroadcast ───────────── + class FakeWeightBroadcast: + def __init__(self, output_dir, *args, **kwargs): + self.output_dir = output_dir + # Mimic real base class — set logger so subclass can use it. + self.logger = MagicMock(name="logger") + + pkg_trainer_rl = types.ModuleType("prime_rl.trainer.rl") + pkg_trainer_rl_broadcast = types.ModuleType("prime_rl.trainer.rl.broadcast") + pkg_broadcast_base = types.ModuleType("prime_rl.trainer.rl.broadcast.base") + pkg_broadcast_base.WeightBroadcast = FakeWeightBroadcast + sys.modules["prime_rl.trainer.rl"] = pkg_trainer_rl + sys.modules["prime_rl.trainer.rl.broadcast"] = pkg_trainer_rl_broadcast + sys.modules["prime_rl.trainer.rl.broadcast.base"] = pkg_broadcast_base + mocks["base_cls"] = FakeWeightBroadcast + + # ─── prime_rl.trainer.runs.get_multi_run_manager ──────────────────── + fake_run_manager = types.SimpleNamespace(used_idxs=[], ready_to_update={}) + fake_get_multi_run_manager = MagicMock(return_value=fake_run_manager) + pkg_runs = types.ModuleType("prime_rl.trainer.runs") + pkg_runs.get_multi_run_manager = fake_get_multi_run_manager + sys.modules["prime_rl.trainer.runs"] = pkg_runs + mocks["run_manager"] = fake_run_manager + + # ─── prime_rl.trainer.utils.get_world ────────────────────────────── + fake_world = types.SimpleNamespace(rank=0, is_master=True) + fake_get_world = MagicMock(return_value=fake_world) + pkg_utils = types.ModuleType("prime_rl.trainer.utils") + pkg_utils.get_world = fake_get_world + sys.modules["prime_rl.trainer.utils"] = pkg_utils + mocks["world"] = fake_world + + # ─── prime_rl.transport.classic_cuda_pool / nixl_agent ────────────── + class FakeAlloc: + def __enter__(self): + return None + + def __exit__(self, *args): + return False + + pkg_transport = types.ModuleType("prime_rl.transport") + pkg_transport_classic = types.ModuleType("prime_rl.transport.classic_cuda_pool") + pkg_transport_classic.classic_cuda_alloc = lambda: FakeAlloc() + pkg_transport_nixl_agent = types.ModuleType("prime_rl.transport.nixl_agent") + pkg_transport_nixl_agent.make_agent_name = MagicMock(return_value="trainer-0") + pkg_transport_nixl_agent.pin_ucx_rail = MagicMock() + sys.modules["prime_rl.transport"] = pkg_transport + sys.modules["prime_rl.transport.classic_cuda_pool"] = pkg_transport_classic + sys.modules["prime_rl.transport.nixl_agent"] = pkg_transport_nixl_agent + + return mocks + + +@pytest.fixture +def broadcast_mod(): + """Load nixl_mx_v2.py under fully-stubbed deps. Yields (module, mocks).""" + # Wipe any stale modules so each test gets a fresh patched graph. + for k in list(sys.modules.keys()): + if k.startswith("prime_rl") or k == "modelexpress" or k.startswith("modelexpress."): + del sys.modules[k] + if k == "transformers": + del sys.modules[k] + + mocks = _install_stubs() + + # Patch torch.cuda + torch.distributed before loading the module. + import torch + torch.cuda.current_device = MagicMock(return_value=0) + if hasattr(torch.distributed, "barrier"): + original_barrier = torch.distributed.barrier + torch.distributed.barrier = MagicMock() + else: + original_barrier = None + torch.distributed.barrier = MagicMock() + + spec = importlib.util.spec_from_file_location( + "_test_nixl_mx_v2_under_test", _BROADCAST_FILE + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + try: + yield (mod, mocks) + finally: + if original_barrier is not None: + torch.distributed.barrier = original_barrier + + +# ---------------------------------------------------------------------------- +# Helpers +# ---------------------------------------------------------------------------- + + +def _make_config(**overrides): + defaults = dict( + type="mx_v2", + host="localhost", + port=8001, + timeout=60, + inference_world_size=1, + inference_model_name="bench/synthetic-1.5B", + same_rank_only=True, + dedup_freshest_per_rank=True, + publish_compile_target=True, + compile_target_filter=None, + publish_self_as_replica=True, + ) + defaults.update(overrides) + return types.SimpleNamespace(**defaults) + + +def _make_parallel_dims( + *, + dp_replicate_enabled: bool = False, + is_primary: bool = True, + fsdp_world_size: int = 1, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, +): + mesh = MagicMock(name="dp_replicate_mesh") + mesh.get_local_rank.return_value = 0 if is_primary else 1 + pdims = MagicMock(name="ParallelDims") + pdims.dp_replicate_enabled = dp_replicate_enabled + pdims.dp_shard_size = fsdp_world_size + pdims.tp_size = tp_size + pdims.pp_size = pp_size + pdims.ep_size = ep_size + pdims.get_mesh = MagicMock(return_value=mesh) + return pdims + + +def _make_fake_slot(*, name: str, is_expert: bool = False, num_buffers: int = 2): + import torch + + slot = MagicMock(name=f"Slot({name})") + slot.is_expert = is_expert + slot.expert_axis = 0 if is_expert else 0 + slot.owned_expert_ids = (0, 1, 2, 3) if is_expert else () + slot.buffers = [ + (f"{name}.buf_{i}", torch.zeros(4), object()) for i in range(num_buffers) + ] + slot.convert = MagicMock() + return slot + + +def _make_fake_model(slots): + model = MagicMock(name="Model") + model.build_slots = MagicMock(return_value=slots) + model.state_dict = MagicMock(return_value={}) + return model + + +# ---------------------------------------------------------------------------- +# Tests +# ---------------------------------------------------------------------------- + + +def test_construction_does_not_initialize_publisher(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + assert bc.is_initialized is False + assert bc._publisher is None + assert bc._model_slots is None + mocks["publisher_cls"].assert_not_called() + + +def test_is_primary_hsdp_rank_gates_correctly(broadcast_mod): + mod, _ = broadcast_mod + bc1 = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(dp_replicate_enabled=False), + ) + assert bc1.is_primary_hsdp_rank is True + + bc2 = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims( + dp_replicate_enabled=True, is_primary=True + ), + ) + assert bc2.is_primary_hsdp_rank is True + + bc3 = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims( + dp_replicate_enabled=True, is_primary=False + ), + ) + assert bc3.is_primary_hsdp_rank is False + + +def test_lazy_init_builds_publisher_with_right_args(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(host="mx-server", port=8001), + parallel_dims=_make_parallel_dims( + fsdp_world_size=4, tp_size=2, pp_size=1, ep_size=2 + ), + ) + model = _make_fake_model([_make_fake_slot(name="layer0")]) + bc.lazy_init(model) + + mocks["layout_cls"].assert_called_once() + layout_kwargs = mocks["layout_cls"].call_args.kwargs + assert layout_kwargs["fsdp_world_size"] == 4 + assert layout_kwargs["tp_world_size"] == 2 + assert layout_kwargs["pp_world_size"] == 1 + assert layout_kwargs["ep_world_size"] == 2 + + mocks["publisher_cls"].assert_called_once() + pub_kwargs = mocks["publisher_cls"].call_args.kwargs + assert pub_kwargs["mx_server_url"] == "mx-server:8001" + assert pub_kwargs["world_layout"] is mocks["layout"] + + mocks["publisher"].initialize.assert_called_once() + init_kwargs = mocks["publisher"].initialize.call_args.kwargs + assert init_kwargs["model_name"] == "bench/synthetic-1.5B" + + assert bc.is_initialized is True + + +def test_lazy_init_idempotent_on_second_call(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + model = _make_fake_model([_make_fake_slot(name="layer0")]) + bc.lazy_init(model) + bc.lazy_init(model) + assert mocks["publisher_cls"].call_count == 1 + + +def test_broadcast_weights_threads_compile_target_metadata(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(publish_compile_target=True), + parallel_dims=_make_parallel_dims(), + ) + slots = [_make_fake_slot(name="layer0", num_buffers=2)] + model = _make_fake_model(slots) + bc.broadcast_weights(model, step=42) + + assert mocks["publisher"].add_tensor.call_count == 2 + for call in mocks["publisher"].add_tensor.call_args_list: + assert call.kwargs["compile_target"] == "cutlass_fp8" + assert call.kwargs["compile_metadata"] == { + "block_size": 128, + "scale_layout": "per_channel", + } + + mocks["publisher"].publish.assert_called_once() + assert mocks["publisher"].publish.call_args.kwargs["version"] == 42 + mocks["publisher"].mark_ready.assert_called_once() + + +def test_broadcast_weights_publish_compile_target_false_uses_hf_raw(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(publish_compile_target=False), + parallel_dims=_make_parallel_dims(), + ) + slots = [_make_fake_slot(name="layer0", num_buffers=1)] + model = _make_fake_model(slots) + bc.broadcast_weights(model, step=1) + + call = mocks["publisher"].add_tensor.call_args + assert call.kwargs["compile_target"] == "hf_raw" + assert call.kwargs["compile_metadata"] is None + + +def test_broadcast_weights_threads_moe_expert_metadata(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + slots = [ + _make_fake_slot(name="layer0.dense", is_expert=False, num_buffers=1), + _make_fake_slot(name="layer0.experts", is_expert=True, num_buffers=1), + ] + model = _make_fake_model(slots) + bc.broadcast_weights(model, step=7) + + calls = mocks["publisher"].add_tensor.call_args_list + assert len(calls) == 2 + + dense_call = next( + c for c in calls if c.kwargs["name"].startswith("layer0.dense") + ) + assert dense_call.kwargs["is_expert"] is False + assert dense_call.kwargs["owned_expert_ids"] == () + + expert_call = next( + c for c in calls if c.kwargs["name"].startswith("layer0.experts") + ) + assert expert_call.kwargs["is_expert"] is True + assert expert_call.kwargs["expert_axis"] == 0 + assert expert_call.kwargs["owned_expert_ids"] == (0, 1, 2, 3) + + +def test_broadcast_weights_skips_non_primary_hsdp_rank(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims( + dp_replicate_enabled=True, is_primary=False + ), + ) + model = _make_fake_model([_make_fake_slot(name="layer0")]) + bc.broadcast_weights(model, step=1) + + mocks["publisher_cls"].assert_not_called() + mocks["publisher"].add_tensor.assert_not_called() + mocks["publisher"].publish.assert_not_called() + + +def test_broadcast_weights_calls_slot_convert(broadcast_mod): + """Each slot's `convert(state_dict)` must be invoked exactly once per + broadcast cycle. GatheredSlot's API takes only the state_dict — the + conversion (compile_target / quantization) is baked in at + `from_spec` creation time, not threaded per-call.""" + mod, _ = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + slots = [ + _make_fake_slot(name="layer0", num_buffers=1), + _make_fake_slot(name="layer1", num_buffers=1), + ] + model = _make_fake_model(slots) + bc.broadcast_weights(model, step=3) + for slot in slots: + slot.convert.assert_called_once() + # convert receives the state_dict (single positional arg). + args = slot.convert.call_args.args + assert isinstance(args[0], dict) + + +def test_lazy_init_forces_gathered_slots_for_pull_mode(broadcast_mod): + """lazy_init must temporarily raise `slots.SMALL_NON_EXPERT_BYTES` to + infinity while `model.build_slots(...)` runs — this forces every + non-expert weight into GatheredSlot (full tensor on each rank via + DTensor.full_tensor()) instead of ShardedSlot (1/N FSDP shard). + Pull-mode + same-rank routing requires the full tensor per rank. + The threshold must be restored after build_slots returns so other + code paths (e.g. nixl_mx push-mode broadcast running in the same + process) aren't perturbed.""" + mod, mocks = broadcast_mod + slots_mod = mocks["slots_mod"] + original_threshold = slots_mod.SMALL_NON_EXPERT_BYTES + + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + seen_thresholds = [] + model = _make_fake_model([_make_fake_slot(name="layer0", num_buffers=1)]) + # Capture the threshold value AT THE TIME build_slots is called + model.build_slots = MagicMock( + side_effect=lambda *_a, **_kw: ( + seen_thresholds.append(slots_mod.SMALL_NON_EXPERT_BYTES) + or [_make_fake_slot(name="layer0", num_buffers=1)] + ) + ) + bc.lazy_init(model) + + assert seen_thresholds, "build_slots was never called" + assert seen_thresholds[0] > 2 * 1024 * 1024, ( + f"threshold was {seen_thresholds[0]} during build_slots — must be " + f"raised (1<<60) so all non-expert weights become GatheredSlot" + ) + # Restored to original value after lazy_init returns. + assert slots_mod.SMALL_NON_EXPERT_BYTES == original_threshold + + +def test_shutdown_calls_publisher_shutdown_idempotent(broadcast_mod): + mod, mocks = broadcast_mod + bc = mod.NIXLMxV2WeightBroadcast( + output_dir=Path("/tmp/out"), + config=_make_config(), + parallel_dims=_make_parallel_dims(), + ) + model = _make_fake_model([_make_fake_slot(name="layer0", num_buffers=1)]) + bc.broadcast_weights(model, step=1) + + bc.shutdown() + assert mocks["publisher"].shutdown.call_count == 1 + bc.shutdown() + assert mocks["publisher"].shutdown.call_count == 1 + assert bc.is_initialized is False diff --git a/tests/unit/transport/test_mx_rendezvous_phase2.py b/tests/unit/transport/test_mx_rendezvous_phase2.py new file mode 100644 index 0000000000..d5a79f8bc8 --- /dev/null +++ b/tests/unit/transport/test_mx_rendezvous_phase2.py @@ -0,0 +1,197 @@ +"""Phase-2 unit tests for MxRendezvous helpers — no docker-compose required. + +Direct-loads mx_rendezvous.py to bypass prime_rl.transport's heavy +__init__.py import chain. +""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from dataclasses import dataclass +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +_HERE = Path(__file__).resolve().parent +_REPO_ROOT = _HERE.parent.parent.parent +_MOD_PATH = _REPO_ROOT / "src" / "prime_rl" / "transport" / "mx_rendezvous.py" + + +@pytest.fixture(scope="module") +def rdzmod(): + if "prime_rl" not in sys.modules: + pkg = types.ModuleType("prime_rl") + pkg.__path__ = [str(_REPO_ROOT / "src" / "prime_rl")] + sys.modules["prime_rl"] = pkg + if "prime_rl.transport" not in sys.modules: + sub = types.ModuleType("prime_rl.transport") + sub.__path__ = [str(_REPO_ROOT / "src" / "prime_rl" / "transport")] + sys.modules["prime_rl.transport"] = sub + + spec = importlib.util.spec_from_file_location( + "prime_rl.transport.mx_rendezvous", _MOD_PATH + ) + mod = importlib.util.module_from_spec(spec) + sys.modules["prime_rl.transport.mx_rendezvous"] = mod + spec.loader.exec_module(mod) + return mod + + +@dataclass +class _FakeInst: + worker_id: str + worker_rank: int + mx_source_id: str = "fake-source" + + +def test_filter_same_rank_keeps_only_matching(rdzmod): + insts = [_FakeInst("w0", 0), _FakeInst("w1", 1), _FakeInst("w2", 2), _FakeInst("w1b", 1)] + kept = rdzmod._filter_same_rank(insts, rank=1) + assert [i.worker_id for i in kept] == ["w1", "w1b"] + + +def test_freshest_per_rank_keeps_largest_updated_at(rdzmod): + insts = [_FakeInst("w0_old", 0), _FakeInst("w0_new", 0), _FakeInst("w1_only", 1), _FakeInst("w0_mid", 0)] + metas = {"w0_old": 100, "w0_new": 300, "w1_only": 200, "w0_mid": 200} + kept = rdzmod._freshest_per_rank(insts, metas=metas) + by_rank = {i.worker_rank: i.worker_id for i in kept} + assert by_rank == {0: "w0_new", 1: "w1_only"} + + +def test_freshest_per_rank_handles_missing_updated_at(rdzmod): + insts = [_FakeInst("ghost", 5), _FakeInst("known", 5)] + metas = {"known": 1} + kept = rdzmod._freshest_per_rank(insts, metas=metas) + assert len(kept) == 1 + assert kept[0].worker_id == "known" + + +def test_freshest_per_rank_returns_lone_unknown_when_no_rival(rdzmod): + insts = [_FakeInst("only_ghost", 7)] + kept = rdzmod._freshest_per_rank(insts, metas={}) + assert len(kept) == 1 + assert kept[0].worker_id == "only_ghost" + + +def test_freshest_per_rank_sorted_by_rank(rdzmod): + insts = [_FakeInst("w2", 2), _FakeInst("w0", 0), _FakeInst("w1", 1)] + kept = rdzmod._freshest_per_rank(insts, metas={"w0": 1, "w1": 1, "w2": 1}) + assert [i.worker_rank for i in kept] == [0, 1, 2] + + +def test_publish_starts_and_close_stops_heartbeat(rdzmod, monkeypatch): + fake_client = MagicMock() + fake_client.publish_metadata.return_value = "mx-source-xyz" + + spawned = [] + + class _FakeHB: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.started = False + self.stopped = False + spawned.append(self) + + def start(self): + self.started = True + + def stop(self): + self.stopped = True + + monkeypatch.setattr(rdzmod, "HeartbeatThread", _FakeHB) + + rdz = rdzmod.MxRendezvous(client=fake_client, role="trainer", rank=2, peer_world_size=4, model_name="m") + sid = rdz.publish(nixl_metadata=b"x", tensors=[]) + assert sid == "mx-source-xyz" + assert len(spawned) == 1 + hb = spawned[0] + assert hb.started + assert hb.kwargs["worker_rank"] == 2 + assert hb.kwargs["mx_source_id"] == "mx-source-xyz" + assert hb.kwargs["nixl_manager"] is None + + rdz.close() + assert hb.stopped + rdz.close() + + +def test_publish_skips_heartbeat_when_disabled(rdzmod, monkeypatch): + fake_client = MagicMock() + fake_client.publish_metadata.return_value = "sid" + + spawned = [] + + class _FakeHB: + def __init__(self, **kwargs): + spawned.append(self) + + def start(self): + pass + + monkeypatch.setattr(rdzmod, "HeartbeatThread", _FakeHB) + rdz = rdzmod.MxRendezvous( + client=fake_client, role="inference", rank=0, peer_world_size=1, model_name="m", enable_heartbeat=False + ) + rdz.publish() + assert spawned == [] + + +def test_publish_swallows_heartbeat_start_failure(rdzmod, monkeypatch): + fake_client = MagicMock() + fake_client.publish_metadata.return_value = "sid" + + class _BrokenHB: + def __init__(self, **kwargs): + raise RuntimeError("can't allocate thread") + + monkeypatch.setattr(rdzmod, "HeartbeatThread", _BrokenHB) + rdz = rdzmod.MxRendezvous(client=fake_client, role="trainer", rank=0, peer_world_size=1, model_name="m") + sid = rdz.publish() + assert sid == "sid" + assert rdz._heartbeat is None + + +def test_collect_updated_at_returns_zero_on_failure(rdzmod): + fake_client = MagicMock() + fake_client.get_metadata.side_effect = RuntimeError("boom") + rdz = rdzmod.MxRendezvous( + client=fake_client, role="trainer", rank=0, peer_world_size=1, model_name="m", enable_heartbeat=False + ) + out = rdz._collect_updated_at([_FakeInst("a", 0), _FakeInst("b", 1)]) + assert out == {"a": 0, "b": 0} + + +def test_collect_updated_at_returns_zero_on_not_found(rdzmod): + fake_client = MagicMock() + + class _Resp: + found = False + worker = MagicMock(updated_at=0) + + fake_client.get_metadata.return_value = _Resp() + rdz = rdzmod.MxRendezvous( + client=fake_client, role="trainer", rank=0, peer_world_size=1, model_name="m", enable_heartbeat=False + ) + out = rdz._collect_updated_at([_FakeInst("x", 0)]) + assert out == {"x": 0} + + +def test_collect_updated_at_returns_real_value(rdzmod): + fake_client = MagicMock() + + class _Resp: + found = True + + def __init__(self): + self.worker = MagicMock(updated_at=42) + + fake_client.get_metadata.return_value = _Resp() + rdz = rdzmod.MxRendezvous( + client=fake_client, role="trainer", rank=0, peer_world_size=1, model_name="m", enable_heartbeat=False + ) + out = rdz._collect_updated_at([_FakeInst("x", 0)]) + assert out == {"x": 42}