Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion tensorrt_llm/_torch/modules/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,33 @@ def __init__(

# Calculate local intermediate size after tensor parallel sharding
tp_size = mapping.tp_size
local_intermediate_size = self.intermediate_size // tp_size

local_intermediate_start = Linear._calc_shard(self.intermediate_size,
mapping.tp_size,
mapping.tp_rank)
local_intermediate_end = Linear._calc_shard(self.intermediate_size,
mapping.tp_size,
mapping.tp_rank + 1)
local_intermediate_size = local_intermediate_end - local_intermediate_start

self._uneven_tp_blocks_lora = (mapping.tp_size > 1
and self.intermediate_size %
mapping.tp_size != 0)

# gateup_shard_indices_mapping is the local offset and size for each sub-weight
# in this rank's concatenated (gate || up) buffer.
# override_tp_sharding is the absolute range of the global weight from which
# this rank pulls each sub-weight.
gateup_shard_indices_mapping = {
'gate': (0, local_intermediate_size),
'up': (local_intermediate_size, local_intermediate_size),
}

override_tp_sharding = {
'gate': (local_intermediate_start, local_intermediate_end),
'up': (local_intermediate_start, local_intermediate_end),
}

Comment on lines 82 to +91

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment explaining the difference b/w the 2

self.gate_up_proj = Linear(
self.hidden_size,
self.intermediate_size * 2,
Expand All @@ -87,6 +107,7 @@ def __init__(
disable_deep_gemm=disable_deep_gemm,
fused_weight_shard_indices_mapping=gateup_shard_indices_mapping,
use_custom_cublas_mm=use_custom_cublas_mm,
override_tp_sharding=override_tp_sharding,
)

if is_shared_expert:
Expand Down Expand Up @@ -284,6 +305,10 @@ def forward_lora(
) -> torch.Tensor:
assert lora_params is not None
assert self.layer_idx is not None, "layer_idx is required for lora"
if self._uneven_tp_blocks_lora:
raise NotImplementedError(
"LoRA is not supported with uneven TP for GatedMLP "
"(intermediate_size not divisible by tp_size).")

h1 = self.gate_up_proj(x)

Expand Down
814 changes: 516 additions & 298 deletions tensorrt_llm/_torch/modules/linear.py

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions tensorrt_llm/_torch/visual_gen/models/flux/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def __init__(
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
reduce_output=False,
override_tp_sharding={
"q": (self.local_q_dim_start, self.local_q_dim_end),
"k": (self.local_kv_dim_start, self.local_kv_dim_end),
"v": (self.local_kv_dim_start, self.local_kv_dim_end),
},
)

# Need not pass any mapping info since this is intra-head normalization
Expand Down Expand Up @@ -130,6 +135,7 @@ def __init__(
allreduce_strategy=config.allreduce_strategy,
tensor_parallel_mode=TensorParallelMode.ROW,
reduce_output=True,
override_tp_sharding=(self.local_kv_dim_start, self.local_kv_dim_end),
)

def apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -350,6 +356,7 @@ def __init__(
skip_create_weights_in_init=self.skip_create_weights_in_init,
force_dynamic_quantization=self.force_dynamic_quantization,
config=config,
attn_shard=(self.local_q_dim_start, self.local_q_dim_end),
)

def _init_qkv_proj(self):
Expand All @@ -366,6 +373,11 @@ def _init_qkv_proj(self):
skip_create_weights_in_init=self.skip_create_weights_in_init,
force_dynamic_quantization=self.force_dynamic_quantization,
mapping=self.mapping,
override_qkv_sharding={
"q": (self.local_q_dim_start, self.local_q_dim_end),
"k": (self.local_kv_dim_start, self.local_kv_dim_end),
"v": (self.local_kv_dim_start, self.local_kv_dim_end),
},
)

def _apply_norm_rope_unfused(
Expand Down
44 changes: 36 additions & 8 deletions tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,19 @@ def __init__(
skip_create_weights_in_init: bool = False,
force_dynamic_quantization: bool = False,
config: Optional[DiffusionModelConfig] = None,
attn_shard: Optional[tuple[int, int]] = None,
):
super().__init__()
mapping = config.mapping if config else None
self.tp_size = getattr(mapping, "tp_size", 1)
self.tp_rank = getattr(mapping, "tp_rank", 0)
self.attn_dim = attn_dim
self.has_bias = bias
self.attn_shard = attn_shard

assert attn_dim % self.tp_size == 0 or self.attn_shard is not None, (
"Explicit attention sharding required for uneven TP"
)

if self.tp_size == 1:
self.proj = Linear(
Expand All @@ -84,6 +90,7 @@ def __init__(
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
reduce_output=False,
override_tp_sharding=self.attn_shard,
)
self.mlp_proj = Linear(
mlp_dim,
Expand Down Expand Up @@ -162,10 +169,12 @@ def __init__(
skip_create_weights_in_init: bool = False,
force_dynamic_quantization: bool = False,
mapping: Optional[Mapping] = None,
override_qkv_sharding=None,
):
super().__init__()

self.tp_size = mapping.tp_size if mapping else 1
self.tp_rank = mapping.tp_rank if mapping else 0

# Store full (pre-TP) dims for weight loading (splitting checkpoint weight)
self.full_q_dim = q_dim
Expand All @@ -188,9 +197,15 @@ def __init__(
self.local_qkv_dim = q_dim + 2 * kv_dim
self.local_mlp_dim = mlp_dim
else:
local_q_dim = q_dim // self.tp_size
local_kv_dim = kv_dim // self.tp_size
shard_mlp_hidden_dim = self.mlp_hidden_dim // self.tp_size
assert override_qkv_sharding is not None, (
"override_qkv_sharding required when tp_size > 1"
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert override_qkv_sharding is not None

def range_size(r):
return r[1] - r[0]

local_q_dim = range_size(override_qkv_sharding["q"])
local_kv_dim = range_size(override_qkv_sharding["k"])
# QKV: column-parallel with fused Q/K/V sharding
self.qkv_proj = Linear(
in_dim,
Expand All @@ -211,8 +226,17 @@ def __init__(
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
reduce_output=False,
override_tp_sharding=override_qkv_sharding,
)
# MLP gate+up: column-parallel with fused gate/up sharding

local_mlp_hidden_start = Linear._calc_shard(
self.mlp_hidden_dim, self.tp_size, self.tp_rank
)
local_mlp_hidden_end = Linear._calc_shard(
self.mlp_hidden_dim, self.tp_size, self.tp_rank + 1
)
local_mlp_hidden_size = local_mlp_hidden_end - local_mlp_hidden_start

self.mlp_proj = Linear(
in_dim,
mlp_dim,
Expand All @@ -225,15 +249,19 @@ def __init__(
weight_mode=WeightMode.FUSED_GATE_UP_LINEAR,
),
fused_weight_shard_indices_mapping={
"gate": (0, shard_mlp_hidden_dim),
"up": (shard_mlp_hidden_dim, shard_mlp_hidden_dim),
"gate": (0, local_mlp_hidden_size),
"up": (local_mlp_hidden_size, local_mlp_hidden_size),
},
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
reduce_output=False,
override_tp_sharding={
"gate": (local_mlp_hidden_start, local_mlp_hidden_end),
"up": (local_mlp_hidden_start, local_mlp_hidden_end),
},
)
self.local_qkv_dim = (q_dim + 2 * kv_dim) // self.tp_size
self.local_mlp_dim = mlp_dim // self.tp_size
self.local_qkv_dim = local_q_dim + 2 * local_kv_dim
self.local_mlp_dim = local_mlp_hidden_size

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns (qkv, mlp_gate_up) with local (post-TP) sizes."""
Expand Down
37 changes: 26 additions & 11 deletions tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,22 @@ def __init__(
)
self.act_mlp = _gelu_tanh_eager

kv_dim = num_attention_heads * attention_head_dim
# Attention (no added_kv_proj_dim since tokens are already concatenated)
self.attn = FluxJointAttention(
hidden_size=dim,
num_attention_heads=num_attention_heads,
head_dim=attention_head_dim,
bias=True,
eps=1e-6,
pre_only=True, # No output projection in attention
config=config,
layer_idx=layer_idx,
module_name=f"single_transformer_blocks.{layer_idx}.attn",
)

# MLP + Attn Output projection, requires special handling for TP
self.proj_out = FluxJointAttnMLPProj(
attn_dim=kv_dim,
attn_dim=self.attn.q_dim,
mlp_dim=self.mlp_hidden_dim,
out_dim=dim,
bias=True,
Expand All @@ -480,19 +491,23 @@ def __init__(
skip_create_weights_in_init=skip_create_weights,
force_dynamic_quantization=force_dynamic_quant,
config=config,
# need explicit shard because we are aligned on head boundaries
attn_shard=(self.attn.local_q_dim_start, self.attn.local_q_dim_end),
)

# Attention (no added_kv_proj_dim since tokens are already concatenated)
self.attn = FluxJointAttention(
hidden_size=dim,
num_attention_heads=num_attention_heads,
head_dim=attention_head_dim,
# MLP + Attn Output projection, requires special handling for TP
self.proj_out = FluxJointAttnMLPProj(
attn_dim=self.attn.q_dim,
mlp_dim=self.mlp_hidden_dim,
out_dim=dim,
bias=True,
eps=1e-6,
pre_only=True, # No output projection in attention
dtype=dtype,
quant_config=quant_config,
skip_create_weights_in_init=skip_create_weights,
force_dynamic_quantization=force_dynamic_quant,
config=config,
layer_idx=layer_idx,
module_name=f"single_transformer_blocks.{layer_idx}.attn",
# need explicit shard because we are aligned on head boundaries
attn_shard=(self.attn.local_q_dim_start, self.attn.local_q_dim_end),
)

def forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def __init__(
force_dynamic_quantization=force_dynamic_quant,
tensor_parallel_mode=tp_mode,
reduce_output=False,
override_tp_sharding=(self.attn2.local_kv_dim_start, self.attn2.local_kv_dim_end),
)
self.add_v_proj = Linear(
added_kv_proj_dim,
Expand All @@ -382,6 +383,7 @@ def __init__(
force_dynamic_quantization=force_dynamic_quant,
tensor_parallel_mode=tp_mode,
reduce_output=False,
override_tp_sharding=(self.attn2.local_kv_dim_start, self.attn2.local_kv_dim_end),
)
self.norm_added_k = RMSNormTPAware(
hidden_size=hidden_size,
Expand All @@ -390,6 +392,7 @@ def __init__(
has_weights=True,
enable_tp=(tp_size > 1),
mapping=model_config.mapping,
override_tp_sharding=(self.attn2.local_kv_dim_start, self.attn2.local_kv_dim_end),
)

# Use torch.empty().normal_(std=...) instead of torch.randn()/scale for MetaInitMode compatibility
Expand Down
Loading
Loading