Skip to content

[None][feat] Implement Uneven TP Linear for VisualGen models#14875

Open
belgarten-nv wants to merge 4 commits into
NVIDIA:mainfrom
belgarten-nv:user/belgarten/uneven-tp
Open

[None][feat] Implement Uneven TP Linear for VisualGen models#14875
belgarten-nv wants to merge 4 commits into
NVIDIA:mainfrom
belgarten-nv:user/belgarten/uneven-tp

Conversation

@belgarten-nv

@belgarten-nv belgarten-nv commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

Summary by CodeRabbit

  • New Features

    • Added support for uneven tensor-parallel sharding across distributed inference, enabling efficient multi-GPU deployment for models with non-evenly divisible layer dimensions.
    • Enhanced tensor-parallel configuration controls for attention and MLP projections in generative AI models (FLUX, WAN).
  • Tests

    • Added comprehensive test suite for uneven tensor-parallel configurations across multiple quantization methods.
    • Added distributed multi-GPU tests for FLUX and WAN model tensor-parallel correctness.

Description

This PR reworks the Linear layer to support Tensor Parallelism when the size of the TP group does not divide the dimensions of the layer ("uneven TP") and allows users to optionally pass custom sharding layouts. It then uses that work to implement uneven TP in the VisualGen models Wan T2V, Wan I2V, Flux-1, and Flux-2.

Test Coverage

Uneven Linear layer loading and numerics are tested in tests/unittests/_torch/modules/test_linear_uneven_tp.py.

The VisualGen changes are tested alongside VisualGen TP in:

  • tests/unittest/_torch/visual_gen/multi_gpu/test_flux_tp.py
  • tests/unittest/_torch/visual_gen/multi_gpu/test_tp_attention.py
  • tests/unittest/_torch/visual_gen/multi_gpu/test_wan_tp.py

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@belgarten-nv belgarten-nv requested review from a team as code owners June 2, 2026 21:47
@belgarten-nv belgarten-nv requested a review from QiJune June 2, 2026 21:47
@belgarten-nv

Copy link
Copy Markdown
Contributor Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #51683 [ run ] triggered by Bot. Commit: 1c7e390 Link to invocation

@coderabbitai

coderabbitai Bot commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

This PR introduces alignment-aware tensor-parallel (TP) sharding infrastructure for TensorRT-LLM PyTorch modules. The changes centralize weight-loading logic in Linear.load_shard, support explicit TP boundary overrides via override_tp_sharding, and propagate these constraints through quantization methods, module layers, and vision models to enable uneven TP partitioning (non-divisible feature dimensions across ranks).

Changes

Uneven TP Sharding

Layer / File(s) Summary
Linear module TP sharding foundation
tensorrt_llm/_torch/modules/linear.py
Linear.__init__ gains override_tp_sharding parameter; new _calc_shard, _auto_tp_sharding, and load_shard methods centralize TP-aware weight slicing; base LinearMethod.get_tp_alignment() defines alignment constraints used by sharding selection.
Quantization methods alignment and weight loading
tensorrt_llm/_torch/modules/linear.py
All quantization method classes gain get_tp_alignment() methods; FP8 rowwise/blockscales, NVFP4, W4A8 variants, weight-only, and AWQ methods refactor weight and scale loading to use module.load_shard(...) with parameterized shard naming, scale span, and elm packing instead of explicit TP parameters.
Module-level sharding (GatedMLP, Attention, RMSNormTPAware)
tensorrt_llm/_torch/modules/gated_mlp.py, tensorrt_llm/_torch/visual_gen/modules/attention.py, tensorrt_llm/_torch/visual_gen/modules/rms_norm.py
GatedMLP uses Linear._calc_shard for intermediate size; Attention computes per-rank head ranges via _calculate_tp_parameters and passes explicit override_tp_sharding to Q/K/V projections and normalization; RMSNormTPAware accepts and uses override_tp_sharding for configurable TP boundaries.
Model-level sharding wiring (FLUX and WAN)
tensorrt_llm/_torch/visual_gen/models/flux/attention.py, tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py, tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py, tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
FLUX joint attention, joint projection, and transformer block modules receive explicit override_tp_sharding ranges for Q/K/V and gate/up shards; WAN I2V cross-attention projections and normalization receive override_tp_sharding aligned with attention KV layout.
Test utilities and comprehensive test suite
tests/unittest/_torch/visual_gen/multi_gpu/tp_shard_utils.py, tests/unittest/_torch/modules/test_linear_uneven_tp.py, tests/unittest/_torch/visual_gen/multi_gpu/test_flux_tp.py, tests/unittest/_torch/visual_gen/multi_gpu/test_tp_attention.py, tests/unittest/_torch/visual_gen/multi_gpu/test_wan_tp.py
New tp_shard_utils module provides TP shard calculation and weight-sharding helpers; comprehensive test_linear_uneven_tp.py validates uneven TP across all quantization modes with fused QKV/gate-up loading tests; FLUX/WAN/attention tests updated to use shared utilities and add TP=3 uneven scenarios.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • QiJune
  • bo-nv
  • liji-nv
  • yuxianq
  • reasonsolo
  • JunyiXu-nv
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.95% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly identifies the main change: implementing uneven tensor parallel support for Linear layers in VisualGen models, following the required format with ticket prefix and feature type.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description covers the main objectives and test coverage, though the PR title template appears incomplete with '[None]' as placeholder.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/linear.py (2)

