Skip to content

Commit 1c7e390

Browse files
committed
uneven tp initial impl
Signed-off-by: Brenden Elgarten <belgarten@nvidia.com>
1 parent 702e39d commit 1c7e390

13 files changed

Lines changed: 2562 additions & 590 deletions

File tree

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,25 @@ def __init__(
6262

6363
# Calculate local intermediate size after tensor parallel sharding
6464
tp_size = mapping.tp_size
65-
local_intermediate_size = self.intermediate_size // tp_size
65+
66+
local_intermediate_start = Linear._calc_shard(self.intermediate_size,
67+
mapping.tp_size,
68+
mapping.tp_rank)
69+
local_intermediate_end = Linear._calc_shard(self.intermediate_size,
70+
mapping.tp_size,
71+
mapping.tp_rank + 1)
72+
local_intermediate_size = local_intermediate_end - local_intermediate_start
6673

6774
gateup_shard_indices_mapping = {
6875
'gate': (0, local_intermediate_size),
6976
'up': (local_intermediate_size, local_intermediate_size),
7077
}
7178

79+
override_tp_sharding = {
80+
'gate': (local_intermediate_start, local_intermediate_end),
81+
'up': (local_intermediate_start, local_intermediate_end),
82+
}
83+
7284
self.gate_up_proj = Linear(
7385
self.hidden_size,
7486
self.intermediate_size * 2,
@@ -87,6 +99,7 @@ def __init__(
8799
disable_deep_gemm=disable_deep_gemm,
88100
fused_weight_shard_indices_mapping=gateup_shard_indices_mapping,
89101
use_custom_cublas_mm=use_custom_cublas_mm,
102+
override_tp_sharding=override_tp_sharding,
90103
)
91104

92105
if is_shared_expert:

tensorrt_llm/_torch/modules/linear.py

Lines changed: 492 additions & 299 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/visual_gen/models/flux/attention.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def __init__(
9999
mapping=config.mapping,
100100
tensor_parallel_mode=TensorParallelMode.COLUMN,
101101
reduce_output=False,
102+
override_tp_sharding={
103+
"q": (self.local_q_dim_start, self.local_q_dim_end),
104+
"k": (self.local_kv_dim_start, self.local_kv_dim_end),
105+
"v": (self.local_kv_dim_start, self.local_kv_dim_end),
106+
},
102107
)
103108

104109
# Need not pass any mapping info since this is intra-head normalization
@@ -128,6 +133,7 @@ def __init__(
128133
allreduce_strategy=config.allreduce_strategy,
129134
tensor_parallel_mode=TensorParallelMode.ROW,
130135
reduce_output=True,
136+
override_tp_sharding=(self.local_kv_dim_start, self.local_kv_dim_end),
131137
)
132138

133139
def apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -345,6 +351,7 @@ def __init__(
345351
skip_create_weights_in_init=self.skip_create_weights_in_init,
346352
force_dynamic_quantization=self.force_dynamic_quantization,
347353
config=config,
354+
attn_shard=(self.local_q_dim_start, self.local_q_dim_end),
348355
)
349356

350357
def _init_qkv_proj(self):
@@ -361,6 +368,11 @@ def _init_qkv_proj(self):
361368
skip_create_weights_in_init=self.skip_create_weights_in_init,
362369
force_dynamic_quantization=self.force_dynamic_quantization,
363370
mapping=self.mapping,
371+
override_qkv_sharding={
372+
"q": (self.local_q_dim_start, self.local_q_dim_end),
373+
"k": (self.local_kv_dim_start, self.local_kv_dim_end),
374+
"v": (self.local_kv_dim_start, self.local_kv_dim_end),
375+
},
364376
)
365377

366378
def _apply_norm_rope_unfused(

tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,19 @@ def __init__(
5353
skip_create_weights_in_init: bool = False,
5454
force_dynamic_quantization: bool = False,
5555
config: Optional[DiffusionModelConfig] = None,
56+
attn_shard: Optional[tuple[int, int]] = None,
5657
):
5758
super().__init__()
5859
mapping = config.mapping if config else None
5960
self.tp_size = getattr(mapping, "tp_size", 1)
6061
self.tp_rank = getattr(mapping, "tp_rank", 0)
6162
self.attn_dim = attn_dim
6263
self.has_bias = bias
64+
self.attn_shard = attn_shard
65+
66+
assert attn_dim % self.tp_size == 0 or self.attn_shard, (
67+
"Explicit attention sharding required for uneven TP"
68+
)
6369

6470
if self.tp_size == 1:
6571
self.proj = Linear(
@@ -84,6 +90,7 @@ def __init__(
8490
mapping=config.mapping,
8591
tensor_parallel_mode=TensorParallelMode.ROW,
8692
reduce_output=False,
93+
override_tp_sharding=self.attn_shard,
8794
)
8895
self.mlp_proj = Linear(
8996
mlp_dim,
@@ -162,10 +169,12 @@ def __init__(
162169
skip_create_weights_in_init: bool = False,
163170
force_dynamic_quantization: bool = False,
164171
mapping: Optional[Mapping] = None,
172+
override_qkv_sharding=None,
165173
):
166174
super().__init__()
167175

168176
self.tp_size = mapping.tp_size if mapping else 1
177+
self.tp_rank = mapping.tp_rank if mapping else 0
169178

170179
# Store full (pre-TP) dims for weight loading (splitting checkpoint weight)
171180
self.full_q_dim = q_dim
@@ -188,9 +197,12 @@ def __init__(
188197
self.local_qkv_dim = q_dim + 2 * kv_dim
189198
self.local_mlp_dim = mlp_dim
190199
else:
191-
local_q_dim = q_dim // self.tp_size
192-
local_kv_dim = kv_dim // self.tp_size
193-
shard_mlp_hidden_dim = self.mlp_hidden_dim // self.tp_size
200+
201+
def range_size(r):
202+
return r[1] - r[0]
203+
204+
local_q_dim = range_size(override_qkv_sharding["q"])
205+
local_kv_dim = range_size(override_qkv_sharding["k"])
194206
# QKV: column-parallel with fused Q/K/V sharding
195207
self.qkv_proj = Linear(
196208
in_dim,
@@ -211,8 +223,17 @@ def __init__(
211223
mapping=mapping,
212224
tensor_parallel_mode=TensorParallelMode.COLUMN,
213225
reduce_output=False,
226+
override_tp_sharding=override_qkv_sharding,
227+
)
228+
229+
local_mlp_hidden_start = Linear._calc_shard(
230+
self.mlp_hidden_dim, self.tp_size, self.tp_rank
231+
)
232+
local_mlp_hidden_end = Linear._calc_shard(
233+
self.mlp_hidden_dim, self.tp_size, self.tp_rank + 1
214234
)
215-
# MLP gate+up: column-parallel with fused gate/up sharding
235+
local_mlp_hidden_size = local_mlp_hidden_end - local_mlp_hidden_start
236+
216237
self.mlp_proj = Linear(
217238
in_dim,
218239
mlp_dim,
@@ -225,15 +246,19 @@ def __init__(
225246
weight_mode=WeightMode.FUSED_GATE_UP_LINEAR,
226247
),
227248
fused_weight_shard_indices_mapping={
228-
"gate": (0, shard_mlp_hidden_dim),
229-
"up": (shard_mlp_hidden_dim, shard_mlp_hidden_dim),
249+
"gate": (0, local_mlp_hidden_size),
250+
"up": (local_mlp_hidden_size, local_mlp_hidden_size),
230251
},
231252
mapping=mapping,
232253
tensor_parallel_mode=TensorParallelMode.COLUMN,
233254
reduce_output=False,
255+
override_tp_sharding={
256+
"gate": (local_mlp_hidden_start, local_mlp_hidden_end),
257+
"up": (local_mlp_hidden_start, local_mlp_hidden_end),
258+
},
234259
)
235260
self.local_qkv_dim = (q_dim + 2 * kv_dim) // self.tp_size
236-
self.local_mlp_dim = mlp_dim // self.tp_size
261+
self.local_mlp_dim = local_mlp_hidden_size
237262

238263
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
239264
"""Returns (qkv, mlp_gate_up) with local (post-TP) sizes."""

tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -465,11 +465,21 @@ def __init__(
465465
)
466466
self.act_mlp = _gelu_tanh_eager
467467

468-
kv_dim = num_attention_heads * attention_head_dim
468+
# Attention (no added_kv_proj_dim since tokens are already concatenated)
469+
self.attn = FluxJointAttention(
470+
hidden_size=dim,
471+
num_attention_heads=num_attention_heads,
472+
head_dim=attention_head_dim,
473+
bias=True,
474+
eps=1e-6,
475+
pre_only=True, # No output projection in attention
476+
config=config,
477+
layer_idx=layer_idx,
478+
)
469479

470480
# MLP + Attn Output projection, requires special handling for TP
471481
self.proj_out = FluxJointAttnMLPProj(
472-
attn_dim=kv_dim,
482+
attn_dim=self.attn.q_dim,
473483
mlp_dim=self.mlp_hidden_dim,
474484
out_dim=dim,
475485
bias=True,
@@ -478,18 +488,8 @@ def __init__(
478488
skip_create_weights_in_init=skip_create_weights,
479489
force_dynamic_quantization=force_dynamic_quant,
480490
config=config,
481-
)
482-
483-
# Attention (no added_kv_proj_dim since tokens are already concatenated)
484-
self.attn = FluxJointAttention(
485-
hidden_size=dim,
486-
num_attention_heads=num_attention_heads,
487-
head_dim=attention_head_dim,
488-
bias=True,
489-
eps=1e-6,
490-
pre_only=True, # No output projection in attention
491-
config=config,
492-
layer_idx=layer_idx,
491+
# need explicit shard because we are aligned on head boundaries
492+
attn_shard=(self.attn.local_q_dim_start, self.attn.local_q_dim_end),
493493
)
494494

495495
def forward(

tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def __init__(
356356
force_dynamic_quantization=force_dynamic_quant,
357357
tensor_parallel_mode=tp_mode,
358358
reduce_output=False,
359+
override_tp_sharding=(self.attn2.local_kv_dim_start, self.attn2.local_kv_dim_end),
359360
)
360361
self.add_v_proj = Linear(
361362
added_kv_proj_dim,
@@ -367,6 +368,7 @@ def __init__(
367368
force_dynamic_quantization=force_dynamic_quant,
368369
tensor_parallel_mode=tp_mode,
369370
reduce_output=False,
371+
override_tp_sharding=(self.attn2.local_kv_dim_start, self.attn2.local_kv_dim_end),
370372
)
371373
self.norm_added_k = RMSNormTPAware(
372374
hidden_size=hidden_size,
@@ -375,6 +377,7 @@ def __init__(
375377
has_weights=True,
376378
enable_tp=(tp_size > 1),
377379
mapping=model_config.mapping,
380+
override_tp_sharding=(self.attn2.local_kv_dim_start, self.attn2.local_kv_dim_end),
378381
)
379382

380383
# Use torch.empty().normal_(std=...) instead of torch.randn()/scale for MetaInitMode compatibility

tensorrt_llm/_torch/visual_gen/modules/attention.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,7 @@ def __init__(
7171
self.bias = bias
7272

7373
self.tp_size = self.mapping.tp_size if self.mapping else 1
74-
assert (
75-
self.num_attention_heads % self.tp_size == 0
76-
and self.num_key_value_heads % self.tp_size == 0
77-
), "TP size must divide the number of Query and KV Heads"
74+
self.tp_rank = self.mapping.tp_rank if self.mapping else 0
7875

7976
# Fused QK Norm + RoPE: each model class opts in via fuse_qk_norm_rope.
8077
# Backed by torch.ops.trtllm.fused_dit_qk_norm_rope which auto-dispatches:
@@ -108,11 +105,7 @@ def __init__(
108105
self.q_dim = self.num_attention_heads * self.head_dim
109106
self.kv_dim = self.num_key_value_heads * self.head_dim
110107

111-
self.local_num_attention_heads = self.num_attention_heads // self.tp_size
112-
self.local_num_key_value_heads = self.num_key_value_heads // self.tp_size
113-
self.local_q_dim = self.local_num_attention_heads * self.head_dim
114-
self.local_kv_dim = self.local_num_key_value_heads * self.head_dim
115-
108+
self._calculate_tp_parameters(ulysses_size if enable_ulysses else None)
116109
self._init_qkv_proj()
117110

118111
attention_metadata_state = getattr(config, "attention_metadata_state", None)
@@ -124,13 +117,20 @@ def __init__(
124117
q_norm_dim = self.head_dim if qk_norm_mode == "per_head" else self.q_dim
125118
k_norm_dim = self.head_dim if qk_norm_mode == "per_head" else self.kv_dim
126119
enable_tp_rms = self.tp_size > 1 and qk_norm_mode == "full"
120+
121+
q_start = self.local_q_dim_start
122+
q_end = self.local_q_dim_end
123+
k_start = self.local_kv_dim_start
124+
k_end = self.local_kv_dim_end
125+
127126
self.norm_q = RMSNormTPAware(
128127
hidden_size=q_norm_dim,
129128
eps=self.eps,
130129
dtype=self.dtype,
131130
has_weights=True,
132131
enable_tp=enable_tp_rms,
133132
mapping=self.mapping,
133+
override_tp_sharding=(q_start, q_end) if qk_norm_mode == "full" else None,
134134
)
135135
self.norm_k = RMSNormTPAware(
136136
hidden_size=k_norm_dim,
@@ -139,6 +139,7 @@ def __init__(
139139
has_weights=True,
140140
enable_tp=enable_tp_rms,
141141
mapping=self.mapping,
142+
override_tp_sharding=(k_start, k_end) if qk_norm_mode == "full" else None,
142143
)
143144

144145
# TODO: Use weight mapper to create just a Linear module
@@ -156,6 +157,7 @@ def __init__(
156157
tensor_parallel_mode=TensorParallelMode.ROW if self.tp_size > 1 else None,
157158
reduce_output=(self.tp_size > 1),
158159
allreduce_strategy=self.allreduce_strategy,
160+
override_tp_sharding=(self.local_q_dim_start, self.local_q_dim_end),
159161
)
160162
]
161163
)
@@ -231,6 +233,46 @@ def __init__(
231233

232234
self.attn = UlyssesAttention(self.attn, process_group=vgm.ulysses_group)
233235

236+
def _calculate_tp_parameters(self, ulysses_size: Optional[int]):
237+
assert self.num_attention_heads % self.num_key_value_heads == 0
238+
gqa_ratio = self.num_attention_heads // self.num_key_value_heads
239+
240+
if not ulysses_size:
241+
ulysses_size = 1
242+
243+
assert self.num_key_value_heads % ulysses_size == 0
244+
# Note: this is intentionally stronger than `num_kv_head >= ulysses_size * tp_size`
245+
assert self.num_key_value_heads // ulysses_size >= self.tp_size
246+
247+
def _calc_shard(full, size, rank):
248+
full //= ulysses_size
249+
shard = (full // size) * rank + min(full % size, rank)
250+
return shard * ulysses_size
251+
252+
self.local_key_value_head_start = _calc_shard(
253+
self.num_key_value_heads, self.tp_size, self.tp_rank
254+
)
255+
self.local_key_value_head_end = _calc_shard(
256+
self.num_key_value_heads, self.tp_size, self.tp_rank + 1
257+
)
258+
self.local_num_key_value_heads = (
259+
self.local_key_value_head_end - self.local_key_value_head_start
260+
)
261+
262+
self.local_attention_head_start = gqa_ratio * self.local_key_value_head_start
263+
self.local_attention_head_end = gqa_ratio * self.local_key_value_head_end
264+
self.local_num_attention_heads = (
265+
self.local_attention_head_end - self.local_attention_head_start
266+
)
267+
268+
self.local_q_dim_start = self.local_attention_head_start * self.head_dim
269+
self.local_q_dim_end = self.local_attention_head_end * self.head_dim
270+
self.local_q_dim = self.local_q_dim_end - self.local_q_dim_start
271+
272+
self.local_kv_dim_start = self.local_key_value_head_start * self.head_dim
273+
self.local_kv_dim_end = self.local_key_value_head_end * self.head_dim
274+
self.local_kv_dim = self.local_kv_dim_end - self.local_kv_dim_start
275+
234276
def _init_qkv_proj(self) -> None:
235277
tp_mode = TensorParallelMode.COLUMN if self.tp_size > 1 else None
236278

@@ -258,6 +300,11 @@ def _init_qkv_proj(self) -> None:
258300
},
259301
tensor_parallel_mode=tp_mode,
260302
reduce_output=False,
303+
override_tp_sharding={
304+
"q": (self.local_q_dim_start, self.local_q_dim_end),
305+
"k": (self.local_kv_dim_start, self.local_kv_dim_end),
306+
"v": (self.local_kv_dim_start, self.local_kv_dim_end),
307+
},
261308
)
262309
else:
263310
self.to_q = Linear(
@@ -271,6 +318,7 @@ def _init_qkv_proj(self) -> None:
271318
force_dynamic_quantization=self.force_dynamic_quantization,
272319
tensor_parallel_mode=tp_mode,
273320
reduce_output=False,
321+
override_tp_sharding=(self.local_q_dim_start, self.local_q_dim_end),
274322
)
275323
self.to_k = Linear(
276324
self.hidden_size,
@@ -283,6 +331,7 @@ def _init_qkv_proj(self) -> None:
283331
force_dynamic_quantization=self.force_dynamic_quantization,
284332
tensor_parallel_mode=tp_mode,
285333
reduce_output=False,
334+
override_tp_sharding=(self.local_kv_dim_start, self.local_kv_dim_end),
286335
)
287336
self.to_v = Linear(
288337
self.hidden_size,
@@ -295,6 +344,7 @@ def _init_qkv_proj(self) -> None:
295344
force_dynamic_quantization=self.force_dynamic_quantization,
296345
tensor_parallel_mode=tp_mode,
297346
reduce_output=False,
347+
override_tp_sharding=(self.local_kv_dim_start, self.local_kv_dim_end),
298348
)
299349

300350
def get_qkv(

0 commit comments

Comments
 (0)