[TRTLLM-13250][feat] Wave 5: Enable MX post-transform Llama receiver#15432
[TRTLLM-13250][feat] Wave 5: Enable MX post-transform Llama receiver#15432chienchunhung wants to merge 4 commits into
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #54683 [ run ] triggered by Bot. Commit: |
|
PR_Github #54683 [ run ] completed with state
|
ae210cb to
f123c77
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #54888 [ run ] triggered by Bot. Commit: |
|
PR_Github #54888 [ run ] completed with state
|
f123c77 to
14a4537
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #54955 [ run ] triggered by Bot. Commit: |
|
PR_Github #54955 [ run ] completed with state |
📝 WalkthroughWalkthroughThe PR refactors the model weight-loading lifecycle by splitting ChangesWeight-load Lifecycle Refactor and MX Staged Receiver
Sequence Diagram(s)sequenceDiagram
participant ModelLoader
participant MXCheckpointLoader
participant MxClient
participant HfCheckpointLoader
participant Model
ModelLoader->>MXCheckpointLoader: load_weights(allow_post_transform_weights=True, source_identity=...)
MXCheckpointLoader->>MxClient: fetch source metadata
MxClient-->>MXCheckpointLoader: metadata (layout, protocol_version, serialized identity)
MXCheckpointLoader->>MXCheckpointLoader: _source_metadata_identity_compatible()
alt metadata compatible + post_transform layout + allowed
MXCheckpointLoader->>MxClient: RDMA P2P weight transfer
MXCheckpointLoader-->>ModelLoader: _post_transform_weights_preloaded=True
ModelLoader->>Model: _setup_aliases()
ModelLoader->>Model: _mark_weights_transformed()
ModelLoader->>Model: _walk_cache_state()
else incompatible or not allowed
MXCheckpointLoader->>HfCheckpointLoader: load_weights (disk fallback)
MXCheckpointLoader-->>ModelLoader: _post_transform_weights_preloaded=False
ModelLoader->>Model: _walk_full_post_load()
end
ModelLoader->>MXCheckpointLoader: post_load_publish(source_identity=...)
MXCheckpointLoader->>MxClient: publish_model_params(metadata={layout, protocol, identity})
sequenceDiagram
participant ModelLoader
participant GMSBackend
participant Model
Note over ModelLoader,Model: GMS Read-Only Path (new ordering)
ModelLoader->>Model: _setup_aliases()
ModelLoader->>ModelLoader: _check_gms_source_identity()
ModelLoader->>GMSBackend: materialize_module(model)
ModelLoader->>Model: _walk_cache_state() [cache_derived_state per module]
ModelLoader->>ModelLoader: _post_load_publish(...)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (6)
tensorrt_llm/_torch/models/modeling_qwen3_moe.py (1)
420-427: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick winAdd an explicit
Nonereturn annotation to the renamed hook.
setup_aliasesis a state-mutating hook and does not return a value.Proposed fix
- def setup_aliases(self): + def setup_aliases(self) -> None:As per coding guidelines, “Always annotate functions. Make the return type
Noneif the function does not return anything.”🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/models/modeling_qwen3_moe.py` around lines 420 - 427, The setup_aliases method is missing an explicit return type annotation. Add `-> None` to the method signature of setup_aliases to indicate that this state-mutating hook does not return a value, as per coding guidelines requiring all functions to have return type annotations.Source: Coding guidelines
tensorrt_llm/_torch/models/modeling_gpt_oss.py (1)
634-645: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick winAdd an explicit
Nonereturn annotation to the renamed hook.
setup_aliasesmutates alias fields and returns nothing; annotating it keeps the new lifecycle hook contract clear.Proposed fix
- def setup_aliases(self): + def setup_aliases(self) -> None:As per coding guidelines, “Always annotate functions. Make the return type
Noneif the function does not return anything.”🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/models/modeling_gpt_oss.py` around lines 634 - 645, The setup_aliases method lacks an explicit return type annotation. Add `-> None` to the method signature after the closing parenthesis in the setup_aliases method definition to indicate that this method mutates state but does not return any value, keeping the function contract clear and consistent with coding guidelines that require all functions to be annotated with their return types.Source: Coding guidelines
tensorrt_llm/_torch/models/modeling_qwen3_next.py (1)
983-990: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick winAdd an explicit
Nonereturn annotation to the renamed hook.
setup_aliasesonly wires aliases, so the new lifecycle method should be annotated as returningNone.Proposed fix
- def setup_aliases(self): + def setup_aliases(self) -> None:As per coding guidelines, “Always annotate functions. Make the return type
Noneif the function does not return anything.”🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/models/modeling_qwen3_next.py` around lines 983 - 990, The setup_aliases method in the Qwen3 model class lacks an explicit return type annotation. Add a `-> None` return type annotation to the method signature of setup_aliases to explicitly indicate that this method does not return any value, as per the coding guidelines requiring all functions to have return type annotations.Source: Coding guidelines
tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py (1)
201-212: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick winAdd return annotations to the new MX lifecycle methods.
These methods expose boolean/side-effect lifecycle contracts, so annotate them explicitly.
Proposed fix
- def is_post_transform_weights_preloaded(self) -> bool: + def is_post_transform_weights_preloaded(self) -> bool: """Whether the last successful MX preload delivered transformed bytes. @@ - ) -> None: + ) -> None: """Publish this instance's weights so other ranks can pull via P2P. @@ - ) -> None: + ) -> None: """Publish locally loaded weights as an MX source when appropriate.As per coding guidelines, “Always annotate functions. Make the return type
Noneif the function does not return anything.”Also applies to: 561-567, 664-670
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py` around lines 201 - 212, Add explicit return type annotations to all three MX lifecycle methods to comply with coding guidelines. The method is_post_transform_weights_preloaded already shows the required -> bool annotation in the diff, but ensure the other two methods at lines 561-567 and 664-670 also have their appropriate return type annotations added (these likely also return bool based on the lifecycle contract pattern). Verify each method has its return type explicitly annotated rather than relying on implicit type inference.Source: Coding guidelines
tensorrt_llm/_torch/modules/fused_moe/interface.py (1)
830-841: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick winAdd explicit
-> Noneannotations to the new lifecycle hooks.These hooks do not return values, so annotate them explicitly.
As per coding guidelines, “Always annotate functions. Make the return type
Noneif the function does not return anything.”Proposed fix
- def transform_weights(self): + def transform_weights(self) -> None: if getattr(self, "_weights_transformed", False): return self.quant_method.transform_weights(self) self._weights_transformed = True - def cache_derived_state(self): + def cache_derived_state(self) -> None: self.quant_method.cache_derived_state(self) - def post_load_weights(self): + def post_load_weights(self) -> None: self.transform_weights() self.cache_derived_state()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/fused_moe/interface.py` around lines 830 - 841, The methods transform_weights, cache_derived_state, and post_load_weights are missing explicit return type annotations. Add `-> None` to the method signature of each of these three methods since they do not return any values. This follows the coding guideline that all functions must be annotated with their return type, using `None` when the function does not return anything.Source: Coding guidelines
tensorrt_llm/_torch/modules/fused_moe/quantization.py (1)
562-580: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick winAnnotate the changed lifecycle hook signatures.
The new/renamed hooks should explicitly return
None; Line 1015 should also typemoduleconsistently with the other hooks.As per coding guidelines, “Always annotate functions. Make the return type
Noneif the function does not return anything.”Proposed signature updates
- def transform_weights(self, module: torch.nn.Module): + def transform_weights(self, module: torch.nn.Module) -> None: - def cache_derived_state(self, module: torch.nn.Module): + def cache_derived_state(self, module: torch.nn.Module) -> None: - def post_load_weights(self, module: torch.nn.Module): + def post_load_weights(self, module: torch.nn.Module) -> None: - def transform_weights(self, module): + def transform_weights(self, module: torch.nn.Module) -> None:Also applies to: 784-787, 1015-1016, 1280-1300, 3106-3111, 5351-5352
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 562 - 580, Add explicit return type annotations to the lifecycle hook methods transform_weights, cache_derived_state, and post_load_weights by appending -> None to their signatures. Additionally, ensure that the module parameter is consistently typed as torch.nn.Module across all these hook methods and at the other locations mentioned (784-787, 1015-1016, 1280-1300, 3106-3111, 5351-5352). This follows the coding guideline that all functions must be annotated with their return types, using None when the function does not return a value.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py`:
- Around line 655-667: The _weights_transformed guard flag in
transform_weights() prevents weights from being transformed on subsequent load
cycles because the flag is never reset. When load_weights() is called to load
fresh weights, reset the _weights_transformed flag to False so that the guard
check in transform_weights() will allow the newly loaded weights to be
transformed in the subsequent post_load_weights() call.
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py`:
- Line 1: The file fused_moe_cutlass.py is flagged as executable but contains no
shebang line, which violates the EXE002 lint rule. Since this is library code
and not intended to be run as a standalone script, remove the executable bit
from the file to resolve the violation. This can be done by changing the file
permissions to make it non-executable using your operating system's file
permission tools.
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py`:
- Line 1: The file fused_moe_wide_ep.py has the executable bit set but contains
no shebang line, triggering Ruff's EXE002 rule. Since this is a library module
and not intended to be directly executed, remove the executable permission from
the file using the appropriate file permission command (such as chmod -x on
Unix-like systems) rather than adding a shebang.
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 383-384: The transform_weights method in the Linear class
currently uses a pass statement which triggers a Ruff B027 lint violation.
Replace the pass statement with an explicit return None to maintain the default
no-op behavior while satisfying the linting requirements. This keeps the
optional hook functional while following the tooling's style guidelines.
In `@tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py`:
- Around line 491-500: In the test assertion section where the metadata is
validated, after the existing assert statement that checks
`_MX_SOURCE_IDENTITY_METADATA_KEY in metadata`, add another assertion to
validate the actual serialized value stored at that metadata key. The assertion
should compare the value of `metadata[_MX_SOURCE_IDENTITY_METADATA_KEY]` against
the expected serialized representation of the source_identity object used in
this test, ensuring the identity payload is not just present but also correct.
- Around line 429-461: The
test_post_transform_mixed_success_falls_back_to_full_disk_load test only
validates the final fallback behavior but does not explicitly verify that
MxLiveWeightLoader.load_weights was actually invoked during the P2P attempt. Add
a mock/patch for MxLiveWeightLoader.load_weights within the context manager
alongside the existing HfCheckpointLoader patch, then add an assertion after the
result checks to verify this method was called once, ensuring the test validates
the intended "attempt P2P then fallback to disk" behavior rather than just the
final outcome.
---
Nitpick comments:
In `@tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py`:
- Around line 201-212: Add explicit return type annotations to all three MX
lifecycle methods to comply with coding guidelines. The method
is_post_transform_weights_preloaded already shows the required -> bool
annotation in the diff, but ensure the other two methods at lines 561-567 and
664-670 also have their appropriate return type annotations added (these likely
also return bool based on the lifecycle contract pattern). Verify each method
has its return type explicitly annotated rather than relying on implicit type
inference.
In `@tensorrt_llm/_torch/models/modeling_gpt_oss.py`:
- Around line 634-645: The setup_aliases method lacks an explicit return type
annotation. Add `-> None` to the method signature after the closing parenthesis
in the setup_aliases method definition to indicate that this method mutates
state but does not return any value, keeping the function contract clear and
consistent with coding guidelines that require all functions to be annotated
with their return types.
In `@tensorrt_llm/_torch/models/modeling_qwen3_moe.py`:
- Around line 420-427: The setup_aliases method is missing an explicit return
type annotation. Add `-> None` to the method signature of setup_aliases to
indicate that this state-mutating hook does not return a value, as per coding
guidelines requiring all functions to have return type annotations.
In `@tensorrt_llm/_torch/models/modeling_qwen3_next.py`:
- Around line 983-990: The setup_aliases method in the Qwen3 model class lacks
an explicit return type annotation. Add a `-> None` return type annotation to
the method signature of setup_aliases to explicitly indicate that this method
does not return any value, as per the coding guidelines requiring all functions
to have return type annotations.
In `@tensorrt_llm/_torch/modules/fused_moe/interface.py`:
- Around line 830-841: The methods transform_weights, cache_derived_state, and
post_load_weights are missing explicit return type annotations. Add `-> None` to
the method signature of each of these three methods since they do not return any
values. This follows the coding guideline that all functions must be annotated
with their return type, using `None` when the function does not return anything.
In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py`:
- Around line 562-580: Add explicit return type annotations to the lifecycle
hook methods transform_weights, cache_derived_state, and post_load_weights by
appending -> None to their signatures. Additionally, ensure that the module
parameter is consistently typed as torch.nn.Module across all these hook methods
and at the other locations mentioned (784-787, 1015-1016, 1280-1300, 3106-3111,
5351-5352). This follows the coding guideline that all functions must be
annotated with their return types, using None when the function does not return
a value.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 1c1244f5-435d-4fc9-8b2e-865bc2265648
📒 Files selected for processing (33)
tensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/memory/gpu_memory_backend.pytensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.pytensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.pytensorrt_llm/_torch/models/modeling_deepseekv3.pytensorrt_llm/_torch/models/modeling_exaone_moe.pytensorrt_llm/_torch/models/modeling_glm.pytensorrt_llm/_torch/models/modeling_gpt_oss.pytensorrt_llm/_torch/models/modeling_llama.pytensorrt_llm/_torch/models/modeling_llama_min_latency.pytensorrt_llm/_torch/models/modeling_qwen3_moe.pytensorrt_llm/_torch/models/modeling_qwen3_next.pytensorrt_llm/_torch/modules/attention.pytensorrt_llm/_torch/modules/fused_moe/configurable_moe.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.pytensorrt_llm/_torch/modules/fused_moe/interface.pytensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_deepgemm.pytensorrt_llm/_torch/modules/fused_moe/quantization.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/modules/mamba/mamba2_mixer.pytensorrt_llm/_torch/pyexecutor/model_loader.pytests/unittest/_torch/attention/sparse/test_dsa_indexer.pytests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.pytests/unittest/_torch/modules/mamba/test_mamba2_mixer.pytests/unittest/_torch/modules/moe/test_moe_backend.pytests/unittest/_torch/pyexecutor/test_model_loader_gms.pytests/unittest/_torch/pyexecutor/test_model_loader_mx.pytests/unittest/_torch/weight_sharing/test_mx_source_identity_gate.py
14a4537 to
5599299
Compare
|
/bot run |
|
PR_Github #55162 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #55179 [ run ] triggered by Bot. Commit: |
|
PR_Github #55162 [ run ] completed with state |
|
PR_Github #55179 [ run ] completed with state
|
28e8066 to
a2a79e7
Compare
|
/bot run |
|
PR_Github #55285 [ run ] triggered by Bot. Commit: |
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
9088c0b to
77eddfa
Compare
|
/bot run |
|
PR_Github #55305 [ run ] triggered by Bot. Commit: |
|
PR_Github #55285 [ run ] completed with state |
|
PR_Github #55305 [ run ] completed with state
|
Summary
Stacked on Wave 4 / #15387.
This implements Wave 5 of the staged post-load hook rollout for MX:
setup_aliases()+cache_derived_state()Dependency / prerequisite stack
This PR is Wave 5 in the staged post-load hooks rollout. The foundation PRs #14770 and #14878 are already merged. The wave PRs should merge in sequence; after each upstream wave lands, rebase the next wave onto
mainso review and CI focus on that wave's delta.Arrows point from prerequisite to dependent. PR numbers in graph nodes are clickable.
graph TD PR14770["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/14770'>#14770</a>: staged-hook contract (merged)"] PR14878["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/14878'>#14878</a>: GMS SourceIdentity gate (merged)"] PR15014["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/15014'>#15014</a>: Wave 1 aliases + GMS RO load (open)"] PR15288["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/15288'>#15288</a>: Wave 2 Linear/Attention transforms (draft)"] PR15386["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/15386'>#15386</a>: Wave 3 MoE/Mamba staged hooks (draft)"] PR15387["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/15387'>#15387</a>: Wave 4 MX receiver cutover (draft)"] PR15432["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/15432'>#15432</a>: Wave 5 MX publisher + Llama receiver (this PR, draft)"] VERIFY["post-migration verification / demo (planned)"] PR14770 -->|satisfied| PR15014 PR14878 -->|satisfied| PR15014 PR15014 -->|blocking| PR15288 PR15288 -->|blocking| PR15386 PR15386 -->|blocking| PR15387 PR15387 -->|blocking| PR15432 PR15432 -.->|planned| VERIFY classDef merged fill:#dcfce7,stroke:#16a34a,color:#14532d; classDef inflight fill:#dbeafe,stroke:#2563eb,color:#1e3a8a; classDef draft fill:#ffedd5,stroke:#f97316,color:#7c2d12; classDef current fill:#ede9fe,stroke:#7c3aed,color:#3b0764,stroke-width:3px; classDef downstream fill:#f3f4f6,stroke:#6b7280,color:#374151,stroke-dasharray:5 5; linkStyle 0,1 stroke:#16a34a,stroke-width:2px; linkStyle 2,3,4,5 stroke:#ea580c,stroke-width:3px; linkStyle 6 stroke:#6b7280,stroke-width:2px,stroke-dasharray:5 5; class PR14770,PR14878 merged; class PR15014 inflight; class PR15288,PR15386,PR15387 draft; class PR15432 current; class VERIFY downstream;Immediate merge dependency for this PR: #15387 must land first; after Wave 5 lands, run the post-migration verification/demo for the completed staged-hook rollout.
Validation
git diff --checkpython -m py_compile tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py tensorrt_llm/_torch/pyexecutor/model_loader.py tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py tests/unittest/_torch/pyexecutor/test_model_loader_gms.py tests/unittest/_torch/pyexecutor/test_model_loader_mx.py tests/unittest/_torch/weight_sharing/test_mx_source_identity_gate.pywaive list checkandvalidate-test-listsskipped locally becausescripts/check_test_list.pyfails under this hook interpreter withTypeError: unsupported operand type(s) for |: 'type' and 'NoneType'Focused pytest collection is blocked in this local environment by missing
transformersbefore tests are collected.Summary by CodeRabbit
New Features
Improvements