Skip to content

Commit d6cd918

Browse files
committed
[TRTLLM-13141][fix] Refine SourceIdentity compatibility metadata
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent c44569c commit d6cd918

9 files changed

Lines changed: 308 additions & 309 deletions

File tree

tensorrt_llm/_torch/memory/gpu_memory_backend.py

Lines changed: 76 additions & 76 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@
1515

1616
"""MX (ModelExpress) checkpoint loader.
1717
18-
Thin adapter on top of the upstream ``modelexpress`` Python client
19-
(``ai-dynamo/modelexpress``). All NIXL/RDMA mechanics (agent setup,
18+
Thin adapter on top of the upstream `modelexpress` Python client
19+
(`ai-dynamo/modelexpress`). All NIXL/RDMA mechanics (agent setup,
2020
tensor registration, source-target name matching, dtype-cast handling,
21-
PVC fallback, etc.) live in the upstream ``MxLiveWeightLoader`` and
22-
``publish_model_params`` helpers — we only call them at the right
21+
PVC fallback, etc.) live in the upstream `MxLiveWeightLoader` and
22+
`publish_model_params` helpers — we only call them at the right
2323
points in TRT-LLM's loading lifecycle.
2424
2525
When no MX server is reachable (or the upstream library is not
2626
installed), this loader transparently falls back to standard
2727
HuggingFace checkpoint loading (disk -> CPU -> GPU) by way of its
28-
``HfCheckpointLoader`` base class.
28+
`HfCheckpointLoader` base class.
2929
"""
3030

3131
import os
@@ -45,14 +45,14 @@
4545
from tensorrt_llm._torch.weight_sharing import (
4646
IdentityCheckPolicy,
4747
SourceIdentity,
48-
check_source_identity,
48+
check_weight_sharing_compatibility,
4949
)
5050
from tensorrt_llm.logger import logger
5151
from tensorrt_llm.mapping import Mapping
5252

