[TRTLLM-13247][feat] Wave 2: stage Linear and Attention transforms#15288
[TRTLLM-13247][feat] Wave 2: stage Linear and Attention transforms#15288chienchunhung wants to merge 2 commits into
Conversation
379b212 to
d309240
Compare
|
Superseded stack note: this branch was rebased again after the initial draft PR setup. The current stack is now documented in #15288 (comment). |
|
/bot run --disable-fail-fast |
|
PR_Github #53738 [ run ] triggered by Bot. Commit: |
|
PR_Github #53738 [ run ] completed with state
|
d309240 to
67267df
Compare
|
/bot run --disable-fail-fast |
|
CI investigation update: the failed build
A fresh |
|
PR_Github #53933 [ run ] triggered by Bot. Commit: |
|
PR_Github #53933 [ run ] completed with state
|
|
/bot run --disable-fail-fast --stage-list "DGX_B200-4_GPUs-PyTorch-Ray-1, DGX_B200-8_GPUs-PyTorch-1" |
|
PR_Github #53990 [ run ] triggered by Bot. Commit: |
|
PR_Github #53990 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #54020 [ run ] triggered by Bot. Commit: |
|
PR_Github #54020 [ run ] completed with state
|
|
/bot run --disable-fail-fast --stage-list "SBSA-Linux" |
|
PR_Github #54145 [ run ] triggered by Bot. Commit: |
|
PR_Github #54145 [ run ] completed with state
|
67267df to
bfebf3a
Compare
|
/bot run |
|
PR_Github #54338 [ run ] triggered by Bot. Commit: |
|
PR_Github #54338 [ run ] completed with state
|
|
/bot run --disable-fail-fast --stage-list "DGX_B200-PyTorch-1" |
|
PR_Github #54369 [ run ] triggered by Bot. Commit: |
|
PR_Github #54369 [ run ] completed with state |
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThe PR splits the ChangesLifecycle Hook Refactor
Sequence Diagram(s)sequenceDiagram
rect rgba(173, 216, 230, 0.5)
Note over ModelLoader,GMSBackend: GMS RO Load Path
end
participant ModelLoader
participant CheckpointLoader
participant GMSBackend
participant Model
ModelLoader->>CheckpointLoader: post_load_apply(weights_preloaded=True)
ModelLoader->>Model: _setup_aliases() — recursive walk, skip _weights_removed
ModelLoader->>GMSBackend: _check_gms_source_identity() — SourceIdentity gate
ModelLoader->>GMSBackend: materialize_module(model) — bind real tensors
ModelLoader->>Model: _walk_cache_state() — refresh derived state
ModelLoader->>CheckpointLoader: post_load_publish()
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/models/modeling_gpt_oss.py`:
- Line 634: The setup_aliases method is missing an explicit return type
annotation which violates the coding guidelines requiring all functions to be
annotated with their return types. Add -> None after the closing parenthesis of
the setup_aliases method signature to explicitly indicate that this method does
not return any value. This should be placed between the closing parenthesis and
the colon in the method definition.
In `@tensorrt_llm/_torch/models/modeling_qwen3_next.py`:
- Line 983: Add an explicit `-> None` return type annotation to the
`setup_aliases` method definition. Locate the method definition for
`setup_aliases` and modify it from `def setup_aliases(self):` to `def
setup_aliases(self) -> None:` to comply with the coding guideline that requires
all functions to have return type annotations.
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 3145-3149: The `_weights_transformed` flag in the Linear class
becomes inaccurate when using GMS RO (Read-Only) materialization because
`materialize_module()` binds already transformed parameters but the flag remains
False, causing layout transforms to be incorrectly re-applied later. Add a RO
cache-state hook in the Linear module that sets `_weights_transformed = True`
when weights are materialized through the RO path, ensuring the flag truthfully
reflects the actual state of weight transformation. Apply the same state
handling logic to the MLA module to maintain consistency across both modules.
- Around line 383-384: The transform_weights method in LinearMethodBase class
currently violates Ruff rule B027 because it only contains a pass statement in a
concrete (non-abstract) method. Since this is an intentional optional hook that
should remain concrete rather than abstract, replace the pass statement with a
non-empty body such as ellipsis (...) or a docstring to satisfy the Ruff linter
while maintaining the optional hook functionality.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 402ced13-7870-45cc-83ca-6fa005ee6211
📒 Files selected for processing (13)
tensorrt_llm/_torch/memory/gpu_memory_backend.pytensorrt_llm/_torch/models/modeling_deepseekv3.pytensorrt_llm/_torch/models/modeling_exaone_moe.pytensorrt_llm/_torch/models/modeling_glm.pytensorrt_llm/_torch/models/modeling_gpt_oss.pytensorrt_llm/_torch/models/modeling_llama.pytensorrt_llm/_torch/models/modeling_qwen3_moe.pytensorrt_llm/_torch/models/modeling_qwen3_next.pytensorrt_llm/_torch/modules/attention.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/pyexecutor/model_loader.pytests/unittest/_torch/pyexecutor/test_model_loader_gms.pytests/unittest/_torch/pyexecutor/test_model_loader_mx.py
cf883fc to
e5d3175
Compare
|
/bot run |
|
PR_Github #55161 [ run ] triggered by Bot. Commit: |
|
PR_Github #55161 [ run ] completed with state
|
abce27d to
896e764
Compare
|
/bot run |
|
PR_Github #55288 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #55288 [ run ] completed with state
|
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
71027a6 to
9450f16
Compare
|
/bot run |
|
PR_Github #55595 [ run ] triggered by Bot. Commit: |
|
PR_Github #55595 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #55610 [ run ] triggered by Bot. Commit: |
|
PR_Github #55610 [ run ] completed with state
|
Summary
Wave 2 of the staged post-load hooks rollout, stacked on #15014.
This migrates the remaining Linear and MLA tensor-layout post-load work into
transform_weights()with_weights_transformedguards, while keepingpost_load_weights()as the backward-compatible shim for existing full post-load walks.What Changed
Linear.transform_weights()and a quant-method-leveltransform_weights()hook, withpost_load_weights()delegating through the staged hook.post_load_weights()implementations intotransform_weights()._weights_transformedstate for Linear and MLA, reset when fresh Linear weights or auxiliary MLA weight tensors are created/loaded.MLA.transform_weights()and keptMLA.post_load_weights()as a shim.setup_aliases(),materialize_module(), thencache_derived_state(); writer-only tensor layout changes belong intransform_weights().source_identitycall shape.Dependency / prerequisite stack
This PR is Wave 2 in the staged post-load hooks rollout. The foundation PRs #14770 and #14878 are already merged. The wave PRs should merge in sequence; after each upstream wave lands, rebase the next wave onto
mainso review and CI focus on that wave's delta.Arrows point from prerequisite to dependent. PR numbers in graph nodes are clickable.
graph TD PR14770["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/14770'>#14770</a>: staged-hook contract (merged)"] PR14878["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/14878'>#14878</a>: GMS SourceIdentity gate (merged)"] PR15014["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/15014'>#15014</a>: Wave 1 aliases + GMS RO load (open)"] PR15288["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/15288'>#15288</a>: Wave 2 Linear/Attention transforms (this PR, draft)"] PR15386["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/15386'>#15386</a>: Wave 3 MoE/Mamba staged hooks (draft)"] PR15387["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/15387'>#15387</a>: Wave 4 MX receiver cutover (draft)"] PR15432["<a href='https://github.com/NVIDIA/TensorRT-LLM/pull/15432'>#15432</a>: Wave 5 MX publisher + Llama receiver (draft)"] VERIFY["post-migration verification / demo (planned)"] PR14770 -->|satisfied| PR15014 PR14878 -->|satisfied| PR15014 PR15014 -->|blocking| PR15288 PR15288 -->|blocking| PR15386 PR15386 -->|blocking| PR15387 PR15387 -->|blocking| PR15432 PR15432 -.->|planned| VERIFY classDef merged fill:#dcfce7,stroke:#16a34a,color:#14532d; classDef inflight fill:#dbeafe,stroke:#2563eb,color:#1e3a8a; classDef draft fill:#ffedd5,stroke:#f97316,color:#7c2d12; classDef current fill:#ede9fe,stroke:#7c3aed,color:#3b0764,stroke-width:3px; classDef downstream fill:#f3f4f6,stroke:#6b7280,color:#374151,stroke-dasharray:5 5; linkStyle 0,1 stroke:#16a34a,stroke-width:2px; linkStyle 2,3,4,5 stroke:#ea580c,stroke-width:3px; linkStyle 6 stroke:#6b7280,stroke-width:2px,stroke-dasharray:5 5; class PR14770,PR14878 merged; class PR15014 inflight; class PR15386,PR15387,PR15432 draft; class PR15288 current; class VERIFY downstream;Immediate merge dependency for this PR: #15014 must land first; after it lands, rebase this branch onto
mainso the PR diff collapses to the Wave 2 delta.Test Plan
PYTHONPYCACHEPREFIX=/tmp/trtllm-wave2-pycache python3 -m py_compile tensorrt_llm/_torch/modules/linear.py tensorrt_llm/_torch/modules/attention.py tensorrt_llm/_torch/memory/gpu_memory_backend.py tests/unittest/_torch/pyexecutor/test_model_loader_mx.py tests/unittest/_torch/pyexecutor/test_model_loader_gms.pygit diff --checkPATH=/Users/chienchunh/.cache/codex-runtimes/codex-primary-runtime/dependencies/python/bin:$PATH pre-commit run --files tensorrt_llm/_torch/memory/gpu_memory_backend.py tensorrt_llm/_torch/modules/attention.py tensorrt_llm/_torch/modules/linear.py tests/unittest/_torch/pyexecutor/test_model_loader_gms.py tests/unittest/_torch/pyexecutor/test_model_loader_mx.pytransformersis not installed:PYTHONPATH=. PYTHONPYCACHEPREFIX=/tmp/trtllm-wave2-pycache pytest tests/unittest/_torch/pyexecutor/test_model_loader_mx.py tests/unittest/_torch/pyexecutor/test_model_loader_gms.pyNext Steps
mainso the PR diff collapses to the Wave 2 commit only.Summary by CodeRabbit
Release Notes
Bug Fixes
Refactor