1515
1616"""MX (ModelExpress) checkpoint loader.
1717
18- Thin adapter on top of the upstream `` modelexpress` ` Python client
19- (`` ai-dynamo/modelexpress` `). All NIXL/RDMA mechanics (agent setup,
18+ Thin adapter on top of the upstream `modelexpress` Python client
19+ (`ai-dynamo/modelexpress`). All NIXL/RDMA mechanics (agent setup,
2020tensor registration, source-target name matching, dtype-cast handling,
21- PVC fallback, etc.) live in the upstream `` MxLiveWeightLoader` ` and
22- `` publish_model_params` ` helpers — we only call them at the right
21+ PVC fallback, etc.) live in the upstream `MxLiveWeightLoader` and
22+ `publish_model_params` helpers — we only call them at the right
2323points in TRT-LLM's loading lifecycle.
2424
2525When no MX server is reachable (or the upstream library is not
2626installed), this loader transparently falls back to standard
2727HuggingFace checkpoint loading (disk -> CPU -> GPU) by way of its
28- `` HfCheckpointLoader` ` base class.
28+ `HfCheckpointLoader` base class.
2929"""
3030
3131import os
4545from tensorrt_llm ._torch .weight_sharing import (
4646 IdentityCheckPolicy ,
4747 SourceIdentity ,
48- check_source_identity ,
48+ check_weight_sharing_compatibility ,
4949)
5050from tensorrt_llm .logger import logger
5151from tensorrt_llm .mapping import Mapping
5252
53- # Defensive default for the upstream `` MX_SOURCE_QUERY_TIMEOUT` ` env var.
54- # The upstream `` MxLiveWeightLoader` ` polls the MX server every 5 s for up
55- # to `` MX_SOURCE_QUERY_TIMEOUT` ` seconds (default 3600 = 1 hour) waiting
53+ # Defensive default for the upstream `MX_SOURCE_QUERY_TIMEOUT` env var.
54+ # The upstream `MxLiveWeightLoader` polls the MX server every 5 s for up
55+ # to `MX_SOURCE_QUERY_TIMEOUT` seconds (default 3600 = 1 hour) waiting
5656# for a source. On a cold cluster (no donor up yet), this means the very
5757# first replica blocks for an hour before falling back to disk. We cap
5858# the default at 30 s so first-replica startup degrades gracefully; users
@@ -83,19 +83,19 @@ def _temporary_env(key: str, value: Optional[str]):
8383class MXCheckpointLoader (HfCheckpointLoader ):
8484 """Checkpoint loader for MX (ModelExpress) P2P weight transfer.
8585
86- When an MX server is reachable AND the upstream `` modelexpress` `
86+ When an MX server is reachable AND the upstream `modelexpress`
8787 library is installed, weights are transferred directly from a
8888 source instance via NIXL/RDMA, bypassing disk I/O. The source
89- publishes its weights *before* `` post_load_weights()` ` runs so
89+ publishes its weights *before* `post_load_weights()` runs so
9090 targets receive raw loaded state and can run their own
9191 post-load transforms.
9292
9393 When the MX server or library is unavailable, this loader
9494 transparently falls back to standard HuggingFace checkpoint
95- loading via the parent `` HfCheckpointLoader` `.
95+ loading via the parent `HfCheckpointLoader`.
9696
9797 All transport-level mechanics (NIXL, dtype casts, source matching,
98- fallback) are delegated to `` modelexpress.trtllm_live_transfer` `
98+ fallback) are delegated to `modelexpress.trtllm_live_transfer`
9999 so that this class stays a thin adapter — when the MX wire
100100 protocol or transport evolves, only the upstream library needs
101101 to track it.
@@ -121,10 +121,10 @@ def __init__(
121121 # _checkpoint_format directly does not see a stale value.
122122 self ._checkpoint_format = "MX"
123123 self ._mx_server_url = mx_server_url
124- # `` model_name` ` is the human-readable identity to publish/look up
124+ # `model_name` is the human-readable identity to publish/look up
125125 # under on the MX server. Typically the user-supplied
126- # `` llm_args.model`` (a Hub ID like `` "Qwen/Qwen2.5-72B-Instruct"` `
127- # or a local path). `` publish_as_source()` ` resolves it via
126+ # `llm_args.model` (a Hub ID like `"Qwen/Qwen2.5-72B-Instruct"`
127+ # or a local path). `publish_as_source()` resolves it via
128128 # :func:`_resolve_mx_model_name` (with HF-snapshot path fallback).
129129 self ._model_name = str (model_name ) if model_name is not None else None
130130 self ._query_timeout_s = query_timeout_s
@@ -146,9 +146,9 @@ def mx_server_url(self) -> Optional[str]:
146146 def model_name (self ) -> Optional [str ]:
147147 """Explicit model identity passed to the constructor (if any).
148148
149- Note this is the *as-configured* value (e.g. `` llm_args.model` `),
149+ Note this is the *as-configured* value (e.g. `llm_args.model`),
150150 not the final resolved identity that ends up in the published
151- `` MODEL_NAME` `. The full resolution (with env var and basename
151+ `MODEL_NAME`. The full resolution (with env var and basename
152152 fallbacks) happens inside :meth:`publish_as_source`.
153153 """
154154 return self ._model_name
@@ -160,42 +160,42 @@ def query_timeout_s(self) -> Optional[int]:
160160 def is_weights_preloaded (self ) -> bool :
161161 """Whether the last :meth:`load_weights` call wired weights directly into the model.
162162
163- Reports the result of the most recent `` load_weights()` ` invocation
164- on this loader instance. `` ModelLoader` ` consults this signal to
163+ Reports the result of the most recent `load_weights()` invocation
164+ on this loader instance. `ModelLoader` consults this signal to
165165 decide whether to run the standard weight-mapping pipeline:
166166
167- - `` True` `: MX P2P transfer succeeded; weights already live in
167+ - `True`: MX P2P transfer succeeded; weights already live in
168168 model parameter buffers via direct writes from the upstream
169- `` MxLiveWeightLoader` `. The mapping pipeline is skipped for
169+ `MxLiveWeightLoader`. The mapping pipeline is skipped for
170170 all parameters covered by P2P.
171- - `` False` `: either P2P was never attempted (no MX server URL,
171+ - `False`: either P2P was never attempted (no MX server URL,
172172 no model reference, library missing) or it failed and we
173173 fell back to disk; weights still need to flow through
174- `` model.load_weights(...)` ` via the standard mapper.
174+ `model.load_weights(...)` via the standard mapper.
175175
176176 Note this is a per-loader-instance flag, not a global one. The
177- flag is reset to `` False`` at the start of each `` load_weights` `
177+ flag is reset to `False` at the start of each `load_weights`
178178 call, so the value is only meaningful immediately after a
179179 successful call.
180180
181181 Returns:
182- `` True`` iff the last `` load_weights` ` populated the model
183- via P2P; `` False` ` before any call and on any fallback path.
182+ `True` iff the last `load_weights` populated the model
183+ via P2P; `False` before any call and on any fallback path.
184184 """
185185 return self ._p2p_succeeded
186186
187187 def load_weights (self , checkpoint_dir : str , mapping : Mapping , ** kwargs ) -> dict [str , Any ]:
188188 """Load weights, preferring MX P2P transfer when available.
189189
190190 Delegates the actual transfer to the upstream
191- `` modelexpress.trtllm_live_transfer.MxLiveWeightLoader` `,
191+ `modelexpress.trtllm_live_transfer.MxLiveWeightLoader`,
192192 which handles NIXL setup, source discovery, name matching,
193193 dtype casting, and PVC fallback for size-mismatched tensors.
194194
195195 Args:
196196 checkpoint_dir: Path to the HF checkpoint directory.
197197 mapping: Distributed mapping configuration.
198- **kwargs: Additional keyword arguments. When `` model` ` is
198+ **kwargs: Additional keyword arguments. When `model` is
199199 passed it is used as the target for direct P2P writes.
200200
201201 Returns:
@@ -297,7 +297,7 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[
297297 def _resolve_query_timeout_override (
298298 self , checkpoint_dir : str , MxClient : Type [Any ], build_identity : Callable [..., Any ]
299299 ) -> Optional [str ]:
300- """Return temporary `` MX_SOURCE_QUERY_TIMEOUT` ` override, if any."""
300+ """Return temporary `MX_SOURCE_QUERY_TIMEOUT` override, if any."""
301301 if self ._query_timeout_s is not None :
302302 return str (self ._query_timeout_s )
303303
@@ -343,7 +343,7 @@ def _source_identity_compatible(
343343 """Whether the MX source's identity is compatible with this receiver.
344344
345345 Compares the receiver's local :class:`SourceIdentity` against the
346- publisher's via ``check_source_identity`` with the `` WARN_FALLBACK` `
346+ publisher's via `check_weight_sharing_compatibility` with the `WARN_FALLBACK`
347347 policy.
348348
349349 Args:
@@ -354,13 +354,13 @@ def _source_identity_compatible(
354354 (forwarded to the fetch seam).
355355
356356 Returns:
357- `` True` ` to proceed with P2P only when both identities are present
358- and compatible. `` False` ` when either identity is missing or the
357+ `True` to proceed with P2P only when both identities are present
358+ and compatible. `False` when either identity is missing or the
359359 identities mismatch, so the caller falls back to disk loading.
360360 """
361361 local_identity = self ._local_source_identity
362362 source_identity = self ._fetch_source_identity (checkpoint_dir , MxClient , build_identity )
363- decision = check_source_identity (
363+ decision = check_weight_sharing_compatibility (
364364 local_identity ,
365365 source_identity ,
366366 IdentityCheckPolicy .WARN_FALLBACK ,
@@ -378,7 +378,7 @@ def _fetch_source_identity(
378378 build_identity: Builder used to derive the publisher identity.
379379
380380 Returns:
381- The publisher's identity, or `` None` ` when it cannot be fetched
381+ The publisher's identity, or `None` when it cannot be fetched
382382 yet (the compatibility gate then rejects P2P and falls back).
383383 """
384384 # TODO(SOURCE-IDENTITY/MX-2): read the publisher's identity from the MX
@@ -406,21 +406,21 @@ def publish_as_source(
406406 ) -> None :
407407 """Publish this instance's weights so other ranks can pull via P2P.
408408
409- Called by the integration in `` model_loader.py` ` *before*
410- `` post_load_weights()` ` so targets receive raw loaded state and
409+ Called by the integration in `model_loader.py` *before*
410+ `post_load_weights()` so targets receive raw loaded state and
411411 can apply their own post-load transforms.
412412
413413 Delegates to the upstream
414- `` modelexpress.trtllm_live_transfer.publish_model_params` `
414+ `modelexpress.trtllm_live_transfer.publish_model_params`
415415 helper, which handles the per-rank NIXL setup, tensor
416416 registration, and gRPC publish.
417417
418418 Args:
419419 model: The model whose weights to publish.
420420 checkpoint_dir: Checkpoint directory. Used as a last-resort
421- fallback for resolving the `` MODEL_NAME` ` identity when
422- neither `` model_name` ` was passed to the constructor nor
423- `` MODEL_NAME` ` is set in the environment.
421+ fallback for resolving the `MODEL_NAME` identity when
422+ neither `model_name` was passed to the constructor nor
423+ `MODEL_NAME` is set in the environment.
424424 """
425425
426426 if self ._mx_server_url is None :
@@ -515,14 +515,14 @@ def _resolve_mx_model_name(model_name_arg: Optional[str], checkpoint_dir: Option
515515
516516 Resolution order (first non-empty wins):
517517
518- 1. `` model_name_arg` ` — the explicit value passed at construction
519- time (typically `` llm_args.model` `: a Hub ID like
520- `` "Qwen/Qwen2.5-72B-Instruct"` ` or a local path).
521- 2. `` MODEL_NAME` ` env var — upstream's existing convention.
522- 3. `` checkpoint_dir` ` basename, with HF-snapshot path fallback so
523- `` .../models--<org>--<name>/snapshots/<sha>/` ` resolves to
524- `` "<org>/<name>"` ` instead of the commit hash.
525- 4. Literal `` "unknown"` ` — matches upstream's own sentinel.
518+ 1. `model_name_arg` — the explicit value passed at construction
519+ time (typically `llm_args.model`: a Hub ID like
520+ `"Qwen/Qwen2.5-72B-Instruct"` or a local path).
521+ 2. `MODEL_NAME` env var — upstream's existing convention.
522+ 3. `checkpoint_dir` basename, with HF-snapshot path fallback so
523+ `.../models--<org>--<name>/snapshots/<sha>/` resolves to
524+ `"<org>/<name>"` instead of the commit hash.
525+ 4. Literal `"unknown"` — matches upstream's own sentinel.
526526 """
527527 candidate = model_name_arg or os .environ .get ("MODEL_NAME" ) or checkpoint_dir
528528 if not candidate :
@@ -533,18 +533,18 @@ def _resolve_mx_model_name(model_name_arg: Optional[str], checkpoint_dir: Option
533533def _normalize_model_identity (s : str ) -> str :
534534 """Convert a model identifier to a stable, human-readable name.
535535
536- Hub IDs (`` "org/name"` `) and arbitrary user-provided strings are
536+ Hub IDs (`"org/name"`) and arbitrary user-provided strings are
537537 returned unchanged. Filesystem paths are reduced to a basename, with
538- HuggingFace cache snapshot layouts (`` snapshots/<commit-sha>/` `)
539- walked up to recover the original `` "org/name"` ` identity.
538+ HuggingFace cache snapshot layouts (`snapshots/<commit-sha>/`)
539+ walked up to recover the original `"org/name"` identity.
540540 """
541541 if not s :
542542 return "unknown"
543543
544- # Heuristic: a Hub ID is bare `` "name"`` or `` "org/name"` `. Anything
544+ # Heuristic: a Hub ID is bare `"name"` or `"org/name"`. Anything
545545 # that starts with a path separator/expansion or contains more than
546546 # one "/" is treated as a path. Single-"/" strings remain ambiguous;
547- # avoid an NFS `` exists` ` probe for common Hub IDs and only touch the
547+ # avoid an NFS `exists` probe for common Hub IDs and only touch the
548548 # filesystem when the string has explicit local-path syntax.
549549 looks_like_path = s .startswith (("/" , "./" , "../" , "~" )) or s .count ("/" ) > 1
550550 if not looks_like_path :
@@ -553,9 +553,9 @@ def _normalize_model_identity(s: str) -> str:
553553 p = Path (s ).expanduser ()
554554 name = p .name
555555 if name and "snapshots" in p .parts :
556- # HF cache layout: `` .../models--<org>--<name>/snapshots/<sha>/` `.
557- # Walk up to find the `` models--<org>--<name>` ` directory and
558- # un-mangle it back to `` "<org>/<name>"` `.
556+ # HF cache layout: `.../models--<org>--<name>/snapshots/<sha>/`.
557+ # Walk up to find the `models--<org>--<name>` directory and
558+ # un-mangle it back to `"<org>/<name>"`.
559559 for ancestor in p .parents :
560560 if ancestor .name .startswith ("models--" ):
561561 return ancestor .name [len ("models--" ) :].replace ("--" , "/" )
0 commit comments