53-
# Defensive default for the upstream ``MX_SOURCE_QUERY_TIMEOUT`` env var.
54-
# The upstream ``MxLiveWeightLoader`` polls the MX server every 5 s for up
55-
# to ``MX_SOURCE_QUERY_TIMEOUT`` seconds (default 3600 = 1 hour) waiting
53+
# Defensive default for the upstream `MX_SOURCE_QUERY_TIMEOUT` env var.
54+
# The upstream `MxLiveWeightLoader` polls the MX server every 5 s for up
55+
# to `MX_SOURCE_QUERY_TIMEOUT` seconds (default 3600 = 1 hour) waiting
5656
# for a source. On a cold cluster (no donor up yet), this means the very
5757
# first replica blocks for an hour before falling back to disk. We cap
5858
# the default at 30 s so first-replica startup degrades gracefully; users
@@ -83,19 +83,19 @@ def _temporary_env(key: str, value: Optional[str]):
8383
class MXCheckpointLoader(HfCheckpointLoader):
8484
"""Checkpoint loader for MX (ModelExpress) P2P weight transfer.
8585
86-
When an MX server is reachable AND the upstream ``modelexpress``
86+
When an MX server is reachable AND the upstream `modelexpress`
8787
library is installed, weights are transferred directly from a
8888
source instance via NIXL/RDMA, bypassing disk I/O. The source
89-
publishes its weights *before* ``post_load_weights()`` runs so
89+
publishes its weights *before* `post_load_weights()` runs so
9090
targets receive raw loaded state and can run their own
9191
post-load transforms.
9292
9393
When the MX server or library is unavailable, this loader
9494
transparently falls back to standard HuggingFace checkpoint
95-
loading via the parent ``HfCheckpointLoader``.
95+
loading via the parent `HfCheckpointLoader`.
9696
9797
All transport-level mechanics (NIXL, dtype casts, source matching,
98-
fallback) are delegated to ``modelexpress.trtllm_live_transfer``
98+
fallback) are delegated to `modelexpress.trtllm_live_transfer`
9999
so that this class stays a thin adapter — when the MX wire
100100
protocol or transport evolves, only the upstream library needs
101101
to track it.
@@ -121,10 +121,10 @@ def __init__(
121121
# _checkpoint_format directly does not see a stale value.
122122
self._checkpoint_format = "MX"
123123
self._mx_server_url = mx_server_url
124-
# ``model_name`` is the human-readable identity to publish/look up
124+
# `model_name` is the human-readable identity to publish/look up
125125
# under on the MX server. Typically the user-supplied
126-
# ``llm_args.model`` (a Hub ID like ``"Qwen/Qwen2.5-72B-Instruct"``
127-
# or a local path). ``publish_as_source()`` resolves it via
126+
# `llm_args.model` (a Hub ID like `"Qwen/Qwen2.5-72B-Instruct"`
127+
# or a local path). `publish_as_source()` resolves it via
128128
# :func:`_resolve_mx_model_name` (with HF-snapshot path fallback).
129129
self._model_name = str(model_name) if model_name is not None else None
130130
self._query_timeout_s = query_timeout_s
@@ -146,9 +146,9 @@ def mx_server_url(self) -> Optional[str]:
146146
def model_name(self) -> Optional[str]:
147147
"""Explicit model identity passed to the constructor (if any).
148148
149-
Note this is the *as-configured* value (e.g. ``llm_args.model``),
149+
Note this is the *as-configured* value (e.g. `llm_args.model`),
150150
not the final resolved identity that ends up in the published
151-
``MODEL_NAME``. The full resolution (with env var and basename
151+
`MODEL_NAME`. The full resolution (with env var and basename
152152
fallbacks) happens inside :meth:`publish_as_source`.
153153
"""
154154
return self._model_name
@@ -160,42 +160,42 @@ def query_timeout_s(self) -> Optional[int]:
160160
def is_weights_preloaded(self) -> bool:
161161
"""Whether the last :meth:`load_weights` call wired weights directly into the model.
162162
163-
Reports the result of the most recent ``load_weights()`` invocation
164-
on this loader instance. ``ModelLoader`` consults this signal to
163+
Reports the result of the most recent `load_weights()` invocation
164+
on this loader instance. `ModelLoader` consults this signal to
165165
decide whether to run the standard weight-mapping pipeline:
166166
167-
- ``True``: MX P2P transfer succeeded; weights already live in
167+
- `True`: MX P2P transfer succeeded; weights already live in
168168
model parameter buffers via direct writes from the upstream
169-
``MxLiveWeightLoader``. The mapping pipeline is skipped for
169+
`MxLiveWeightLoader`. The mapping pipeline is skipped for
170170
all parameters covered by P2P.
171-
- ``False``: either P2P was never attempted (no MX server URL,
171+
- `False`: either P2P was never attempted (no MX server URL,
172172
no model reference, library missing) or it failed and we
173173
fell back to disk; weights still need to flow through
174-
``model.load_weights(...)`` via the standard mapper.
174+
`model.load_weights(...)` via the standard mapper.
175175
176176
Note this is a per-loader-instance flag, not a global one. The
177-
flag is reset to ``False`` at the start of each ``load_weights``
177+
flag is reset to `False` at the start of each `load_weights`
178178
call, so the value is only meaningful immediately after a
179179
successful call.
180180
181181
Returns:
182-
``True`` iff the last ``load_weights`` populated the model
183-
via P2P; ``False`` before any call and on any fallback path.
182+
`True` iff the last `load_weights` populated the model
183+
via P2P; `False` before any call and on any fallback path.
184184
"""
185185
return self._p2p_succeeded
186186

187187
def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[str, Any]:
188188
"""Load weights, preferring MX P2P transfer when available.
189189
190190
Delegates the actual transfer to the upstream
191-
``modelexpress.trtllm_live_transfer.MxLiveWeightLoader``,
191+
`modelexpress.trtllm_live_transfer.MxLiveWeightLoader`,
192192
which handles NIXL setup, source discovery, name matching,
193193
dtype casting, and PVC fallback for size-mismatched tensors.
194194
195195
Args:
196196
checkpoint_dir: Path to the HF checkpoint directory.
197197
mapping: Distributed mapping configuration.
198-
**kwargs: Additional keyword arguments. When ``model`` is
198+
**kwargs: Additional keyword arguments. When `model` is
199199
passed it is used as the target for direct P2P writes.
200200
201201
Returns:
@@ -297,7 +297,7 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[
297297
def _resolve_query_timeout_override(
298298
self, checkpoint_dir: str, MxClient: Type[Any], build_identity: Callable[..., Any]
299299
) -> Optional[str]:
300-
"""Return temporary ``MX_SOURCE_QUERY_TIMEOUT`` override, if any."""
300+
"""Return temporary `MX_SOURCE_QUERY_TIMEOUT` override, if any."""
301301
if self._query_timeout_s is not None:
302302
return str(self._query_timeout_s)
303303

@@ -343,7 +343,7 @@ def _source_identity_compatible(
343343
"""Whether the MX source's identity is compatible with this receiver.
344344
345345
Compares the receiver's local :class:`SourceIdentity` against the
346-
publisher's via ``check_source_identity`` with the ``WARN_FALLBACK``
346+
publisher's via `check_weight_sharing_compatibility` with the `WARN_FALLBACK`
347347
policy.
348348
349349
Args:
@@ -354,13 +354,13 @@ def _source_identity_compatible(
354354
(forwarded to the fetch seam).
355355
356356
Returns:
357-
``True`` to proceed with P2P only when both identities are present
358-
and compatible. ``False`` when either identity is missing or the
357+
`True` to proceed with P2P only when both identities are present
358+
and compatible. `False` when either identity is missing or the
359359
identities mismatch, so the caller falls back to disk loading.
360360
"""
361361
local_identity = self._local_source_identity
362362
source_identity = self._fetch_source_identity(checkpoint_dir, MxClient, build_identity)
363-
decision = check_source_identity(
363+
decision = check_weight_sharing_compatibility(
364364
local_identity,
365365
source_identity,
366366
IdentityCheckPolicy.WARN_FALLBACK,
@@ -378,7 +378,7 @@ def _fetch_source_identity(
378378
build_identity: Builder used to derive the publisher identity.
379379
380380
Returns:
381-
The publisher's identity, or ``None`` when it cannot be fetched
381+
The publisher's identity, or `None` when it cannot be fetched
382382
yet (the compatibility gate then rejects P2P and falls back).
383383
"""
384384
# TODO(SOURCE-IDENTITY/MX-2): read the publisher's identity from the MX
@@ -406,21 +406,21 @@ def publish_as_source(
406406
) -> None:
407407
"""Publish this instance's weights so other ranks can pull via P2P.
408408
409-
Called by the integration in ``model_loader.py`` *before*
410-
``post_load_weights()`` so targets receive raw loaded state and
409+
Called by the integration in `model_loader.py` *before*
410+
`post_load_weights()` so targets receive raw loaded state and
411411
can apply their own post-load transforms.
412412
413413
Delegates to the upstream
414-
``modelexpress.trtllm_live_transfer.publish_model_params``
414+
`modelexpress.trtllm_live_transfer.publish_model_params`
415415
helper, which handles the per-rank NIXL setup, tensor
416416
registration, and gRPC publish.
417417
418418
Args:
419419
model: The model whose weights to publish.
420420
checkpoint_dir: Checkpoint directory. Used as a last-resort
421-
fallback for resolving the ``MODEL_NAME`` identity when
422-
neither ``model_name`` was passed to the constructor nor
423-
``MODEL_NAME`` is set in the environment.
421+
fallback for resolving the `MODEL_NAME` identity when
422+
neither `model_name` was passed to the constructor nor
423+
`MODEL_NAME` is set in the environment.
424424
"""
425425

426426
if self._mx_server_url is None:
@@ -515,14 +515,14 @@ def _resolve_mx_model_name(model_name_arg: Optional[str], checkpoint_dir: Option
515515
516516
Resolution order (first non-empty wins):
517517
518-
1. ``model_name_arg`` — the explicit value passed at construction
519-
time (typically ``llm_args.model``: a Hub ID like
520-
``"Qwen/Qwen2.5-72B-Instruct"`` or a local path).
521-
2. ``MODEL_NAME`` env var — upstream's existing convention.
522-
3. ``checkpoint_dir`` basename, with HF-snapshot path fallback so
523-
``.../models--<org>--<name>/snapshots/<sha>/`` resolves to
524-
``"<org>/<name>"`` instead of the commit hash.
525-
4. Literal ``"unknown"`` — matches upstream's own sentinel.
518+
1. `model_name_arg` — the explicit value passed at construction
519+
time (typically `llm_args.model`: a Hub ID like
520+
`"Qwen/Qwen2.5-72B-Instruct"` or a local path).
521+
2. `MODEL_NAME` env var — upstream's existing convention.
522+
3. `checkpoint_dir` basename, with HF-snapshot path fallback so
523+
`.../models--<org>--<name>/snapshots/<sha>/` resolves to
524+
`"<org>/<name>"` instead of the commit hash.
525+
4. Literal `"unknown"` — matches upstream's own sentinel.
526526
"""
527527
candidate = model_name_arg or os.environ.get("MODEL_NAME") or checkpoint_dir
528528
if not candidate:
@@ -533,18 +533,18 @@ def _resolve_mx_model_name(model_name_arg: Optional[str], checkpoint_dir: Option
533533
def _normalize_model_identity(s: str) -> str:
534534
"""Convert a model identifier to a stable, human-readable name.
535535
536-
Hub IDs (``"org/name"``) and arbitrary user-provided strings are
536+
Hub IDs (`"org/name"`) and arbitrary user-provided strings are
537537
returned unchanged. Filesystem paths are reduced to a basename, with
538-
HuggingFace cache snapshot layouts (``snapshots/<commit-sha>/``)
539-
walked up to recover the original ``"org/name"`` identity.
538+
HuggingFace cache snapshot layouts (`snapshots/<commit-sha>/`)
539+
walked up to recover the original `"org/name"` identity.
540540
"""
541541
if not s:
542542
return "unknown"
543543

544-
# Heuristic: a Hub ID is bare ``"name"`` or ``"org/name"``. Anything
544+
# Heuristic: a Hub ID is bare `"name"` or `"org/name"`. Anything
545545
# that starts with a path separator/expansion or contains more than
546546
# one "/" is treated as a path. Single-"/" strings remain ambiguous;
547-
# avoid an NFS ``exists`` probe for common Hub IDs and only touch the
547+
# avoid an NFS `exists` probe for common Hub IDs and only touch the
548548
# filesystem when the string has explicit local-path syntax.
549549
looks_like_path = s.startswith(("/", "./", "../", "~")) or s.count("/") > 1
550550
if not looks_like_path:
@@ -553,9 +553,9 @@ def _normalize_model_identity(s: str) -> str:
553553
p = Path(s).expanduser()
554554
name = p.name
555555
if name and "snapshots" in p.parts:
556-
# HF cache layout: ``.../models--<org>--<name>/snapshots/<sha>/``.
557-
# Walk up to find the ``models--<org>--<name>`` directory and
558-
# un-mangle it back to ``"<org>/<name>"``.
556+
# HF cache layout: `.../models--<org>--<name>/snapshots/<sha>/`.
557+
# Walk up to find the `models--<org>--<name>` directory and
558+
# un-mangle it back to `"<org>/<name>"`.
559559
for ancestor in p.parents:
560560
if ancestor.name.startswith("models--"):
561561
return ancestor.name[len("models--") :].replace("--", "/")

0 commit comments

Comments
 (0)