2350-2355: 💤 Low value

Unused variable weight_dtype.

The variable is unpacked but never used. Consider prefixing with underscore to indicate intentional discard.

🔧 Suggested fix
-        weight_dtype, weight_id = get_weight_dtype_and_id(module)
+        _, weight_id = get_weight_dtype_and_id(module)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/modules/linear.py` around lines 2350 - 2355, The tuple
returned by get_weight_dtype_and_id is unpacked into weight_dtype and weight_id
but weight_dtype is never used; change the unpack to discard the unused value
(e.g., prefix with an underscore such as _weight_dtype or use a single
underscore) so it’s clear it’s intentional, keeping weight_id for the subsequent
elm_packing calculation that uses module.tp_mode and TensorParallelMode.COLUMN
and the call to load_weights_vanilla_helper.

3126-3129: 💤 Low value

Consider adding a defensive assertion when tp_sharding is a dict.

If tp_sharding is a dict (for fused modes), name must be provided. A KeyError would occur if name is None, which could be confusing. While fused modes should always provide name, an explicit assertion documents this contract.

🔧 Proposed assertion
         elif isinstance(self.tp_sharding, dict):
+            assert name is not None, (
+                "name is required when tp_sharding is a dict (fused modes)")
             slice_start, slice_end = self.tp_sharding[name]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/modules/linear.py` around lines 3126 - 3129, When
handling tp_sharding as a dict in the block that assigns slice_start/slice_end,
add a defensive assertion that name is provided and is a valid key: assert name
is not None and name in self.tp_sharding, with a clear message like "tp_sharding
is a dict but 'name' is missing or invalid for fused mode"; then read
slice_start, slice_end = self.tp_sharding[name] as before to avoid an unclear
KeyError. This involves editing the branch that currently checks
isinstance(self.tp_sharding, dict) in the code that assigns slice_start and
slice_end.
tests/unittest/_torch/modules/test_linear_uneven_tp.py (1)

476-486: 💤 Low value

Consider adding strict=True to zip for explicit length validation.

The zip(outputs, per_rank_ranges) should have matching lengths, and strict=True would catch any mismatch during test development. However, this is a minor enhancement.

