[None][feat] Implement Uneven TP Linear for VisualGen models#14875
[None][feat] Implement Uneven TP Linear for VisualGen models#14875belgarten-nv wants to merge 4 commits into
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #51683 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR introduces alignment-aware tensor-parallel (TP) sharding infrastructure for TensorRT-LLM PyTorch modules. The changes centralize weight-loading logic in ChangesUneven TP Sharding
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/linear.py (2)
2350-2355: 💤 Low valueUnused variable
weight_dtype.The variable is unpacked but never used. Consider prefixing with underscore to indicate intentional discard.
🔧 Suggested fix
- weight_dtype, weight_id = get_weight_dtype_and_id(module) + _, weight_id = get_weight_dtype_and_id(module)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/linear.py` around lines 2350 - 2355, The tuple returned by get_weight_dtype_and_id is unpacked into weight_dtype and weight_id but weight_dtype is never used; change the unpack to discard the unused value (e.g., prefix with an underscore such as _weight_dtype or use a single underscore) so it’s clear it’s intentional, keeping weight_id for the subsequent elm_packing calculation that uses module.tp_mode and TensorParallelMode.COLUMN and the call to load_weights_vanilla_helper.
3126-3129: 💤 Low valueConsider adding a defensive assertion when
tp_shardingis a dict.If
tp_shardingis a dict (for fused modes),namemust be provided. AKeyErrorwould occur ifnameisNone, which could be confusing. While fused modes should always providename, an explicit assertion documents this contract.🔧 Proposed assertion
elif isinstance(self.tp_sharding, dict): + assert name is not None, ( + "name is required when tp_sharding is a dict (fused modes)") slice_start, slice_end = self.tp_sharding[name]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/linear.py` around lines 3126 - 3129, When handling tp_sharding as a dict in the block that assigns slice_start/slice_end, add a defensive assertion that name is provided and is a valid key: assert name is not None and name in self.tp_sharding, with a clear message like "tp_sharding is a dict but 'name' is missing or invalid for fused mode"; then read slice_start, slice_end = self.tp_sharding[name] as before to avoid an unclear KeyError. This involves editing the branch that currently checks isinstance(self.tp_sharding, dict) in the code that assigns slice_start and slice_end.tests/unittest/_torch/modules/test_linear_uneven_tp.py (1)
476-486: 💤 Low valueConsider adding
strict=Trueto zip for explicit length validation.The
zip(outputs, per_rank_ranges)should have matching lengths, andstrict=Truewould catch any mismatch during test development. However, this is a minor enhancement.♻️ Proposed fix
- for output, ranges in zip(outputs, per_rank_ranges): + for output, ranges in zip(outputs, per_rank_ranges, strict=True):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/test_linear_uneven_tp.py` around lines 476 - 486, In _assemble_fused_outputs, ensure the lengths of outputs and per_rank_ranges are validated by replacing zip(outputs, per_rank_ranges) with zip(outputs, per_rank_ranges, strict=True) so a length mismatch raises immediately; update the zip call within the function _assemble_fused_outputs to include strict=True (requires Python 3.10+), leaving the rest of the slicing/concatenation logic unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py`:
- Around line 260-261: The assignment to self.local_qkv_dim is incorrect for
uneven tensor-parallel (TP) shards: replace the simple global division
(currently self.local_qkv_dim = (q_dim + 2 * kv_dim) // self.tp_size) with a
computed local sum using the shard-specific sizes so local_qkv_dim = local_q_dim
+ 2 * local_kv_dim (where local_q_dim and local_kv_dim are derived from the TP
override/range computation used elsewhere); update the initialization in the
constructor so forward() uses the correct per-shard split sizes when slicing QKV
tensors.
---
Nitpick comments:
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 2350-2355: The tuple returned by get_weight_dtype_and_id is
unpacked into weight_dtype and weight_id but weight_dtype is never used; change
the unpack to discard the unused value (e.g., prefix with an underscore such as
_weight_dtype or use a single underscore) so it’s clear it’s intentional,
keeping weight_id for the subsequent elm_packing calculation that uses
module.tp_mode and TensorParallelMode.COLUMN and the call to
load_weights_vanilla_helper.
- Around line 3126-3129: When handling tp_sharding as a dict in the block that
assigns slice_start/slice_end, add a defensive assertion that name is provided
and is a valid key: assert name is not None and name in self.tp_sharding, with a
clear message like "tp_sharding is a dict but 'name' is missing or invalid for
fused mode"; then read slice_start, slice_end = self.tp_sharding[name] as before
to avoid an unclear KeyError. This involves editing the branch that currently
checks isinstance(self.tp_sharding, dict) in the code that assigns slice_start
and slice_end.
In `@tests/unittest/_torch/modules/test_linear_uneven_tp.py`:
- Around line 476-486: In _assemble_fused_outputs, ensure the lengths of outputs
and per_rank_ranges are validated by replacing zip(outputs, per_rank_ranges)
with zip(outputs, per_rank_ranges, strict=True) so a length mismatch raises
immediately; update the zip call within the function _assemble_fused_outputs to
include strict=True (requires Python 3.10+), leaving the rest of the
slicing/concatenation logic unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: ea513628-a60f-4a66-8091-b7ffc1a0299b
📒 Files selected for processing (13)
tensorrt_llm/_torch/modules/gated_mlp.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/visual_gen/models/flux/attention.pytensorrt_llm/_torch/visual_gen/models/flux/joint_proj.pytensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.pytensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.pytensorrt_llm/_torch/visual_gen/modules/attention.pytensorrt_llm/_torch/visual_gen/modules/rms_norm.pytests/unittest/_torch/modules/test_linear_uneven_tp.pytests/unittest/_torch/visual_gen/multi_gpu/test_flux_tp.pytests/unittest/_torch/visual_gen/multi_gpu/test_tp_attention.pytests/unittest/_torch/visual_gen/multi_gpu/test_wan_tp.pytests/unittest/_torch/visual_gen/multi_gpu/tp_shard_utils.py
|
PR_Github #51683 [ run ] completed with state
|
| h_to_4h_type = LoraModuleType.MLP_H_TO_4H | ||
| gate_type = LoraModuleType.MLP_GATE | ||
|
|
||
| self.down_lora = LoraLayer([down_type], [self.hidden_size]) |
There was a problem hiding this comment.
Will this work with lora layers? Can you raise an error if someone tries to use uneven model dims with TP with lora?
| gateup_shard_indices_mapping = { | ||
| 'gate': (0, local_intermediate_size), | ||
| 'up': (local_intermediate_size, local_intermediate_size), | ||
| } | ||
|
|
||
| override_tp_sharding = { | ||
| 'gate': (local_intermediate_start, local_intermediate_end), | ||
| 'up': (local_intermediate_start, local_intermediate_end), | ||
| } | ||
|
|
There was a problem hiding this comment.
could you add a comment explaining the difference b/w the 2
|
|
||
| def load_shard( | ||
| self, | ||
| weights: Dict, |
There was a problem hiding this comment.
looks like weights can be a torch.Tensor too?
| _uneven_tp_unsupported = {QuantAlgo.NVFP4_ARC} | ||
| _quant_algo = quant_config.quant_algo if quant_config else None | ||
| if override_tp_sharding is not None: | ||
| assert _quant_algo not in _uneven_tp_unsupported |
|
|
||
| return self._calculate_local_features_helper(out_features) | ||
|
|
||
| def load_shard( |
There was a problem hiding this comment.
could you add docstrings for all methods and also in load_weight_shard pointing to this new method
| For VANILLA mode only. Fused modes with non-divisible dims require | ||
| explicit override_tp_sharding from the model layer. | ||
| """ | ||
| assert self.weights_loading_config.weight_mode == WeightMode.VANILLA, ( |
There was a problem hiding this comment.
does it make sense to include FUSED_GATE_UP mode in auto sharding? it should be simpler than FUSE_QKV I think
There was a problem hiding this comment.
The reason we can't do QKV fusing is because we don't know the full size of each Q, K, V matrix. You could always assume the size of the Gate is always the same as the size of Up matrix (which is thus out_features / 2), but I'm not certain this is always true. This becomes a problem since this could be misused really for any two simultaneous projections, not just a swiglu Gate/Up setup.
| # Note: this is intentionally stronger than `num_kv_head >= ulysses_size * tp_size` | ||
| assert self.num_key_value_heads // ulysses_size >= self.tp_size | ||
|
|
||
| def _calc_shard(full, size, rank): |
There was a problem hiding this comment.
just use rank as arg to make it clearer
| self.local_q_dim = self.local_num_attention_heads * self.head_dim | ||
| self.local_kv_dim = self.local_num_key_value_heads * self.head_dim | ||
|
|
||
| self._calculate_tp_parameters(ulysses_size if enable_ulysses else None) |
There was a problem hiding this comment.
would it be better to move this to a utility class with accessors to make it easier for new models?
| self.has_bias = bias | ||
| self.attn_shard = attn_shard | ||
|
|
||
| assert attn_dim % self.tp_size == 0 or self.attn_shard, ( |
There was a problem hiding this comment.
assert attn_dim % self.tp_size == 0 or self.attn_shard is not None,
| local_q_dim = q_dim // self.tp_size | ||
| local_kv_dim = kv_dim // self.tp_size | ||
| shard_mlp_hidden_dim = self.mlp_hidden_dim // self.tp_size | ||
|
|
There was a problem hiding this comment.
assert override_qkv_sharding is not None
Signed-off-by: Brenden Elgarten <belgarten@nvidia.com>
Signed-off-by: Brenden Elgarten <belgarten@nvidia.com>
5475ecc to
bbbfed0
Compare
|
/bot run |
|
PR_Github #55079 [ run ] triggered by Bot. Commit: |
|
PR_Github #55079 [ run ] completed with state
|
|
/bot run --disable-reuse-test --disable-fail-fast |
|
PR_Github #55086 [ run ] triggered by Bot. Commit: |
Signed-off-by: Brenden Elgarten <belgarten@nvidia.com>
bbbfed0 to
81eb3d5
Compare
|
/bot kill |
|
PR_Github #55106 [ kill ] triggered by Bot. Commit: |
|
PR_Github #55086 [ run ] completed with state |
|
PR_Github #55106 [ kill ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #55108 [ run ] triggered by Bot. Commit: |
|
PR_Github #55108 [ run ] completed with state
|
Signed-off-by: Brenden Elgarten <belgarten@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #55268 [ run ] triggered by Bot. Commit: |
|
PR_Github #55268 [ run ] completed with state
|
Summary by CodeRabbit
New Features
Tests
Description
This PR reworks the Linear layer to support Tensor Parallelism when the size of the TP group does not divide the dimensions of the layer ("uneven TP") and allows users to optionally pass custom sharding layouts. It then uses that work to implement uneven TP in the VisualGen models Wan T2V, Wan I2V, Flux-1, and Flux-2.
Test Coverage
Uneven Linear layer loading and numerics are tested in
tests/unittests/_torch/modules/test_linear_uneven_tp.py.The VisualGen changes are tested alongside VisualGen TP in:
tests/unittest/_torch/visual_gen/multi_gpu/test_flux_tp.pytests/unittest/_torch/visual_gen/multi_gpu/test_tp_attention.pytests/unittest/_torch/visual_gen/multi_gpu/test_wan_tp.pyPR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.