♻️ Proposed fix
-    for output, ranges in zip(outputs, per_rank_ranges):
+    for output, ranges in zip(outputs, per_rank_ranges, strict=True):
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/modules/test_linear_uneven_tp.py` around lines 476 -
486, In _assemble_fused_outputs, ensure the lengths of outputs and
per_rank_ranges are validated by replacing zip(outputs, per_rank_ranges) with
zip(outputs, per_rank_ranges, strict=True) so a length mismatch raises
immediately; update the zip call within the function _assemble_fused_outputs to
include strict=True (requires Python 3.10+), leaving the rest of the
slicing/concatenation logic unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py`:
- Around line 260-261: The assignment to self.local_qkv_dim is incorrect for
uneven tensor-parallel (TP) shards: replace the simple global division
(currently self.local_qkv_dim = (q_dim + 2 * kv_dim) // self.tp_size) with a
computed local sum using the shard-specific sizes so local_qkv_dim = local_q_dim
+ 2 * local_kv_dim (where local_q_dim and local_kv_dim are derived from the TP
override/range computation used elsewhere); update the initialization in the
constructor so forward() uses the correct per-shard split sizes when slicing QKV
tensors.

---

Nitpick comments:
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 2350-2355: The tuple returned by get_weight_dtype_and_id is
unpacked into weight_dtype and weight_id but weight_dtype is never used; change
the unpack to discard the unused value (e.g., prefix with an underscore such as
_weight_dtype or use a single underscore) so it’s clear it’s intentional,
keeping weight_id for the subsequent elm_packing calculation that uses
module.tp_mode and TensorParallelMode.COLUMN and the call to
load_weights_vanilla_helper.
- Around line 3126-3129: When handling tp_sharding as a dict in the block that
assigns slice_start/slice_end, add a defensive assertion that name is provided
and is a valid key: assert name is not None and name in self.tp_sharding, with a
clear message like "tp_sharding is a dict but 'name' is missing or invalid for
fused mode"; then read slice_start, slice_end = self.tp_sharding[name] as before
to avoid an unclear KeyError. This involves editing the branch that currently
checks isinstance(self.tp_sharding, dict) in the code that assigns slice_start
and slice_end.

In `@tests/unittest/_torch/modules/test_linear_uneven_tp.py`:
- Around line 476-486: In _assemble_fused_outputs, ensure the lengths of outputs
and per_rank_ranges are validated by replacing zip(outputs, per_rank_ranges)
with zip(outputs, per_rank_ranges, strict=True) so a length mismatch raises
immediately; update the zip call within the function _assemble_fused_outputs to
include strict=True (requires Python 3.10+), leaving the rest of the
slicing/concatenation logic unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: ea513628-a60f-4a66-8091-b7ffc1a0299b

📥 Commits

Reviewing files that changed from the base of the PR and between cd38dfb and 1c7e390.

📒 Files selected for processing (13)
  • tensorrt_llm/_torch/modules/gated_mlp.py
  • tensorrt_llm/_torch/modules/linear.py
  • tensorrt_llm/_torch/visual_gen/models/flux/attention.py
  • tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py
  • tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py
  • tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py
  • tensorrt_llm/_torch/visual_gen/modules/attention.py
  • tensorrt_llm/_torch/visual_gen/modules/rms_norm.py
  • tests/unittest/_torch/modules/test_linear_uneven_tp.py
  • tests/unittest/_torch/visual_gen/multi_gpu/test_flux_tp.py
  • tests/unittest/_torch/visual_gen/multi_gpu/test_tp_attention.py
  • tests/unittest/_torch/visual_gen/multi_gpu/test_wan_tp.py
  • tests/unittest/_torch/visual_gen/multi_gpu/tp_shard_utils.py

Comment thread tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py Outdated
@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #51683 [ run ] completed with state SUCCESS. Commit: 1c7e390
/LLM/main/L0_MergeRequest_PR pipeline #41063 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

h_to_4h_type = LoraModuleType.MLP_H_TO_4H
gate_type = LoraModuleType.MLP_GATE

self.down_lora = LoraLayer([down_type], [self.hidden_size])

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work with lora layers? Can you raise an error if someone tries to use uneven model dims with TP with lora?

Comment on lines 74 to +83
gateup_shard_indices_mapping = {
'gate': (0, local_intermediate_size),
'up': (local_intermediate_size, local_intermediate_size),
}

override_tp_sharding = {
'gate': (local_intermediate_start, local_intermediate_end),
'up': (local_intermediate_start, local_intermediate_end),
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment explaining the difference b/w the 2

Comment thread tensorrt_llm/_torch/modules/linear.py Outdated

def load_shard(
self,
weights: Dict,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like weights can be a torch.Tensor too?

_uneven_tp_unsupported = {QuantAlgo.NVFP4_ARC}
_quant_algo = quant_config.quant_algo if quant_config else None
if override_tp_sharding is not None:
assert _quant_algo not in _uneven_tp_unsupported

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise error


return self._calculate_local_features_helper(out_features)

def load_shard(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add docstrings for all methods and also in load_weight_shard pointing to this new method

For VANILLA mode only. Fused modes with non-divisible dims require
explicit override_tp_sharding from the model layer.
"""
assert self.weights_loading_config.weight_mode == WeightMode.VANILLA, (

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense to include FUSED_GATE_UP mode in auto sharding? it should be simpler than FUSE_QKV I think

@belgarten-nv belgarten-nv Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we can't do QKV fusing is because we don't know the full size of each Q, K, V matrix. You could always assume the size of the Gate is always the same as the size of Up matrix (which is thus out_features / 2), but I'm not certain this is always true. This becomes a problem since this could be misused really for any two simultaneous projections, not just a swiglu Gate/Up setup.

# Note: this is intentionally stronger than `num_kv_head >= ulysses_size * tp_size`
assert self.num_key_value_heads // ulysses_size >= self.tp_size

def _calc_shard(full, size, rank):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just use rank as arg to make it clearer

self.local_q_dim = self.local_num_attention_heads * self.head_dim
self.local_kv_dim = self.local_num_key_value_heads * self.head_dim

self._calculate_tp_parameters(ulysses_size if enable_ulysses else None)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be better to move this to a utility class with accessors to make it easier for new models?

self.has_bias = bias
self.attn_shard = attn_shard

assert attn_dim % self.tp_size == 0 or self.attn_shard, (

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert attn_dim % self.tp_size == 0 or self.attn_shard is not None,

local_q_dim = q_dim // self.tp_size
local_kv_dim = kv_dim // self.tp_size
shard_mlp_hidden_dim = self.mlp_hidden_dim // self.tp_size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert override_qkv_sharding is not None

Signed-off-by: Brenden Elgarten <belgarten@nvidia.com>
Signed-off-by: Brenden Elgarten <belgarten@nvidia.com>
@belgarten-nv belgarten-nv force-pushed the user/belgarten/uneven-tp branch from 5475ecc to bbbfed0 Compare June 22, 2026 19:48
@belgarten-nv belgarten-nv requested a review from a team as a code owner June 22, 2026 19:48
@belgarten-nv

Copy link
Copy Markdown
Contributor Author

/bot run

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55079 [ run ] triggered by Bot. Commit: bbbfed0 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55079 [ run ] completed with state SUCCESS. Commit: bbbfed0
/LLM/main/L0_MergeRequest_PR pipeline #44066 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@belgarten-nv

Copy link
Copy Markdown
Contributor Author

/bot run --disable-reuse-test --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55086 [ run ] triggered by Bot. Commit: bbbfed0 Link to invocation

Signed-off-by: Brenden Elgarten <belgarten@nvidia.com>
@belgarten-nv belgarten-nv force-pushed the user/belgarten/uneven-tp branch from bbbfed0 to 81eb3d5 Compare June 23, 2026 00:13
@belgarten-nv

Copy link
Copy Markdown
Contributor Author

/bot kill

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55106 [ kill ] triggered by Bot. Commit: 81eb3d5 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55086 [ run ] completed with state ABORTED. Commit: bbbfed0

Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55106 [ kill ] completed with state SUCCESS. Commit: 81eb3d5
Successfully killed previous jobs for commit 81eb3d5

Link to invocation

@belgarten-nv

Copy link
Copy Markdown
Contributor Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55108 [ run ] triggered by Bot. Commit: 81eb3d5 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55108 [ run ] completed with state SUCCESS. Commit: 81eb3d5
/LLM/main/L0_MergeRequest_PR pipeline #44092 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: Brenden Elgarten <belgarten@nvidia.com>
@belgarten-nv

Copy link
Copy Markdown
Contributor Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55268 [ run ] triggered by Bot. Commit: 3a6e883 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55268 [ run ] completed with state FAILURE. Commit: 3a6e883
/LLM/main/L0_MergeRequest_PR pipeline #44222 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants