From 0b73b90b69e811e370efe9e7671fd086918dcdbd Mon Sep 17 00:00:00 2001 From: Brenden Elgarten Date: Tue, 2 Jun 2026 21:39:42 +0000 Subject: [PATCH 1/4] uneven tp initial impl Signed-off-by: Brenden Elgarten --- tensorrt_llm/_torch/modules/gated_mlp.py | 15 +- tensorrt_llm/_torch/modules/linear.py | 791 ++++++---- .../visual_gen/models/flux/attention.py | 12 + .../visual_gen/models/flux/joint_proj.py | 39 +- .../models/flux/transformer_flux.py | 30 +- .../visual_gen/models/wan/transformer_wan.py | 3 + .../_torch/visual_gen/modules/attention.py | 68 +- .../_torch/visual_gen/modules/rms_norm.py | 37 +- .../_torch/modules/test_linear_uneven_tp.py | 1371 +++++++++++++++++ .../visual_gen/multi_gpu/test_flux_tp.py | 224 ++- .../visual_gen/multi_gpu/test_tp_attention.py | 114 +- .../visual_gen/multi_gpu/test_wan_tp.py | 153 +- .../visual_gen/multi_gpu/tp_shard_utils.py | 297 ++++ 13 files changed, 2563 insertions(+), 591 deletions(-) create mode 100644 tests/unittest/_torch/modules/test_linear_uneven_tp.py create mode 100644 tests/unittest/_torch/visual_gen/multi_gpu/tp_shard_utils.py diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index 01293a83faf0..2b49fa195a08 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -62,13 +62,25 @@ def __init__( # Calculate local intermediate size after tensor parallel sharding tp_size = mapping.tp_size - local_intermediate_size = self.intermediate_size // tp_size + + local_intermediate_start = Linear._calc_shard(self.intermediate_size, + mapping.tp_size, + mapping.tp_rank) + local_intermediate_end = Linear._calc_shard(self.intermediate_size, + mapping.tp_size, + mapping.tp_rank + 1) + local_intermediate_size = local_intermediate_end - local_intermediate_start gateup_shard_indices_mapping = { 'gate': (0, local_intermediate_size), 'up': (local_intermediate_size, local_intermediate_size), } + override_tp_sharding = { + 'gate': (local_intermediate_start, local_intermediate_end), + 'up': (local_intermediate_start, local_intermediate_end), + } + self.gate_up_proj = Linear( self.hidden_size, self.intermediate_size * 2, @@ -87,6 +99,7 @@ def __init__( disable_deep_gemm=disable_deep_gemm, fused_weight_shard_indices_mapping=gateup_shard_indices_mapping, use_custom_cublas_mm=use_custom_cublas_mm, + override_tp_sharding=override_tp_sharding, ) if is_shared_expert: diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index aae3f3a65ab2..d51babad8630 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -177,7 +177,8 @@ def load_weights_vanilla_helper(module: Linear, weights: List[Dict], weight_transform=lambda x: x, bias_transform=lambda x: x, - allow_partial_loading: bool = False): + allow_partial_loading: bool = False, + elm_packing: int = 1): assert len(weights) == 1 if not allow_partial_loading: assert "weight" in weights[0] @@ -185,9 +186,10 @@ def load_weights_vanilla_helper(module: Linear, assert "bias" in weights[0] device = torch.device('cuda') - weight = load_weight_shard(weights[0]['weight'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "weight" in weights[0] else None + weight = module.load_shard(weights[0], + 'weight', + device=device, + elm_packing=elm_packing) if weight is not None: if module.has_weight_only_quant: @@ -202,9 +204,7 @@ def load_weights_vanilla_helper(module: Linear, copy_weight(module.weight, weight_transform(weight)) if module.bias is not None: - bias = load_weight_shard(weights[0]['bias'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "bias" in weights[0] else None + bias = module.load_shard(weights[0], 'bias', device=device) if bias is not None: copy_weight(module.bias, bias_transform(bias)) @@ -214,7 +214,8 @@ def load_weights_fused_qkv_helper( weights: List[Dict], weight_transform=lambda x: x, bias_transform=lambda x: x, - allow_partial_loading: bool = False + allow_partial_loading: bool = False, + elm_packing: int = 1, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if not allow_partial_loading: assert all('weight' in weights[i] for i in range(3)) @@ -226,26 +227,27 @@ def load_weights_fused_qkv_helper( ) is not None, "Fused weight shard indices mapping is required in partial loading" device = torch.device('cuda') - q_weight = load_weight_shard(weights[0]['weight'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "weight" in weights[0] else None - k_weight = load_weight_shard(weights[1]['weight'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "weight" in weights[1] else None - v_weight = load_weight_shard(weights[2]['weight'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "weight" in weights[2] else None + q_weight = module.load_shard(weights[0], + 'weight', + device=device, + name='q', + elm_packing=elm_packing) + k_weight = module.load_shard(weights[1], + 'weight', + device=device, + name='k', + elm_packing=elm_packing) + v_weight = module.load_shard(weights[2], + 'weight', + device=device, + name='v', + elm_packing=elm_packing) if module.bias is not None: - q_bias = load_weight_shard(weights[0]['bias'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "bias" in weights[0] else None - k_bias = load_weight_shard(weights[1]['bias'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "bias" in weights[1] else None - v_bias = load_weight_shard(weights[2]['bias'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "bias" in weights[2] else None + q_bias = module.load_shard(weights[0], 'bias', device=device, name='q') + k_bias = module.load_shard(weights[1], 'bias', device=device, name='k') + v_bias = module.load_shard(weights[2], 'bias', device=device, name='v') + if not allow_partial_loading: copy_weight(module.bias, bias_transform(torch.cat((q_bias, k_bias, v_bias)))) @@ -263,11 +265,12 @@ def load_weights_fused_qkv_helper( def load_weights_fused_gate_up_helper( - module: Linear, - weights: List[Dict], - weight_transform=lambda x: x, - bias_transform=lambda x: x, - allow_partial_loading: bool = False + module: Linear, + weights: List[Dict], + weight_transform=lambda x: x, + bias_transform=lambda x: x, + allow_partial_loading: bool = False, + elm_packing: int = 1, ) -> tuple[torch.Tensor, torch.Tensor]: if not allow_partial_loading: assert all('weight' in weights[i] for i in range(2)) @@ -279,19 +282,26 @@ def load_weights_fused_gate_up_helper( ) is not None, "Fused weight shard indices mapping is required in partial loading" device = torch.device('cuda') - gate_weight = load_weight_shard(weights[0]['weight'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "weight" in weights[0] else None - up_weight = load_weight_shard(weights[1]['weight'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "weight" in weights[1] else None + gate_weight = module.load_shard(weights[0], + 'weight', + device=device, + name='gate', + elm_packing=elm_packing) + up_weight = module.load_shard(weights[1], + 'weight', + device=device, + name='up', + elm_packing=elm_packing) + if module.bias is not None: - gate_bias = load_weight_shard(weights[0]['bias'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "bias" in weights[0] else None - up_bias = load_weight_shard(weights[1]['bias'], module.tp_size, - module.tp_rank, module.tp_mode, - device) if "bias" in weights[1] else None + gate_bias = module.load_shard(weights[0], + 'bias', + device=device, + name='gate') + up_bias = module.load_shard(weights[1], + 'bias', + device=device, + name='up') if not allow_partial_loading: copy_weight(module.bias, bias_transform(torch.cat((gate_bias, up_bias)))) @@ -461,6 +471,12 @@ def pre_reload_weights(self, module: Linear): requires_grad=False) module.register_parameter(param_name, param) + def get_tp_alignment(self, + tp_mode: Optional[TensorParallelMode], + quant_config: Optional[QuantConfig] = None) -> int: + """ Alignment required for TP shard boundaries. """ + return 1 + class UnquantizedLinearMethod(LinearMethodBase): @@ -706,18 +722,12 @@ def load_weight_scales(self, shard_keys: list[str] = None): input_scales, weight_scales = {}, {} if shard_keys is None: - for w in weights: - if "input_scale" in w: - input_scales[None] = w["input_scale"][...].reshape([]) - if "weight_scale" in w: - weight_scales[None] = w["weight_scale"][...].reshape([]) - else: - for shard_key, w in zip(shard_keys, weights): - if "input_scale" in w: - input_scales[shard_key] = w["input_scale"][...].reshape([]) - if "weight_scale" in w: - weight_scales[shard_key] = w["weight_scale"][...].reshape( - []) + shard_keys = [None] + for shard_key, w in zip(shard_keys, weights): + if "input_scale" in w: + input_scales[shard_key] = w["input_scale"][...].reshape([]) + if "weight_scale" in w: + weight_scales[shard_key] = w["weight_scale"][...].reshape([]) return input_scales, weight_scales def load_weights_vanilla(self, @@ -1003,9 +1013,7 @@ def load_weights_vanilla(self, module, weights, allow_partial_loading=allow_partial_loading) scale_name = self._get_scale_name(weights) if scale_name in weights[0]: - weight_scale = load_weight_shard(weights[0][scale_name], - module.tp_size, module.tp_rank, - module.tp_mode) + weight_scale = module.load_shard(weights[0], scale_name) copy_weight(module.weight_scale, weight_scale) if "input_scale" in weights[0]: copy_weight(module.input_scale, weights[0]["input_scale"]) @@ -1017,16 +1025,12 @@ def load_weights_fused_qkv_linear(self, allow_partial_loading: bool = False): super().load_weights_fused_qkv_linear( module, weights, allow_partial_loading=allow_partial_loading) + scale_name = self._get_scale_name(weights) - q_scale = load_weight_shard( - weights[0][scale_name], module.tp_size, module.tp_rank, - module.tp_mode) if scale_name in weights[0] else None - k_scale = load_weight_shard( - weights[1][scale_name], module.tp_size, module.tp_rank, - module.tp_mode) if scale_name in weights[1] else None - v_scale = load_weight_shard( - weights[2][scale_name], module.tp_size, module.tp_rank, - module.tp_mode) if scale_name in weights[2] else None + q_scale = module.load_shard(weights[0], scale_name, name='q') + k_scale = module.load_shard(weights[1], scale_name, name='k') + v_scale = module.load_shard(weights[2], scale_name, name='v') + for shard_key, scale in zip( module.fused_weight_shard_indices_mapping.keys(), [q_scale, k_scale, v_scale]): @@ -1043,13 +1047,11 @@ def load_weights_fused_gate_up_linear( allow_partial_loading: bool = False) -> None: super().load_weights_fused_gate_up_linear( module, weights, allow_partial_loading=allow_partial_loading) + scale_name = self._get_scale_name(weights) - gate_scale = load_weight_shard( - weights[0][scale_name], module.tp_size, module.tp_rank, - module.tp_mode) if scale_name in weights[0] else None - up_scale = load_weight_shard( - weights[1][scale_name], module.tp_size, module.tp_rank, - module.tp_mode) if scale_name in weights[1] else None + gate_scale = module.load_shard(weights[0], scale_name, name='gate') + up_scale = module.load_shard(weights[1], scale_name, name='up') + for shard_key, scale in zip( module.fused_weight_shard_indices_mapping.keys(), [gate_scale, up_scale]): @@ -1065,6 +1067,9 @@ class FP8BlockScalesLinearMethod(UnquantizedLinearMethod): # fp8_block_scaling_gemm does not support writing into an NCCL window buffer. supports_nccl_symmetric_memory_window_output: ClassVar[bool] = False + def get_tp_alignment(self, tp_mode, quant_config=None): + return 128 + def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype): weight_shape = (out_features, in_features) @@ -1157,8 +1162,8 @@ def load_weights_vanilla(self, # modelopt fp8_pb_wo can have 2 extra singleton dimensions if full_weight_scale.dim() == 4: full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1) - weight_scale = load_weight_shard(full_weight_scale, module.tp_size, - module.tp_rank, module.tp_mode) + + weight_scale = module.load_shard(full_weight_scale, scale_span=128) copy_weight(module.weight_scale, weight_scale) if "input_scale" in weights[0]: copy_weight(module.input_scale, weights[0]["input_scale"]) @@ -1202,8 +1207,8 @@ def load_weights_fused_qkv_linear( ] scales = [ - load_weight_shard(s, module.tp_size, module.tp_rank, module.tp_mode) - if s is not None else None for s in full_scales_squeezed + module.load_shard(s, scale_span=128) if s is not None else None + for s in full_scales_squeezed ] processed_mapping = self.remap_fused_shard_indices_by_divisible_factor( module.fused_weight_shard_indices_mapping, 128) @@ -1230,8 +1235,8 @@ def load_weights_fused_gate_up_linear( for s in full_scales ] scales = [ - load_weight_shard(s, module.tp_size, module.tp_rank, module.tp_mode) - if s is not None else None for s in full_scales_squeezed + module.load_shard(s, scale_span=128) if s is not None else None + for s in full_scales_squeezed ] processed_mapping = self.remap_fused_shard_indices_by_divisible_factor( module.fused_weight_shard_indices_mapping, 128) @@ -1273,6 +1278,13 @@ class NVFP4LinearMethod(LinearMethodBase): # construction; LLM paths leave it False to avoid host overhead. use_tunable_quantize: bool = False + def get_tp_alignment(self, tp_mode, quant_config=None): + # 32-element alignment for both modes. ROW shards in_features which + # is packed 2:1, so 32 → 16 packed, meeting GEMM col_alignment=16. + # COLUMN must also be 32 because column output feeds row input, and + # row's packed weight K dimension needs 16-alignment. + return 32 + def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype): module.scaling_vector_size = 16 @@ -1474,17 +1486,17 @@ def load_weight_scales(self, """ device = torch.device("cuda") + scale_span = 16 if module.tp_mode == TensorParallelMode.ROW else 1 # Per-shard weight_scale: load, TP-shard, store in tmp dict keyed by shard if shard_keys is not None: if not hasattr(module, "tmp_nvfp4_weight_scales"): module.tmp_nvfp4_weight_scales = {} for shard_key, w in zip(shard_keys, weights): if "weight_scale" in w: - ws = load_weight_shard(w["weight_scale"], - module.tp_size, - module.tp_rank, - module.tp_mode, - device=device).contiguous() + ws = module.load_shard(w["weight_scale"], + device=device, + scale_span=scale_span, + name=shard_key).contiguous() assert ws.dtype == torch.float8_e4m3fn module.tmp_nvfp4_weight_scales[shard_key] = ws.view( fp4_utils.float4_sf_dtype) @@ -1492,11 +1504,9 @@ def load_weight_scales(self, # Vanilla: single weight_scale, load + interleave directly w = weights[0] if "weight_scale" in w: - ws = load_weight_shard(w["weight_scale"], - module.tp_size, - module.tp_rank, - module.tp_mode, - device=device).contiguous() + ws = module.load_shard(w["weight_scale"], + device=device, + scale_span=scale_span).contiguous() ws = ws.view(fp4_utils.float4_sf_dtype) ws = torch.ops.trtllm.block_scale_interleave(ws) copy_weight(module.weight_scale, ws) @@ -1592,9 +1602,12 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict], allow_partial_loading: bool = False) -> None: + + elm_packing = 2 if module.tp_mode == TensorParallelMode.ROW else 1 load_weights_vanilla_helper(module, weights, - allow_partial_loading=allow_partial_loading) + allow_partial_loading=allow_partial_loading, + elm_packing=elm_packing) # Load scales (vanilla = no shard_keys) self.load_weight_scales(module, weights, shard_keys=None) @@ -1602,13 +1615,14 @@ def load_weights_vanilla(self, # Load pre_quant_scale if it exists (for NVFP4_AWQ) if "pre_quant_scale" in weights[0]: device = module.weight.device - pre_quant_scale = load_weight_shard( + # scale_span is flipped because flip_tp=True + act_scale_span = 1 if module.tp_mode == TensorParallelMode.ROW else 16 + pre_quant_scale = module.load_shard( weights[0]["pre_quant_scale"], - module.tp_size, - module.tp_rank, # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around - TensorParallelMode.flip(module.tp_mode), - device, + flip_tp=True, + device=device, + scale_span=act_scale_span, ) module.pre_quant_scale = Parameter( @@ -1622,8 +1636,12 @@ def load_weights_fused_qkv_linear( module: Linear, weights: List[Dict], allow_partial_loading: bool = False) -> None: + elm_packing = 2 if module.tp_mode == TensorParallelMode.ROW else 1 q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( - module, weights, allow_partial_loading=allow_partial_loading) + module, + weights, + allow_partial_loading=allow_partial_loading, + elm_packing=elm_packing) weight_mode = module.weights_loading_config.weight_mode @@ -1695,8 +1713,12 @@ def load_weights_fused_gate_up_linear( module: Linear, weights: List[Dict], allow_partial_loading: bool = False) -> None: + elm_packing = 2 if module.tp_mode == TensorParallelMode.ROW else 1 gate_weight, up_weight = load_weights_fused_gate_up_helper( - module, weights, allow_partial_loading=allow_partial_loading) + module, + weights, + allow_partial_loading=allow_partial_loading, + elm_packing=elm_packing) weight_mode = module.weights_loading_config.weight_mode device = torch.device("cuda") @@ -1715,14 +1737,13 @@ def load_weights_fused_gate_up_linear( # Load pre_quant_scale if it exists (for NVFP4_AWQ) # NOTE: pre_quant_scale is the same for gate and up since modelopt checks which layer shared the same input if "pre_quant_scale" in weights[0]: - device = module.weight.device - pre_quant_scale = load_weight_shard( + act_scale_span = 1 if module.tp_mode == TensorParallelMode.ROW else 16 + pre_quant_scale = module.load_shard( weights[0]["pre_quant_scale"], - module.tp_size, - module.tp_rank, # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around - TensorParallelMode.flip(module.tp_mode), - device, + flip_tp=True, + device=module.weight.device, + scale_span=act_scale_span, ) module.pre_quant_scale = Parameter( @@ -1950,6 +1971,10 @@ def apply(self, module: Linear, input: torch.Tensor, class W4A8NVFP4FP8LinearMethod(LinearMethodBase): + def get_tp_alignment(self, tp_mode, quant_config=None): + # Same as NVFP4: 32-element alignment for both modes. + return 32 + def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype): module.epilogue_tile_m = 128 @@ -2018,10 +2043,9 @@ def apply(self, module: Linear, input: torch.Tensor, def load_weight_scales( self, + module: Linear, weights: List[Dict], - tp_size: int = 1, - tp_rank: int = 0, - tp_mode: Optional[TensorParallelMode] = None, + shard_keys: Optional[List[str]] = None, ): # For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared. input_scale = None @@ -2029,6 +2053,27 @@ def load_weight_scales( weight_scale = [] device = torch.device("cuda") + scale_span = 32 if module.tp_mode == TensorParallelMode.ROW else 1 + + if shard_keys is not None: + for shard_key, w in zip(shard_keys, weights): + if "weight_scale" in w: + ws = module.load_shard(w["weight_scale"], + device=device, + scale_span=scale_span, + name=shard_key).contiguous() + assert ws.dtype == torch.float8_e4m3fn + weight_scale.append( + ws.view(dtype=fp4_utils.float4_sf_dtype)) + else: + for w in weights: + if "weight_scale" in w: + ws = module.load_shard(w["weight_scale"], + device=device, + scale_span=scale_span).contiguous() + assert ws.dtype == torch.float8_e4m3fn + weight_scale.append( + ws.view(dtype=fp4_utils.float4_sf_dtype)) for w in weights: if "input_scale" in w: @@ -2037,14 +2082,6 @@ def load_weight_scales( else: assert input_scale == w["input_scale"][ ...], "The input_scale should be same for all the weights" - if "weight_scale" in w: - ws = load_weight_shard(w["weight_scale"], - tp_size, - tp_rank, - tp_mode, - device=device).contiguous() - assert ws.dtype == torch.float8_e4m3fn - weight_scale.append(ws.view(dtype=fp4_utils.float4_sf_dtype)) if "weight_scale_2" in w: if weight_scale_2 is None: weight_scale_2 = w["weight_scale_2"][...] @@ -2060,16 +2097,16 @@ def load_weight_scales( return input_scale, weight_scale, weight_scale_2, alpha def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: + elm_packing = 2 if module.tp_mode == TensorParallelMode.ROW else 1 # FIXME: this depends on the kernel internals load_weights_vanilla_helper( - module, weights, - lambda w: fp4_utils.shuffle_matrix_a(w, module.epilogue_tile_m)) + module, + weights, + lambda w: fp4_utils.shuffle_matrix_a(w, module.epilogue_tile_m), + elm_packing=elm_packing) input_scale, weight_scale, weight_scale_2, alpha = self.load_weight_scales( - weights, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - tp_mode=module.tp_mode) + module, weights) assert len(weights) == 1 weight_scale = weight_scale[0] @@ -2085,14 +2122,13 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: def load_weights_fused_qkv_linear(self, module: Linear, weights: List[Dict]) -> None: + elm_packing = 2 if module.tp_mode == TensorParallelMode.ROW else 1 q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( - module, weights) + module, weights, elm_packing=elm_packing) + weight_mode = module.weights_loading_config.weight_mode input_scale, weight_scales, weight_scale_2, alpha = self.load_weight_scales( - weights, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - tp_mode=module.tp_mode) + module, weights, shard_keys=weight_mode.shard_keys) # Swizzle weight scales after concatenation weight_scale = torch.cat(weight_scales, 0) # Shuffle and Swizzle weight scale @@ -2112,18 +2148,17 @@ def load_weights_fused_qkv_linear(self, module: Linear, def load_weights_fused_gate_up_linear(self, module: Linear, weights: List[Dict]) -> None: + elm_packing = 2 if module.tp_mode == TensorParallelMode.ROW else 1 gate_weight, up_weight = load_weights_fused_gate_up_helper( - module, weights) + module, weights, elm_packing=elm_packing) fused_weight = torch.cat((gate_weight, up_weight)) fused_weight = fp4_utils.shuffle_matrix_a(fused_weight, module.epilogue_tile_m) copy_weight(module.weight, fused_weight) + weight_mode = module.weights_loading_config.weight_mode input_scale, weight_scales, weight_scale_2, alpha = self.load_weight_scales( - weights, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - tp_mode=module.tp_mode) + module, weights, shard_keys=weight_mode.shard_keys) # Swizzle weight scales after concatenation weight_scale = torch.cat(weight_scales, 0) # Shuffle and Swizzle weight scale @@ -2139,6 +2174,10 @@ def load_weights_fused_gate_up_linear(self, module: Linear, class W4A8MXFP4FP8LinearMethod(LinearMethodBase): + def get_tp_alignment(self, tp_mode, quant_config=None): + # Same as NVFP4: 32-element alignment for both modes. + return 32 + def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype): module.scaling_vector_size = 32 @@ -2185,32 +2224,30 @@ def apply(self, module: Linear, input: torch.Tensor, return output def load_weight_scales(self, + module: Linear, weights: List[Dict], - tp_size: int = 1, - tp_rank: int = 0, - tp_mode: Optional[TensorParallelMode] = None): - # For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared. + shard_keys: Optional[List[str]] = None): weight_scale = [] device = torch.device("cuda") - for w in weights: + scale_span = 32 if module.tp_mode == TensorParallelMode.ROW else 1 + + if shard_keys is None: + shard_keys = [None] + for shard_key, w in zip(shard_keys, weights): if "weight_scale" in w: - ws = load_weight_shard(w["weight_scale"], - tp_size, - tp_rank, - tp_mode, - device=device).contiguous() - # Should be E8M0 for MXFP4 + ws = module.load_shard(w["weight_scale"], + device=device, + scale_span=scale_span, + name=shard_key).contiguous() assert ws.dtype == torch.uint8 weight_scale.append(ws.view(fp4_utils.float4_sf_dtype)) return weight_scale def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: - load_weights_vanilla_helper(module, weights) + elm_packing = 2 if module.tp_mode == TensorParallelMode.ROW else 1 + load_weights_vanilla_helper(module, weights, elm_packing=elm_packing) - weight_scale = self.load_weight_scales(weights, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - tp_mode=module.tp_mode) + weight_scale = self.load_weight_scales(module, weights) assert len(weights) == 1 weight_scale = weight_scale[0] # Swizzle weight scale @@ -2219,31 +2256,30 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: def load_weights_fused_qkv_linear(self, module: Linear, weights: List[Dict]) -> None: + elm_packing = 2 if module.tp_mode == TensorParallelMode.ROW else 1 q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( - module, weights) + module, weights, elm_packing=elm_packing) fused_weight = torch.cat((q_weight, k_weight, v_weight)) copy_weight(module.weight, fused_weight) - weight_scale = self.load_weight_scales(weights, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - tp_mode=module.tp_mode) + weight_mode = module.weights_loading_config.weight_mode + weight_scale = self.load_weight_scales( + module, weights, shard_keys=weight_mode.shard_keys) weight_scale = torch.cat(weight_scale, 0) weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale) copy_weight(module.weight_scale, weight_scale) def load_weights_fused_gate_up_linear(self, module: Linear, weights: List[Dict]) -> None: + elm_packing = 2 if module.tp_mode == TensorParallelMode.ROW else 1 gate_weight, up_weight = load_weights_fused_gate_up_helper( - module, weights) + module, weights, elm_packing=elm_packing) fused_weight = torch.cat((gate_weight, up_weight)) copy_weight(module.weight, fused_weight) - weight_scale = self.load_weight_scales(weights, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - tp_mode=module.tp_mode) - # Swizzle weight scales after concatenation + weight_mode = module.weights_loading_config.weight_mode + weight_scale = self.load_weight_scales( + module, weights, shard_keys=weight_mode.shard_keys) weight_scale = torch.cat(weight_scale, 0) weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale) copy_weight(module.weight_scale, weight_scale) @@ -2251,6 +2287,17 @@ def load_weights_fused_gate_up_linear(self, module: Linear, class WeightOnlyQuantLinearMethod(LinearMethodBase): + def get_tp_alignment(self, tp_mode, quant_config=None): + # preprocess_weights_for_mixed_gemm requires: + # - ROW (in_features): % B_ROWS_PER_MMA (32 for INT4, 16 for INT8) + # - COLUMN (out_features): % MMA_SHAPE_N (8) + # COLUMN must also satisfy ROW of next layer in a column->row pipeline. + if quant_config is not None and quant_config.layer_quant_mode.is_int4_weight_only( + ): + return 32 + else: + return 16 + def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype) -> None: @@ -2289,69 +2336,61 @@ def apply(self, module: Linear, input: torch.Tensor, def load_weight_scales( self, + module: Linear, weights: List[Dict], - tp_size: int = 1, - tp_rank: int = 0, - tp_mode: Optional[TensorParallelMode] = None) -> List[torch.Tensor]: + shard_keys: Optional[List[Optional[str]]] = None + ) -> List[torch.Tensor]: device = torch.device("cuda") - q_weight_scale = load_weight_shard(weights[0]['weight_scale'], - tp_size, - tp_rank, - tp_mode, - device=device) - k_weight_scale = load_weight_shard(weights[1]['weight_scale'], - tp_size, - tp_rank, - tp_mode, - device=device) - v_weight_scale = load_weight_shard(weights[2]['weight_scale'], - tp_size, - tp_rank, - tp_mode, - device=device) - weight_scales = [q_weight_scale, k_weight_scale, v_weight_scale] - + if shard_keys is None: + shard_keys = [None] + weight_scales = [] + for shard_key, w in zip(shard_keys, weights): + ws = module.load_shard(w, + 'weight_scale', + device=device, + name=shard_key) + weight_scales.append(ws) return weight_scales def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: - load_weights_vanilla_helper(module, weights) + weight_dtype, weight_id = get_weight_dtype_and_id(module) + # INT4 checkpoint tensors are packed 2:1 along the output dimension + # before preprocessing, so COLUMN logical shard boundaries must be + # converted to packed coordinates. ROW shards the unpacked K dimension. + elm_packing = weight_id if module.tp_mode == TensorParallelMode.COLUMN else 1 + load_weights_vanilla_helper(module, weights, elm_packing=elm_packing) - device = torch.device('cuda') - weight_scale = load_weight_shard(weights[0]['weight_scale'], - module.tp_size, module.tp_rank, - module.tp_mode, device) - - copy_weight(module.weight_scale, weight_scale) + weight_scales = self.load_weight_scales(module, weights) + assert len(weight_scales) == 1 + copy_weight(module.weight_scale, weight_scales[0]) def load_weights_fused_qkv_linear(self, module: Linear, weights: List[Dict]) -> None: + weight_dtype, weight_id = get_weight_dtype_and_id(module) + elm_packing = weight_id if module.tp_mode == TensorParallelMode.COLUMN else 1 q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( - module, weights) + module, weights, elm_packing=elm_packing) fused_weight = torch.cat((q_weight, k_weight, v_weight)) - weight_dtype, _ = get_weight_dtype_and_id(module) fused_weight = preprocess_weights_for_mixed_gemm( fused_weight.to(torch.int8).T.contiguous().cpu(), weight_dtype, torch.float16).cuda().contiguous() copy_weight(module.weight, fused_weight) - weight_scales = self.load_weight_scales(weights, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - tp_mode=module.tp_mode) - - # Create concatenated weight scale tensor + weight_mode = module.weights_loading_config.weight_mode + weight_scales = self.load_weight_scales( + module, weights, shard_keys=weight_mode.shard_keys) cat_weight_scale = torch.cat(weight_scales, dim=0) copy_weight(module.weight_scale, cat_weight_scale) def load_weights_fused_gate_up_linear(self, module: Linear, weights: List[Dict]) -> None: - device = torch.device('cuda') - weight_dtype, _ = get_weight_dtype_and_id(module) + weight_dtype, weight_id = get_weight_dtype_and_id(module) + elm_packing = weight_id if module.tp_mode == TensorParallelMode.COLUMN else 1 gate_weight, up_weight = load_weights_fused_gate_up_helper( - module, weights) + module, weights, elm_packing=elm_packing) fused_weight = torch.cat((gate_weight, up_weight)) @@ -2361,18 +2400,22 @@ def load_weights_fused_gate_up_linear(self, module: Linear, copy_weight(module.weight, fused_weight) - left_scale = load_weight_shard(weights[0]['weight_scale'], - module.tp_size, module.tp_rank, - module.tp_mode, device).contiguous() - right_scale = load_weight_shard(weights[1]['weight_scale'], - module.tp_size, module.tp_rank, - module.tp_mode, device).contiguous() - fused_scale = torch.cat([left_scale, right_scale], dim=0) + weight_mode = module.weights_loading_config.weight_mode + weight_scales = self.load_weight_scales( + module, weights, shard_keys=weight_mode.shard_keys) + fused_scale = torch.cat(weight_scales, dim=0) copy_weight(module.weight_scale, fused_scale) class W4A16_AWQ_LinearMethod(LinearMethodBase): + def get_tp_alignment(self, tp_mode, quant_config=None): + if quant_config is None: + return 1 + # ROW shards input groups directly. COLUMN shards output features, which + # feed the next row-parallel layer's grouped input dimension in MLPs. + return quant_config.group_size + def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype) -> None: @@ -2423,60 +2466,51 @@ def apply(self, module: Linear, input: torch.Tensor, def load_weight_scales( self, + module: Linear, weights: List[Dict], - tp_size: int = 1, - tp_rank: int = 0, - tp_mode: Optional[TensorParallelMode] = None) -> List[torch.Tensor]: + shard_keys: Optional[List[str]] = None) -> List[torch.Tensor]: device = torch.device("cuda") - q_weight_scale = load_weight_shard(weights[0]['weight_scale'], - tp_size, - tp_rank, - tp_mode, - device=device) - k_weight_scale = load_weight_shard(weights[1]['weight_scale'], - tp_size, - tp_rank, - tp_mode, - device=device) - v_weight_scale = load_weight_shard(weights[2]['weight_scale'], - tp_size, - tp_rank, - tp_mode, - device=device) - weight_scales = [q_weight_scale, k_weight_scale, v_weight_scale] - + scale_span = module.quant_config.group_size if module.tp_mode == TensorParallelMode.ROW else 1 + weight_scales = [] + if shard_keys is None: + shard_keys = [None] + for shard_key, w in zip(shard_keys, weights): + weight_scales.append( + module.load_shard(w, + 'weight_scale', + device=device, + name=shard_key, + scale_span=scale_span)) return weight_scales def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: - load_weights_vanilla_helper(module, weights) + elm_packing = 2 if module.tp_mode == TensorParallelMode.COLUMN else 1 + load_weights_vanilla_helper(module, weights, elm_packing=elm_packing) # Use the same device as the weight tensor # as we register pre_quant_scale after sharded model weights are moved to respective gpus device = module.weight.device - pre_quant_scale = load_weight_shard( + pre_quant_scale = module.load_shard( weights[0]["pre_quant_scale"], - module.tp_size, - module.tp_rank, # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around - TensorParallelMode.flip(module.tp_mode), - device, + flip_tp=True, + device=device, ) module.pre_quant_scale = Parameter( torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype), requires_grad=False).to(device=device) - weight_scale = load_weight_shard(weights[0]['weight_scale'], - module.tp_size, module.tp_rank, - module.tp_mode, device) + weight_scale = self.load_weight_scales(module, weights)[0] copy_weight(module.pre_quant_scale, pre_quant_scale) copy_weight(module.weight_scale, weight_scale.T.contiguous()) def load_weights_fused_qkv_linear(self, module: Linear, weights: List[Dict]) -> None: + elm_packing = 2 if module.tp_mode == TensorParallelMode.COLUMN else 1 q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( - module, weights) + module, weights, elm_packing=elm_packing) fused_weight = torch.cat((q_weight, k_weight, v_weight)) fused_weight = preprocess_weights_for_mixed_gemm( @@ -2485,8 +2519,9 @@ def load_weights_fused_qkv_linear(self, module: Linear, copy_weight(module.weight, fused_weight) - weight_scales = self.load_weight_scales(weights, module.tp_size, - module.tp_rank, module.tp_mode) + weight_mode = module.weights_loading_config.weight_mode + weight_scales = self.load_weight_scales( + module, weights, shard_keys=weight_mode.shard_keys) # Create concatenated weight scale tensor cat_weight_scale = torch.cat(weight_scales, dim=0).T.contiguous() @@ -2494,9 +2529,9 @@ def load_weights_fused_qkv_linear(self, module: Linear, def load_weights_fused_gate_up_linear(self, module: Linear, weights: List[Dict]) -> None: - device = torch.device('cuda') + elm_packing = 2 if module.tp_mode == TensorParallelMode.COLUMN else 1 gate_weight, up_weight = load_weights_fused_gate_up_helper( - module, weights) + module, weights, elm_packing=elm_packing) fused_weight = torch.cat((gate_weight, up_weight)) fused_weight = preprocess_weights_for_mixed_gemm( @@ -2505,18 +2540,22 @@ def load_weights_fused_gate_up_linear(self, module: Linear, copy_weight(module.weight, fused_weight) - left_scale = load_weight_shard(weights[0]['weight_scale'], - module.tp_size, module.tp_rank, - module.tp_mode, device).contiguous() - right_scale = load_weight_shard(weights[1]['weight_scale'], - module.tp_size, module.tp_rank, - module.tp_mode, device).contiguous() - fused_scale = torch.cat([left_scale, right_scale], dim=0).T.contiguous() + weight_mode = module.weights_loading_config.weight_mode + weight_scales = self.load_weight_scales( + module, weights, shard_keys=weight_mode.shard_keys) + fused_scale = torch.cat(weight_scales, dim=0).T.contiguous() copy_weight(module.weight_scale, fused_scale) class W4A8_AWQ_LinearMethod(LinearMethodBase): + def get_tp_alignment(self, tp_mode, quant_config=None): + if quant_config is None: + return 1 + # Same grouped INT4 weight layout as W4A16_AWQ. ROW must keep input + # groups intact, and COLUMN output shards feed row-parallel grouped K. + return quant_config.group_size + def create_weights(self, module: Linear, in_features: int, out_features: int, bias: bool, dtype: torch.dtype): # Quantized weights @@ -2596,18 +2635,20 @@ def apply(self, module: Linear, input: torch.Tensor, return output def load_weight_scales_w4a8(self, + module: Linear, weights: List[Dict], - tp_size: int = 1, - tp_rank: int = 0, - tp_mode: Optional[TensorParallelMode] = None): + shard_keys: Optional[List[str]] = None): # For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared. input_scale = None weight_scale_2 = None weight_scale = [] device = torch.device("cuda") + scale_span = module.quant_config.group_size if module.tp_mode == TensorParallelMode.ROW else 1 + if shard_keys is None: + shard_keys = [None] - for w in weights: + for shard_key, w in zip(shard_keys, weights): if "input_scale" in w: if input_scale is None: input_scale = w["input_scale"][...] @@ -2615,11 +2656,11 @@ def load_weight_scales_w4a8(self, assert input_scale == w["input_scale"][ ...], "The input_scale should be same for all the weights" if "weight_scale" in w: - ws = load_weight_shard(w["weight_scale"], - tp_size, - tp_rank, - tp_mode, - device=device) + ws = module.load_shard(w, + 'weight_scale', + device=device, + name=shard_key, + scale_span=scale_span) weight_scale.append(ws.to(torch.float16)) if "weight_scale_2" in w: @@ -2635,18 +2676,17 @@ def load_weight_scales_w4a8(self, return input_scale, weight_scale, alpha, weight_scale_2 def load_weights_vanilla(self, module: Linear, weights: List[Dict]): - load_weights_vanilla_helper(module, weights) + elm_packing = 2 if module.tp_mode == TensorParallelMode.COLUMN else 1 + load_weights_vanilla_helper(module, weights, elm_packing=elm_packing) # Use the same device as the weight tensor # as we register pre_quant_scale after sharded model weights are moved to respective gpus device = module.weight.device - pre_quant_scale = load_weight_shard( + pre_quant_scale = module.load_shard( weights[0]["pre_quant_scale"], - module.tp_size, - module.tp_rank, # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around - TensorParallelMode.flip(module.tp_mode), - device, + flip_tp=True, + device=device, ) assert pre_quant_scale.dtype == module.dtype @@ -2658,10 +2698,7 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]): copy_weight(module.pre_quant_scale, pre_quant_scale) input_scale, weight_scale, alpha, weight_scale_2 = self.load_weight_scales_w4a8( - weights=weights, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - tp_mode=module.tp_mode) + module, weights) assert len(weight_scale) == 1, "there should be only one weight scale" @@ -2678,8 +2715,9 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]): def load_weights_fused_qkv_linear(self, module: Linear, weights: List[Dict]): + elm_packing = 2 if module.tp_mode == TensorParallelMode.COLUMN else 1 q_weight, k_weight, v_weight = load_weights_fused_qkv_helper( - module, weights) + module, weights, elm_packing=elm_packing) fused_weight = torch.cat((q_weight, k_weight, v_weight)) fused_weight = preprocess_weights_for_mixed_gemm( @@ -2688,11 +2726,9 @@ def load_weights_fused_qkv_linear(self, module: Linear, copy_weight(module.weight, fused_weight) + weight_mode = module.weights_loading_config.weight_mode input_scale, weight_scales, alpha, weight_scale_2 = self.load_weight_scales_w4a8( - weights=weights, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - tp_mode=module.tp_mode) + module, weights, shard_keys=weight_mode.shard_keys) # Create concatenated weight scale tensor cat_weight_scale = (torch.cat(weight_scales, dim=0).T / @@ -2708,13 +2744,11 @@ def load_weights_fused_qkv_linear(self, module: Linear, # Use the same device as the weight tensor # as we register pre_quant_scale after sharded model weights are moved to respective gpus device = module.weight.device - pre_quant_scale = load_weight_shard( + pre_quant_scale = module.load_shard( weights[0]["pre_quant_scale"], - module.tp_size, - module.tp_rank, # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around - TensorParallelMode.flip(module.tp_mode), - device, + flip_tp=True, + device=device, ) module.pre_quant_scale = Parameter( @@ -2726,8 +2760,9 @@ def load_weights_fused_qkv_linear(self, module: Linear, def load_weights_fused_gate_up_linear(self, module: Linear, weights: List[Dict]): + elm_packing = 2 if module.tp_mode == TensorParallelMode.COLUMN else 1 gate_weight, up_weight = load_weights_fused_gate_up_helper( - module, weights) + module, weights, elm_packing=elm_packing) fused_weight = torch.cat((gate_weight, up_weight)) fused_weight = preprocess_weights_for_mixed_gemm( @@ -2736,11 +2771,9 @@ def load_weights_fused_gate_up_linear(self, module: Linear, copy_weight(module.weight, fused_weight) + weight_mode = module.weights_loading_config.weight_mode input_scale, weight_scale, alpha, weight_scale_2 = self.load_weight_scales_w4a8( - weights=weights, - tp_size=module.tp_size, - tp_rank=module.tp_rank, - tp_mode=module.tp_mode) + module, weights, shard_keys=weight_mode.shard_keys) fused_scale = (torch.cat(weight_scale, dim=0).T / weight_scale_2).contiguous() @@ -2754,13 +2787,11 @@ def load_weights_fused_gate_up_linear(self, module: Linear, # Use the same device as the weight tensor # as we register pre_quant_scale after sharded model weights are moved to respective gpus device = module.weight.device - pre_quant_scale = load_weight_shard( + pre_quant_scale = module.load_shard( weights[0]["pre_quant_scale"], - module.tp_size, - module.tp_rank, # pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around - TensorParallelMode.flip(module.tp_mode), - device, + flip_tp=True, + device=device, ) # NOTE:Create this tensor in load_weights, since not all layer have this tensor and memory is not allocated for it (same as W4A16) @@ -2850,6 +2881,9 @@ def __init__( fused_weight_shard_indices_mapping: Optional[dict] = None, nvfp4_allowed_backends: Optional[List[str]] = None, enable_gemm_allreduce_fusion: bool = True, + override_tp_sharding: Optional[Union[tuple[int, int], + Dict[str, tuple[int, + int]]]] = None, ): """ Args: @@ -2889,25 +2923,36 @@ def __init__( 'cutlass', 'cublaslt', 'cuda_core' ] - local_in_features = in_features - local_out_features = out_features - - if self.tp_mode == TensorParallelMode.ROW: - assert in_features % self.tp_size == 0, ( - f'in_features {in_features} must be divisible by tp_size {self.tp_size}' - ) - local_in_features = in_features // self.tp_size - elif self.tp_mode == TensorParallelMode.COLUMN: - assert out_features % self.tp_size == 0, ( - f'out_features {out_features} must be divisible by tp_size {self.tp_size}' - ) - local_out_features = out_features // self.tp_size - reduce_output = False if self.mapping.enable_attention_dp else reduce_output + assert self.tp_mode in (TensorParallelMode.ROW, + TensorParallelMode.COLUMN, None) + + # Init TP sharding either from override or auto generated + _uneven_tp_unsupported = {QuantAlgo.NVFP4_ARC} + _quant_algo = quant_config.quant_algo if quant_config else None + if override_tp_sharding is not None: + assert _quant_algo not in _uneven_tp_unsupported + self.tp_sharding = override_tp_sharding + elif self.tp_size > 1 and self.tp_mode is not None \ + and self.weights_loading_config.weight_mode == WeightMode.VANILLA \ + and _quant_algo not in _uneven_tp_unsupported: + features = in_features if self.tp_mode == TensorParallelMode.ROW else out_features + self.tp_sharding = self._auto_tp_sharding(features, quant_config) else: - assert self.tp_mode is None, f'unsupported tensor parallel mode: {self.tp_mode}' + self.tp_sharding = None + if self.tp_size > 1 and self.tp_mode is not None: + features = in_features if self.tp_mode == TensorParallelMode.ROW else out_features + assert features % self.tp_size == 0, ( + f"Uneven TP not supported for this configuration " + f"(weight_mode={self.weights_loading_config.weight_mode}, " + f"quant_algo={_quant_algo}). " + f"features={features} must be divisible by tp_size={self.tp_size}." + ) + + self.in_features = self.calculate_local_in_features(in_features) + self.out_features = self.calculate_local_out_features(out_features) - self.in_features = local_in_features - self.out_features = local_out_features + if self.tp_mode == TensorParallelMode.COLUMN: + reduce_output = False if self.mapping.enable_attention_dp else reduce_output self.all_reduce = AllReduce(mapping=self.mapping, strategy=allreduce_strategy, @@ -2954,6 +2999,154 @@ def __init__( def get_quant_method(self, quant_config: Optional[QuantConfig] = None): return get_quant_method(quant_config) + @staticmethod + def _calc_shard(total, tp_size, rank): + return (total // tp_size) * rank + min(total % tp_size, rank) + + def _auto_tp_sharding(self, features, quant_config): + """Auto-generate tp_sharding tuple based on quant alignment requirements. + + For VANILLA mode only. Fused modes with non-divisible dims require + explicit override_tp_sharding from the model layer. + """ + assert self.weights_loading_config.weight_mode == WeightMode.VANILLA, ( + f"_auto_tp_sharding only supports VANILLA mode, got " + f"{self.weights_loading_config.weight_mode}. Fused modes require " + f"explicit override_tp_sharding.") + alignment = get_quant_method(quant_config).get_tp_alignment( + self.tp_mode, quant_config) + if alignment <= 1: + # No alignment constraint — use standard element-level distribution + start = self._calc_shard(features, self.tp_size, self.tp_rank) + end = self._calc_shard(features, self.tp_size, self.tp_rank + 1) + else: + # Distribute whole alignment-sized blocks across ranks + assert features % alignment == 0, ( + f"Feature dim ({features}) must be divisible by quant alignment " + f"({alignment}) for TP sharding") + num_blocks = features // alignment + block_start = self._calc_shard(num_blocks, self.tp_size, + self.tp_rank) + block_end = self._calc_shard(num_blocks, self.tp_size, + self.tp_rank + 1) + start = block_start * alignment + end = block_end * alignment + return (start, end) + + def _calculate_local_features_helper(self, features): + if isinstance(self.tp_sharding, tuple): + assert self.weights_loading_config.weight_mode == WeightMode.VANILLA + start, end = self.tp_sharding + return end - start + elif isinstance(self.tp_sharding, dict): + assert self.weights_loading_config.weight_mode in ( + WeightMode.FUSED_GATE_UP_LINEAR, WeightMode.FUSED_QKV_LINEAR) + return sum(end - start for start, end in self.tp_sharding.values()) + else: + assert features % self.tp_size == 0 or self.weights_loading_config.weight_mode == WeightMode.VANILLA + start = self._calc_shard(features, self.tp_size, self.tp_rank) + end = self._calc_shard(features, self.tp_size, self.tp_rank + 1) + return end - start + + def calculate_local_in_features(self, in_features): + if self.tp_mode != TensorParallelMode.ROW: + return in_features + + return self._calculate_local_features_helper(in_features) + + def calculate_local_out_features(self, out_features): + if self.tp_mode != TensorParallelMode.COLUMN: + return out_features + + return self._calculate_local_features_helper(out_features) + + def load_shard( + self, + weights: Dict, + label: Optional[str] = None, + device: torch.device = torch.device('cpu'), + name: Optional[str] = None, + flip_tp: bool = False, # for input activation scales + scale_span: Optional[int] = None, + # number of elms in a given "slot", used for fp4 since + # 2 are packed in each 8 bit element of the tensor + elm_packing: int = 1, + ) -> torch.Tensor: + if label: + if label not in weights: + return None + weight = weights[label] + else: + weight = weights + + # Skip device transfers on integrated GPUs to conserve shared memory + if weight.device.type != device.type and is_device_integrated(): + # For integrated GPU systems (e.g., DGX Spark), CPU and GPU share limited physical memory. + # Avoiding device transfers reduces memory consumption and unnecessary data copies, + # enabling support for larger models on memory-constrained systems. + logger.warning_once( + f"[Linear.load_shard] Skipping device transfer from {weight.device} to {device} on integrated GPU to conserve shared memory.", + key="load_weight_shard_skip_device_transfer_with_integrated_gpu" + ) + device = weight.device + if isinstance(weight, torch.Tensor): + tensor_shape = weight.shape + + def maybe_convert_to_torch_tensor(tensor: torch.Tensor, + indices: list[slice] + | None = None): + if indices is None: + # Avoid unnecessary copy + return tensor.to(device) + else: + return tensor[indices].to(device) + + # WAR to check whether it is a safetensor slice since safetensor didn't register the type to the module + # safetensors slice, supports lazy loading, type(weight) is `builtin.PySafeSlice` + elif hasattr(weight, "get_shape"): + tensor_shape = weight.get_shape() + + def maybe_convert_to_torch_tensor( + tensor, indices: Union[slice, tuple[slice]] = slice(None)): + return tensor[indices].to(device) + else: + raise ValueError(f'unsupported weight type: {type(weight)}') + if self.tp_mode is None or self.tp_size <= 1: + return maybe_convert_to_torch_tensor(weight) + + tp_mode = TensorParallelMode.flip( + self.tp_mode) if flip_tp else self.tp_mode + split_dim = TensorParallelMode.split_dim(tp_mode) + + if len(tensor_shape) == 1 and split_dim == 1: + return maybe_convert_to_torch_tensor(weight) + + width = tensor_shape[split_dim] + if width == 1: + return maybe_convert_to_torch_tensor(weight) + + if self.tp_sharding is None: + slice_start = self._calc_shard(width, self.tp_size, self.tp_rank) + slice_end = self._calc_shard(width, self.tp_size, self.tp_rank + 1) + else: + if isinstance(self.tp_sharding, tuple): + slice_start, slice_end = self.tp_sharding + elif isinstance(self.tp_sharding, dict): + slice_start, slice_end = self.tp_sharding[name] + + if scale_span: + assert slice_end % scale_span == 0 and slice_start % scale_span == 0 + slice_start //= scale_span + slice_end //= scale_span + + assert slice_start % elm_packing == 0 and slice_end % elm_packing == 0 + slice_start //= elm_packing + slice_end //= elm_packing + + slice_obj = [slice(d) for d in tensor_shape] + slice_obj[split_dim] = slice(slice_start, slice_end) + return maybe_convert_to_torch_tensor(weight, tuple(slice_obj)) + def create_weights(self): if self._weights_created: return diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/attention.py b/tensorrt_llm/_torch/visual_gen/models/flux/attention.py index 72c471df3404..1ad617f5cba9 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/attention.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/attention.py @@ -101,6 +101,11 @@ def __init__( mapping=config.mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, reduce_output=False, + override_tp_sharding={ + "q": (self.local_q_dim_start, self.local_q_dim_end), + "k": (self.local_kv_dim_start, self.local_kv_dim_end), + "v": (self.local_kv_dim_start, self.local_kv_dim_end), + }, ) # Need not pass any mapping info since this is intra-head normalization @@ -130,6 +135,7 @@ def __init__( allreduce_strategy=config.allreduce_strategy, tensor_parallel_mode=TensorParallelMode.ROW, reduce_output=True, + override_tp_sharding=(self.local_kv_dim_start, self.local_kv_dim_end), ) def apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -350,6 +356,7 @@ def __init__( skip_create_weights_in_init=self.skip_create_weights_in_init, force_dynamic_quantization=self.force_dynamic_quantization, config=config, + attn_shard=(self.local_q_dim_start, self.local_q_dim_end), ) def _init_qkv_proj(self): @@ -366,6 +373,11 @@ def _init_qkv_proj(self): skip_create_weights_in_init=self.skip_create_weights_in_init, force_dynamic_quantization=self.force_dynamic_quantization, mapping=self.mapping, + override_qkv_sharding={ + "q": (self.local_q_dim_start, self.local_q_dim_end), + "k": (self.local_kv_dim_start, self.local_kv_dim_end), + "v": (self.local_kv_dim_start, self.local_kv_dim_end), + }, ) def _apply_norm_rope_unfused( diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py b/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py index edb908c2d80a..f27f69dec0bf 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py @@ -53,6 +53,7 @@ def __init__( skip_create_weights_in_init: bool = False, force_dynamic_quantization: bool = False, config: Optional[DiffusionModelConfig] = None, + attn_shard: Optional[tuple[int, int]] = None, ): super().__init__() mapping = config.mapping if config else None @@ -60,6 +61,11 @@ def __init__( self.tp_rank = getattr(mapping, "tp_rank", 0) self.attn_dim = attn_dim self.has_bias = bias + self.attn_shard = attn_shard + + assert attn_dim % self.tp_size == 0 or self.attn_shard, ( + "Explicit attention sharding required for uneven TP" + ) if self.tp_size == 1: self.proj = Linear( @@ -84,6 +90,7 @@ def __init__( mapping=config.mapping, tensor_parallel_mode=TensorParallelMode.ROW, reduce_output=False, + override_tp_sharding=self.attn_shard, ) self.mlp_proj = Linear( mlp_dim, @@ -162,10 +169,12 @@ def __init__( skip_create_weights_in_init: bool = False, force_dynamic_quantization: bool = False, mapping: Optional[Mapping] = None, + override_qkv_sharding=None, ): super().__init__() self.tp_size = mapping.tp_size if mapping else 1 + self.tp_rank = mapping.tp_rank if mapping else 0 # Store full (pre-TP) dims for weight loading (splitting checkpoint weight) self.full_q_dim = q_dim @@ -188,9 +197,12 @@ def __init__( self.local_qkv_dim = q_dim + 2 * kv_dim self.local_mlp_dim = mlp_dim else: - local_q_dim = q_dim // self.tp_size - local_kv_dim = kv_dim // self.tp_size - shard_mlp_hidden_dim = self.mlp_hidden_dim // self.tp_size + + def range_size(r): + return r[1] - r[0] + + local_q_dim = range_size(override_qkv_sharding["q"]) + local_kv_dim = range_size(override_qkv_sharding["k"]) # QKV: column-parallel with fused Q/K/V sharding self.qkv_proj = Linear( in_dim, @@ -211,8 +223,17 @@ def __init__( mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, reduce_output=False, + override_tp_sharding=override_qkv_sharding, + ) + + local_mlp_hidden_start = Linear._calc_shard( + self.mlp_hidden_dim, self.tp_size, self.tp_rank + ) + local_mlp_hidden_end = Linear._calc_shard( + self.mlp_hidden_dim, self.tp_size, self.tp_rank + 1 ) - # MLP gate+up: column-parallel with fused gate/up sharding + local_mlp_hidden_size = local_mlp_hidden_end - local_mlp_hidden_start + self.mlp_proj = Linear( in_dim, mlp_dim, @@ -225,15 +246,19 @@ def __init__( weight_mode=WeightMode.FUSED_GATE_UP_LINEAR, ), fused_weight_shard_indices_mapping={ - "gate": (0, shard_mlp_hidden_dim), - "up": (shard_mlp_hidden_dim, shard_mlp_hidden_dim), + "gate": (0, local_mlp_hidden_size), + "up": (local_mlp_hidden_size, local_mlp_hidden_size), }, mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, reduce_output=False, + override_tp_sharding={ + "gate": (local_mlp_hidden_start, local_mlp_hidden_end), + "up": (local_mlp_hidden_start, local_mlp_hidden_end), + }, ) self.local_qkv_dim = (q_dim + 2 * kv_dim) // self.tp_size - self.local_mlp_dim = mlp_dim // self.tp_size + self.local_mlp_dim = local_mlp_hidden_size def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Returns (qkv, mlp_gate_up) with local (post-TP) sizes.""" diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py index d6551918141d..527dce31df23 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py @@ -467,21 +467,6 @@ def __init__( ) self.act_mlp = _gelu_tanh_eager - kv_dim = num_attention_heads * attention_head_dim - - # MLP + Attn Output projection, requires special handling for TP - self.proj_out = FluxJointAttnMLPProj( - attn_dim=kv_dim, - mlp_dim=self.mlp_hidden_dim, - out_dim=dim, - bias=True, - dtype=dtype, - quant_config=quant_config, - skip_create_weights_in_init=skip_create_weights, - force_dynamic_quantization=force_dynamic_quant, - config=config, - ) - # Attention (no added_kv_proj_dim since tokens are already concatenated) self.attn = FluxJointAttention( hidden_size=dim, @@ -495,6 +480,21 @@ def __init__( module_name=f"single_transformer_blocks.{layer_idx}.attn", ) + # MLP + Attn Output projection, requires special handling for TP + self.proj_out = FluxJointAttnMLPProj( + attn_dim=self.attn.q_dim, + mlp_dim=self.mlp_hidden_dim, + out_dim=dim, + bias=True, + dtype=dtype, + quant_config=quant_config, + skip_create_weights_in_init=skip_create_weights, + force_dynamic_quantization=force_dynamic_quant, + config=config, + # need explicit shard because we are aligned on head boundaries + attn_shard=(self.attn.local_q_dim_start, self.attn.local_q_dim_end), + ) + def forward( self, hidden_states: torch.Tensor, diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py index 7e86f23d7c70..d84e189d4c27 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.py @@ -371,6 +371,7 @@ def __init__( force_dynamic_quantization=force_dynamic_quant, tensor_parallel_mode=tp_mode, reduce_output=False, + override_tp_sharding=(self.attn2.local_kv_dim_start, self.attn2.local_kv_dim_end), ) self.add_v_proj = Linear( added_kv_proj_dim, @@ -382,6 +383,7 @@ def __init__( force_dynamic_quantization=force_dynamic_quant, tensor_parallel_mode=tp_mode, reduce_output=False, + override_tp_sharding=(self.attn2.local_kv_dim_start, self.attn2.local_kv_dim_end), ) self.norm_added_k = RMSNormTPAware( hidden_size=hidden_size, @@ -390,6 +392,7 @@ def __init__( has_weights=True, enable_tp=(tp_size > 1), mapping=model_config.mapping, + override_tp_sharding=(self.attn2.local_kv_dim_start, self.attn2.local_kv_dim_end), ) # Use torch.empty().normal_(std=...) instead of torch.randn()/scale for MetaInitMode compatibility diff --git a/tensorrt_llm/_torch/visual_gen/modules/attention.py b/tensorrt_llm/_torch/visual_gen/modules/attention.py index eafac1e328c4..aee7ccd1d821 100644 --- a/tensorrt_llm/_torch/visual_gen/modules/attention.py +++ b/tensorrt_llm/_torch/visual_gen/modules/attention.py @@ -75,10 +75,7 @@ def __init__( self.bias = bias self.tp_size = self.mapping.tp_size if self.mapping else 1 - assert ( - self.num_attention_heads % self.tp_size == 0 - and self.num_key_value_heads % self.tp_size == 0 - ), "TP size must divide the number of Query and KV Heads" + self.tp_rank = self.mapping.tp_rank if self.mapping else 0 # Fused QK Norm + RoPE: each model class opts in via fuse_qk_norm_rope. # Backed by torch.ops.trtllm.fused_dit_qk_norm_rope which auto-dispatches: @@ -112,11 +109,7 @@ def __init__( self.q_dim = self.num_attention_heads * self.head_dim self.kv_dim = self.num_key_value_heads * self.head_dim - self.local_num_attention_heads = self.num_attention_heads // self.tp_size - self.local_num_key_value_heads = self.num_key_value_heads // self.tp_size - self.local_q_dim = self.local_num_attention_heads * self.head_dim - self.local_kv_dim = self.local_num_key_value_heads * self.head_dim - + self._calculate_tp_parameters(ulysses_size if enable_ulysses else None) self._init_qkv_proj() # Structural eligibility for SEPARATE_QKV self-attn quantize dedup. @@ -142,6 +135,12 @@ def __init__( q_norm_dim = self.head_dim if qk_norm_mode == "per_head" else self.q_dim k_norm_dim = self.head_dim if qk_norm_mode == "per_head" else self.kv_dim enable_tp_rms = self.tp_size > 1 and qk_norm_mode == "full" + + q_start = self.local_q_dim_start + q_end = self.local_q_dim_end + k_start = self.local_kv_dim_start + k_end = self.local_kv_dim_end + self.norm_q = RMSNormTPAware( hidden_size=q_norm_dim, eps=self.eps, @@ -149,6 +148,7 @@ def __init__( has_weights=True, enable_tp=enable_tp_rms, mapping=self.mapping, + override_tp_sharding=(q_start, q_end) if qk_norm_mode == "full" else None, ) self.norm_k = RMSNormTPAware( hidden_size=k_norm_dim, @@ -157,6 +157,7 @@ def __init__( has_weights=True, enable_tp=enable_tp_rms, mapping=self.mapping, + override_tp_sharding=(k_start, k_end) if qk_norm_mode == "full" else None, ) # TODO: Use weight mapper to create just a Linear module @@ -174,6 +175,7 @@ def __init__( tensor_parallel_mode=TensorParallelMode.ROW if self.tp_size > 1 else None, reduce_output=(self.tp_size > 1), allreduce_strategy=self.allreduce_strategy, + override_tp_sharding=(self.local_q_dim_start, self.local_q_dim_end), ) ] ) @@ -249,6 +251,46 @@ def _qualified_module_name( prefix = f"{component_name}." return module_name if module_name.startswith(prefix) else f"{prefix}{module_name}" + def _calculate_tp_parameters(self, ulysses_size: Optional[int]): + assert self.num_attention_heads % self.num_key_value_heads == 0 + gqa_ratio = self.num_attention_heads // self.num_key_value_heads + + if not ulysses_size: + ulysses_size = 1 + + assert self.num_key_value_heads % ulysses_size == 0 + # Note: this is intentionally stronger than `num_kv_head >= ulysses_size * tp_size` + assert self.num_key_value_heads // ulysses_size >= self.tp_size + + def _calc_shard(full, size, rank): + full //= ulysses_size + shard = (full // size) * rank + min(full % size, rank) + return shard * ulysses_size + + self.local_key_value_head_start = _calc_shard( + self.num_key_value_heads, self.tp_size, self.tp_rank + ) + self.local_key_value_head_end = _calc_shard( + self.num_key_value_heads, self.tp_size, self.tp_rank + 1 + ) + self.local_num_key_value_heads = ( + self.local_key_value_head_end - self.local_key_value_head_start + ) + + self.local_attention_head_start = gqa_ratio * self.local_key_value_head_start + self.local_attention_head_end = gqa_ratio * self.local_key_value_head_end + self.local_num_attention_heads = ( + self.local_attention_head_end - self.local_attention_head_start + ) + + self.local_q_dim_start = self.local_attention_head_start * self.head_dim + self.local_q_dim_end = self.local_attention_head_end * self.head_dim + self.local_q_dim = self.local_q_dim_end - self.local_q_dim_start + + self.local_kv_dim_start = self.local_key_value_head_start * self.head_dim + self.local_kv_dim_end = self.local_key_value_head_end * self.head_dim + self.local_kv_dim = self.local_kv_dim_end - self.local_kv_dim_start + def _init_qkv_proj(self) -> None: tp_mode = TensorParallelMode.COLUMN if self.tp_size > 1 else None @@ -276,6 +318,11 @@ def _init_qkv_proj(self) -> None: }, tensor_parallel_mode=tp_mode, reduce_output=False, + override_tp_sharding={ + "q": (self.local_q_dim_start, self.local_q_dim_end), + "k": (self.local_kv_dim_start, self.local_kv_dim_end), + "v": (self.local_kv_dim_start, self.local_kv_dim_end), + }, ) else: self.to_q = Linear( @@ -289,6 +336,7 @@ def _init_qkv_proj(self) -> None: force_dynamic_quantization=self.force_dynamic_quantization, tensor_parallel_mode=tp_mode, reduce_output=False, + override_tp_sharding=(self.local_q_dim_start, self.local_q_dim_end), ) self.to_k = Linear( self.hidden_size, @@ -301,6 +349,7 @@ def _init_qkv_proj(self) -> None: force_dynamic_quantization=self.force_dynamic_quantization, tensor_parallel_mode=tp_mode, reduce_output=False, + override_tp_sharding=(self.local_kv_dim_start, self.local_kv_dim_end), ) self.to_v = Linear( self.hidden_size, @@ -313,6 +362,7 @@ def _init_qkv_proj(self) -> None: force_dynamic_quantization=self.force_dynamic_quantization, tensor_parallel_mode=tp_mode, reduce_output=False, + override_tp_sharding=(self.local_kv_dim_start, self.local_kv_dim_end), ) def get_qkv( diff --git a/tensorrt_llm/_torch/visual_gen/modules/rms_norm.py b/tensorrt_llm/_torch/visual_gen/modules/rms_norm.py index b58239ca8b9b..76da6b2ad18e 100644 --- a/tensorrt_llm/_torch/visual_gen/modules/rms_norm.py +++ b/tensorrt_llm/_torch/visual_gen/modules/rms_norm.py @@ -19,6 +19,7 @@ from torch import nn from tensorrt_llm._torch.distributed import AllReduce +from tensorrt_llm._torch.modules.linear import Linear # for Linear._calc_shard from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.mapping import Mapping @@ -36,6 +37,7 @@ def __init__( enable_tp: bool = False, mapping: Optional[Mapping] = None, allreduce_strategy: AllReduceStrategy = AllReduceStrategy.NCCL, + override_tp_sharding: Optional[tuple] = None, ): super().__init__() @@ -45,30 +47,43 @@ def __init__( self.mapping = mapping self.enable_tp = enable_tp + self.hidden_size = hidden_size + if enable_tp: assert mapping is not None - self.full_size = hidden_size - shard = hidden_size // mapping.tp_size - start = shard * mapping.tp_rank - end = min(shard * (mapping.tp_rank + 1), hidden_size) - hidden_size = end - start + if override_tp_sharding: + self.tp_sharding = override_tp_sharding + else: + start = Linear._calc_shard(self.hidden_size, mapping.tp_size, mapping.tp_rank) + end = Linear._calc_shard(self.hidden_size, mapping.tp_size, mapping.tp_rank + 1) + self.tp_sharding = (start, end) + + start, end = self.tp_sharding + self.local_hidden_size = end - start self.allreduce = AllReduce( mapping=mapping, strategy=allreduce_strategy, dtype=torch.float32 ) else: + self.local_hidden_size = self.hidden_size self.allreduce = None if use_gemma and not has_weights: raise ValueError("has_weights must be True if use_gemma is True") if has_weights: if not use_gemma: - self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device)) + self.weight = nn.Parameter( + torch.ones(self.local_hidden_size, dtype=dtype, device=device) + ) else: - self.weight = nn.Parameter(torch.zeros(hidden_size, dtype=dtype, device=device)) + self.weight = nn.Parameter( + torch.zeros(self.local_hidden_size, dtype=dtype, device=device) + ) else: self.register_buffer( - "weight", torch.ones(hidden_size, dtype=dtype, device=device), persistent=False + "weight", + torch.ones(self.local_hidden_size, dtype=dtype, device=device), + persistent=False, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -78,7 +93,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: x2 = hidden_states.pow(2) if self.allreduce: x2_sum = x2.sum(-1, keepdim=True) - variance = self.allreduce(x2_sum) / self.full_size + variance = self.allreduce(x2_sum) / self.hidden_size else: variance = x2.mean(-1, keepdim=True) @@ -95,9 +110,7 @@ def load_weights(self, weights: torch.Tensor): if param is None or param_name not in weights: continue if param_name == "weight" and self.enable_tp: - shard = self.full_size // self.mapping.tp_size - start = shard * self.mapping.tp_rank - end = min(shard * (self.mapping.tp_rank + 1), self.full_size) + start, end = self.tp_sharding data = weights[param_name][..., start:end] else: data = weights[param_name] diff --git a/tests/unittest/_torch/modules/test_linear_uneven_tp.py b/tests/unittest/_torch/modules/test_linear_uneven_tp.py new file mode 100644 index 000000000000..f2010940750c --- /dev/null +++ b/tests/unittest/_torch/modules/test_linear_uneven_tp.py @@ -0,0 +1,1371 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import pytest +import torch + +from tensorrt_llm._torch.modules.linear import ( + Linear, + TensorParallelMode, + WeightMode, + WeightsLoadingConfig, +) +from tensorrt_llm._utils import get_sm_version, is_sm_100f +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization.mode import QuantAlgo + + +class FakeMapping(Mapping): + def __new__(cls, *args, **kwargs): + return object.__new__(cls) + + def __init__(self, world_size, rank): + super().__init__( + world_size=world_size, + rank=rank, + tp_size=world_size, + ) + self.tp_rank = rank + + +@pytest.fixture(autouse=True) +def seed(): + torch.manual_seed(42) + + +def build_weights(in_features, out_features, quant_algo, bias=True): + if quant_algo == QuantAlgo.NO_QUANT: + w = { + "weight": torch.randn(out_features, in_features) + * torch.rsqrt(torch.tensor(float(in_features))) + } + if bias: + w["bias"] = torch.randn(out_features) + return [w] + elif quant_algo == QuantAlgo.FP8: + fp32_weight = torch.randn(out_features, in_features) * torch.rsqrt( + torch.tensor(float(in_features)) + ) + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + weight_scale = fp32_weight.abs().max() / max_fp8 + fp8_weight = (fp32_weight / weight_scale).to(torch.float8_e4m3fn) + w = { + "weight": fp8_weight, + "weight_scale": weight_scale, + "input_scale": torch.tensor(1.0, dtype=torch.float32), + } + if bias: + w["bias"] = torch.randn(out_features) + return [w] + elif quant_algo == QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN: + fp32_weight = torch.randn(out_features, in_features) * torch.rsqrt( + torch.tensor(float(in_features)) + ) + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + # Per-row scale: one scale per output row + row_max = fp32_weight.abs().amax(dim=1) + weight_scale = row_max / max_fp8 + fp8_weight = (fp32_weight / weight_scale.unsqueeze(1)).to(torch.float8_e4m3fn) + w = { + "weight": fp8_weight, + "weight_scale": weight_scale, + } + if bias: + w["bias"] = torch.randn(out_features) + return [w] + elif quant_algo == QuantAlgo.FP8_BLOCK_SCALES: + fp32_weight = torch.randn(out_features, in_features) * torch.rsqrt( + torch.tensor(float(in_features)) + ) + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + # Per 128-element block scales + scale_rows = math.ceil(out_features / 128) + scale_cols = math.ceil(in_features / 128) + weight_scale = torch.empty(scale_rows, scale_cols, dtype=torch.float32) + fp8_weight = torch.empty(out_features, in_features, dtype=torch.float8_e4m3fn) + for r in range(scale_rows): + for c in range(scale_cols): + r_start, r_end = r * 128, min((r + 1) * 128, out_features) + c_start, c_end = c * 128, min((c + 1) * 128, in_features) + block = fp32_weight[r_start:r_end, c_start:c_end] + block_max = block.abs().max() + s = block_max / max_fp8 + weight_scale[r, c] = s + fp8_weight[r_start:r_end, c_start:c_end] = (block / s).to(torch.float8_e4m3fn) + w = { + "weight": fp8_weight, + "weight_scale": weight_scale, + } + if bias: + w["bias"] = torch.randn(out_features) + return [w] + elif quant_algo == QuantAlgo.NVFP4: + FP8_MAX, E2M1_MAX = 448.0, 6.0 + scaling_vector_size = 16 + fp32_weight = torch.randn( + out_features, in_features, device="cuda", dtype=torch.bfloat16 + ) * torch.rsqrt(torch.tensor(float(in_features))) + weight_amax = fp32_weight.abs().max().float() + weight_scale_2 = weight_amax / (FP8_MAX * E2M1_MAX) + input_scale = torch.tensor(FP8_MAX * E2M1_MAX, dtype=torch.float32) + # Quantize weight to FP4 using the TRTLLM op (NVFP4: sfVecSize=16, UE8M0=False) + global_scale = torch.tensor(FP8_MAX * E2M1_MAX / weight_amax, device="cuda") + fp4_weight, fp4_weight_scale = torch.ops.trtllm.fp4_quantize( + fp32_weight, + global_scale, + scaling_vector_size, + sfUseUE8M0=False, + isSfSwizzledLayout=False, + ) + fp4_weight_scale = fp4_weight_scale.reshape( + out_features, in_features // scaling_vector_size + ) + w = { + "weight": fp4_weight.cpu(), + "weight_scale": fp4_weight_scale.cpu(), + "input_scale": input_scale, + "weight_scale_2": weight_scale_2, + } + if bias: + w["bias"] = torch.randn(out_features) + return [w] + elif quant_algo == QuantAlgo.W4A8_NVFP4_FP8: + scaling_vector_size = 32 + import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils + + # FP4 E2M1: 1.0 = 0b0010. Super-diagonal: M[i, i+1] = 1. + # Packed pairs: element i+1 sits in byte (i+1)//2, nibble (i+1)%2. + # This is an easy way to generate synthetic data that will not cause + # overflows but still requires cross-gpu communication (ie not block diagonal) + packed_cols = in_features // 2 + raw = torch.zeros(out_features, packed_cols, dtype=torch.uint8) + for i in range(min(out_features, in_features - 1)): + j = i + 1 + byte_idx = j // 2 + if j % 2 == 0: + raw[i, byte_idx] = 0x02 # low nibble + else: + raw[i, byte_idx] = 0x20 # high nibble + fp4_weight = raw.view(fp4_utils.float4_e2m1x2) + scale_shape = (out_features, in_features // scaling_vector_size) + fp4_weight_scale = torch.ones(scale_shape, dtype=torch.float32).to(torch.float8_e4m3fn) + input_scale = torch.tensor(1.0, dtype=torch.float32) + weight_scale_2 = torch.tensor(1.0, dtype=torch.float32) + w = { + "weight": fp4_weight, + "weight_scale": fp4_weight_scale, + "input_scale": input_scale, + "weight_scale_2": weight_scale_2, + } + if bias: + w["bias"] = torch.zeros(out_features) + return [w] + elif quant_algo in (QuantAlgo.W4A8_MXFP4_FP8, QuantAlgo.W4A8_MXFP4_MXFP8): + scaling_vector_size = 32 + fp32_weight = torch.randn( + out_features, in_features, device="cuda", dtype=torch.bfloat16 + ) * torch.rsqrt(torch.tensor(float(in_features))) + # MXFP4: sfVecSize=32, UE8M0=True, no globalScale needed + fp4_weight, fp4_weight_scale = torch.ops.trtllm.fp4_quantize( + fp32_weight, None, scaling_vector_size, sfUseUE8M0=True, isSfSwizzledLayout=False + ) + fp4_weight_scale = fp4_weight_scale.reshape( + out_features, in_features // scaling_vector_size + ) + w = { + "weight": fp4_weight.cpu(), + "weight_scale": fp4_weight_scale.cpu(), + } + if bias: + w["bias"] = torch.randn(out_features) + return [w] + elif quant_algo == QuantAlgo.W8A16: + # Match the existing THOP weight-only linear test: quantize a logical + # (in_features, out_features) matrix and store checkpoint weight as + # (out_features, in_features). + fp32_weight = torch.randn(in_features, out_features) * torch.rsqrt( + torch.tensor(float(in_features)) + ) + quant_weight, _, weight_scale = ( + torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix( + fp32_weight.cpu(), torch.int8 + ) + ) + w = { + "weight": quant_weight.T.contiguous(), + "weight_scale": weight_scale, + } + if bias: + w["bias"] = torch.randn(out_features) + return [w] + elif quant_algo == QuantAlgo.W4A16: + # INT4 weight-only checkpoint stores the output dimension packed 2:1: + # quant_weight is (in_features, out_features // 2), so transposed + # checkpoint weight is (out_features // 2, in_features). + fp32_weight = torch.randn(in_features, out_features) * torch.rsqrt( + torch.tensor(float(in_features)) + ) + quant_weight, _, weight_scale = ( + torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix( + fp32_weight.cpu(), torch.quint4x2 + ) + ) + w = { + "weight": quant_weight.T.contiguous(), + "weight_scale": weight_scale, + } + if bias: + w["bias"] = torch.randn(out_features) + return [w] + elif quant_algo in (QuantAlgo.W4A16_AWQ, QuantAlgo.W4A8_AWQ): + group_size = 128 + dtype = torch.float16 if quant_algo == QuantAlgo.W4A16_AWQ else torch.bfloat16 + # Checkpoint weight is packed along output dim. Use a sparse synthetic + # super-diagonal so sharded and full GEMMs compare stably. + raw_weight = torch.zeros(in_features, out_features // 2, dtype=torch.uint8, device="cuda") + for i in range(min(in_features, out_features - 1)): + j = i + 1 + byte_idx = j // 2 + if j % 2 == 0: + raw_weight[i, byte_idx] = 0x01 # low nibble + else: + raw_weight[i, byte_idx] = 0x10 # high nibble + pre_quant_scale = torch.ones(in_features, dtype=dtype, device="cuda") + scale_dtype = torch.float32 if quant_algo == QuantAlgo.W4A16_AWQ else torch.float16 + weight_scale = torch.ones( + in_features // group_size, out_features, dtype=scale_dtype, device="cuda" + ) + w = { + "weight": raw_weight.T.contiguous(), + "weight_scale": weight_scale.T.contiguous(), + "pre_quant_scale": pre_quant_scale, + } + if quant_algo == QuantAlgo.W4A8_AWQ: + w["input_scale"] = torch.tensor(1.0, dtype=torch.float32) + w["weight_scale_2"] = torch.tensor(1.0, dtype=torch.float32) + if bias: + w["bias"] = torch.zeros(out_features) + return [w] + else: + raise NotImplementedError(f"Test does not support QuantAlgo {quant_algo}") + + +DEFAULT_DTYPES = { + QuantAlgo.NO_QUANT: torch.float32, + QuantAlgo.FP8: torch.float32, + QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN: torch.bfloat16, + QuantAlgo.FP8_BLOCK_SCALES: torch.bfloat16, + QuantAlgo.NVFP4: torch.bfloat16, + QuantAlgo.W4A8_NVFP4_FP8: torch.bfloat16, + QuantAlgo.W4A8_MXFP4_FP8: torch.bfloat16, + QuantAlgo.W4A8_MXFP4_MXFP8: torch.bfloat16, + QuantAlgo.W8A16: torch.float16, + QuantAlgo.W4A16: torch.float16, + QuantAlgo.W4A16_AWQ: torch.float16, + QuantAlgo.W4A8_AWQ: torch.bfloat16, +} + + +def build_linears( + in_features, + out_features, + world_size, + quant_algo, + bias=True, + dtype=None, + overrides=None, + **kwargs, +): + """Build one Linear per rank, load shared weights. + + Args: + overrides: Optional list of per-rank override_tp_sharding tuples. + Length must equal world_size. Overrides auto tp_sharding. + """ + weights = build_weights(in_features, out_features, quant_algo, bias=bias) + if dtype is None: + dtype = DEFAULT_DTYPES[quant_algo] + if overrides is not None: + assert len(overrides) == world_size + linears = [] + for rank in range(world_size): + mapping = FakeMapping(world_size, rank) + if quant_algo in (QuantAlgo.W4A16_AWQ, QuantAlgo.W4A8_AWQ): + quant_config = QuantConfig(quant_algo=quant_algo, group_size=128, has_zero_point=False) + else: + quant_config = QuantConfig(quant_algo=quant_algo) + override = overrides[rank] if overrides is not None else None + linear = Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + reduce_output=False, + override_tp_sharding=override, + **kwargs, + ) + linear.load_weights(weights) + linear.post_load_weights() + linear.cuda() + linears.append(linear) + return linears, weights + + +def _fused_shard_indices_mapping(shard_keys, ranges): + mapping = {} + offset = 0 + for key in shard_keys: + start, end = ranges[key] + size = end - start + mapping[key] = (offset, size) + offset += size + return mapping + + +def _prepare_fused_weights_for_loading(weights, quant_algo): + if quant_algo == QuantAlgo.NVFP4: + shared_weight_scale_2 = weights[0]["weight_scale_2"].clone() + for weight in weights: + weight["weight_scale"] = weight["weight_scale"].view(torch.float8_e4m3fn) + weight["weight_scale_2"] = shared_weight_scale_2.clone() + + +def build_fused_linears( + in_features, + sub_out_features, + world_size, + quant_algo, + weight_mode, + shard_keys, + overrides=None, + allow_partial_loading=False, +): + weights = [ + build_weights(in_features, sub_out_features, quant_algo, bias=True)[0] for _ in shard_keys + ] + _prepare_fused_weights_for_loading(weights, quant_algo) + dtype = DEFAULT_DTYPES[quant_algo] + if overrides is not None: + assert len(overrides) == world_size + linears = [] + for rank in range(world_size): + mapping = FakeMapping(world_size, rank) + quant_config = QuantConfig(quant_algo=quant_algo) + override = overrides[rank] if overrides is not None else None + shard_indices_mapping = ( + _fused_shard_indices_mapping(shard_keys, override) + if override is not None and (allow_partial_loading or quant_algo == QuantAlgo.NVFP4) + else None + ) + linear = Linear( + in_features, + sub_out_features * len(shard_keys), + bias=True, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + weights_loading_config=WeightsLoadingConfig(weight_mode=weight_mode), + reduce_output=False, + override_tp_sharding=override, + fused_weight_shard_indices_mapping=shard_indices_mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + linear.load_weights(weights, allow_partial_loading=allow_partial_loading) + if allow_partial_loading: + linear.process_weights_after_loading() + linear.post_load_weights() + linear.cuda() + linears.append(linear) + return linears, weights + + +def build_fused_reference( + in_features, sub_out_features, quant_algo, weight_mode, shard_keys, weights +): + dtype = DEFAULT_DTYPES[quant_algo] + mapping = FakeMapping(1, 0) + quant_config = QuantConfig(quant_algo=quant_algo) + shard_indices_mapping = _fused_shard_indices_mapping( + shard_keys, {key: (0, sub_out_features) for key in shard_keys} + ) + ref = Linear( + in_features, + sub_out_features * len(shard_keys), + bias=True, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + weights_loading_config=WeightsLoadingConfig(weight_mode=weight_mode), + reduce_output=False, + fused_weight_shard_indices_mapping=shard_indices_mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + ref.load_weights(weights) + ref.post_load_weights() + ref.cuda() + return ref + + +def build_reference(in_features, out_features, quant_algo, weights, bias=True, **kwargs): + """Build a single tp_size=1 linear loaded with the given weights.""" + dtype = DEFAULT_DTYPES.get(quant_algo, torch.float32) + mapping = FakeMapping(1, 0) + quant_config = QuantConfig(quant_algo=quant_algo) + ref = Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + mapping=mapping, + quant_config=quant_config, + reduce_output=False, + **kwargs, + ) + ref.load_weights(weights) + ref.post_load_weights() + ref.cuda() + return ref + + +def _legacy_even_slice(total, tp_size, rank): + shard = total // tp_size + return rank * shard, (rank + 1) * shard + + +def _assert_same_storage(actual, expected): + torch.testing.assert_close( + actual.detach().cpu().view(torch.uint8), + expected.detach().cpu().view(torch.uint8), + rtol=0, + atol=0, + ) + + +def _check_fused_weight_reconstruction(linears, weights, shard_keys, per_rank_ranges): + for rank, linear in enumerate(linears): + expected_weights = [] + expected_biases = [] + for key in shard_keys: + start, end = per_rank_ranges[rank][key] + expected_weights.append(weights[shard_keys.index(key)]["weight"][start:end]) + expected_biases.append(weights[shard_keys.index(key)]["bias"][start:end]) + + expected_weight = torch.cat(expected_weights, dim=0) + expected_bias = torch.cat(expected_biases, dim=0).to(linear.bias.dtype) + _assert_same_storage(linear.weight, expected_weight) + torch.testing.assert_close(linear.bias.detach().cpu(), expected_bias) + + +def _assemble_fused_outputs(outputs, shard_keys, per_rank_ranges): + per_key_outputs = {key: [] for key in shard_keys} + for output, ranges in zip(outputs, per_rank_ranges): + offset = 0 + for key in shard_keys: + start, end = ranges[key] + size = end - start + per_key_outputs[key].append(output[..., offset : offset + size]) + offset += size + + return torch.cat([torch.cat(per_key_outputs[key], dim=-1) for key in shard_keys], dim=-1) + + +def _check_fused_forward(linears, ref, shard_keys, per_rank_ranges, quant_algo): + x = torch.randn(2, ref.in_features, device="cuda", dtype=DEFAULT_DTYPES[quant_algo]) + outputs = [linear(x) for linear in linears] + result = _assemble_fused_outputs(outputs, shard_keys, per_rank_ranges) + expected = ref(x) + if quant_algo == QuantAlgo.NO_QUANT: + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + else: + torch.testing.assert_close(result, expected, rtol=0.2, atol=0.2) + + +# ── Test parametrizations ── + +# Pipeline: unified input → column(in→hidden) → row(hidden→out) → sum +# (in_features, hidden, out_features, tp_size) +PIPELINE_CASES = [ + # even + (32, 32, 32, 2), + (32, 64, 32, 4), + # uneven hidden (column out and row in both split unevenly) + (32, 10, 32, 3), + (16, 7, 16, 2), + (16, 13, 16, 4), + (8, 5, 8, 3), +] + + +class TestMLP: + """Unified input → ColumnParallel → RowParallel(no allreduce) → sum.""" + + @pytest.mark.parametrize("in_features,hidden,out_features,tp_size", PIPELINE_CASES) + def test_pipeline(self, in_features, hidden, out_features, tp_size): + col_linears, col_weights = build_linears( + in_features, + hidden, + tp_size, + QuantAlgo.NO_QUANT, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_linears, row_weights = build_linears( + hidden, + out_features, + tp_size, + QuantAlgo.NO_QUANT, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + + x = torch.randn(2, in_features, device="cuda") + + partial_outputs = [] + for rank in range(tp_size): + col_out = col_linears[rank](x) + row_out = row_linears[rank](col_out) + partial_outputs.append(row_out) + + result = sum(partial_outputs) + + w_col = col_weights[0]["weight"].cuda() + b_col = col_weights[0]["bias"].cuda() + w_row = row_weights[0]["weight"].cuda() + b_row = row_weights[0]["bias"].cuda() + expected = (x @ w_col.t() + b_col) @ w_row.t() + b_row + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + @pytest.mark.parametrize("in_features,hidden,out_features,tp_size", PIPELINE_CASES) + def test_pipeline_no_bias(self, in_features, hidden, out_features, tp_size): + col_linears, col_weights = build_linears( + in_features, + hidden, + tp_size, + QuantAlgo.NO_QUANT, + bias=False, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_linears, row_weights = build_linears( + hidden, + out_features, + tp_size, + QuantAlgo.NO_QUANT, + bias=False, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + + x = torch.randn(2, in_features, device="cuda") + + partial_outputs = [] + for rank in range(tp_size): + col_out = col_linears[rank](x) + row_out = row_linears[rank](col_out) + partial_outputs.append(row_out) + + result = sum(partial_outputs) + + w_col = col_weights[0]["weight"].cuda() + w_row = row_weights[0]["weight"].cuda() + expected = (x @ w_col.t()) @ w_row.t() + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + +@pytest.mark.skipif( + get_sm_version() < 89, + reason="FP8 per-tensor is supported on SM 89+ GPUs", +) +class TestFP8QDQMLP: + """FP8QDQ: unified input → ColumnParallel → RowParallel → sum.""" + + @pytest.mark.parametrize( + "in_features,hidden,out_features,tp_size", + [ + (32, 32, 32, 2), + (32, 64, 32, 4), + (64, 48, 64, 3), + ], + ) + def test_pipeline(self, in_features, hidden, out_features, tp_size): + col_linears, col_weights = build_linears( + in_features, + hidden, + tp_size, + QuantAlgo.FP8, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_linears, row_weights = build_linears( + hidden, + out_features, + tp_size, + QuantAlgo.FP8, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + col_ref = build_reference( + in_features, + hidden, + QuantAlgo.FP8, + weights=col_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_ref = build_reference( + hidden, + out_features, + QuantAlgo.FP8, + weights=row_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + + x = torch.randn(2, in_features, device="cuda") + + partial_outputs = [] + for rank in range(tp_size): + col_out = col_linears[rank](x) + row_out = row_linears[rank](col_out) + partial_outputs.append(row_out) + result = sum(partial_outputs) + + expected = row_ref(col_ref(x)) + torch.testing.assert_close(result, expected, rtol=1e-3, atol=1e-3) + + +FP8R = QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN + + +@pytest.mark.skipif( + get_sm_version() != 90, + reason="FP8 rowwise is supported on Hopper GPUs", +) +class TestFP8RowwiseMLP: + """FP8 Rowwise: unified input → ColumnParallel → RowParallel → sum.""" + + @pytest.mark.parametrize( + "in_features,hidden,out_features,tp_size", + [ + (32, 32, 32, 2), + (32, 64, 32, 4), + (64, 48, 64, 3), + ], + ) + def test_pipeline(self, in_features, hidden, out_features, tp_size): + col_linears, col_weights = build_linears( + in_features, + hidden, + tp_size, + FP8R, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_linears, row_weights = build_linears( + hidden, + out_features, + tp_size, + FP8R, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + col_ref = build_reference( + in_features, + hidden, + FP8R, + weights=col_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_ref = build_reference( + hidden, + out_features, + FP8R, + weights=row_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + + # fp8 input bypasses dynamic per-token quantization in both column + # and row linears. Scale keeps values in fp8 normal range. + x = (torch.randn(2, in_features, device="cuda") * 0.1).to(torch.float8_e4m3fn) + + # Column outputs bf16; cast to fp8 before row to avoid requantization + partial_outputs = [] + for rank in range(tp_size): + col_out = col_linears[rank](x).to(torch.float8_e4m3fn) + row_out = row_linears[rank](col_out) + partial_outputs.append(row_out) + result = sum(partial_outputs) + + col_ref_out = col_ref(x).to(torch.float8_e4m3fn) + expected = row_ref(col_ref_out) + # atol accounts for bf16 accumulation order differences between + # sharded and full GEMM (max observed diff ~0.008) + torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + + +FP8BS = QuantAlgo.FP8_BLOCK_SCALES + +# All shard boundaries must be 128-aligned (scale_span=128 assertion in +# load_shard). The tp_size=3 case distributes five 128-blocks as 2,2,1. +FP8BS_PIPELINE_CASES = [ + (256, 256, 256, 2), + (512, 512, 512, 4), + (640, 640, 640, 3), +] + + +@pytest.mark.skipif( + not (get_sm_version() == 90 or is_sm_100f()), + reason="FP8 block scales are supported on Hopper and SM 100 family GPUs", +) +class TestFP8BlockScalesMLP: + """FP8 Block Scales: column → row pipeline.""" + + @pytest.mark.parametrize("in_features,hidden,out_features,tp_size", FP8BS_PIPELINE_CASES) + def test_pipeline(self, in_features, hidden, out_features, tp_size): + col_linears, col_weights = build_linears( + in_features, + hidden, + tp_size, + FP8BS, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_linears, row_weights = build_linears( + hidden, + out_features, + tp_size, + FP8BS, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + col_ref = build_reference( + in_features, + hidden, + FP8BS, + weights=col_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_ref = build_reference( + hidden, + out_features, + FP8BS, + weights=row_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + + x = torch.randn(2, in_features, device="cuda", dtype=torch.bfloat16) + + partial_outputs = [] + for rank in range(tp_size): + col_out = col_linears[rank](x) + row_out = row_linears[rank](col_out) + partial_outputs.append(row_out) + result = sum(partial_outputs) + + expected = row_ref(col_ref(x)) + torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif( + not is_sm_100f(), + reason="This test is only supported on SM 100 family GPUs", +) +class TestNVFP4MLP: + """NVFP4: column → row pipeline. ROW requires 16-aligned shard boundaries.""" + + @pytest.mark.parametrize( + "in_features,hidden,out_features,tp_size", + [ + (256, 256, 256, 2), # even + (256, 256, 256, 3), # uneven: ROW shards 16 blocks → 6,5,5 + (256, 256, 256, 4), # even + ], + ) + def test_pipeline(self, in_features, hidden, out_features, tp_size): + col_linears, col_weights = build_linears( + in_features, + hidden, + tp_size, + QuantAlgo.NVFP4, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_linears, row_weights = build_linears( + hidden, + out_features, + tp_size, + QuantAlgo.NVFP4, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + col_ref = build_reference( + in_features, + hidden, + QuantAlgo.NVFP4, + weights=col_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_ref = build_reference( + hidden, + out_features, + QuantAlgo.NVFP4, + weights=row_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + + x = torch.randn(2, in_features, device="cuda", dtype=torch.bfloat16) + + partial_outputs = [] + for rank in range(tp_size): + col_out = col_linears[rank](x) + row_out = row_linears[rank](col_out) + partial_outputs.append(row_out) + result = sum(partial_outputs) + + expected = row_ref(col_ref(x)) + torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif( + not is_sm_100f(), + reason="This test is only supported on SM 100 family GPUs", +) +class TestW4A8MXFP4FP8MLP: + """W4A8 MXFP4/FP8: column → row pipeline. + + CUTLASS MXFP8xMXFP4 kernel requires shard dims divisible by 128. + Uneven test uses explicit overrides with 128-aligned splits. + """ + + def test_pipeline_even(self): + self._run_pipeline(256, 2) + + def test_pipeline_uneven(self): + overrides = [(0, 256), (256, 512), (512, 640)] + self._run_pipeline(640, 3, overrides=overrides) + + def _run_pipeline(self, dim, tp_size, overrides=None): + col_linears, col_weights = build_linears( + dim, + dim, + tp_size, + QuantAlgo.W4A8_MXFP4_FP8, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + overrides=overrides, + ) + row_linears, row_weights = build_linears( + dim, + dim, + tp_size, + QuantAlgo.W4A8_MXFP4_FP8, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + overrides=overrides, + ) + col_ref = build_reference( + dim, + dim, + QuantAlgo.W4A8_MXFP4_FP8, + weights=col_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_ref = build_reference( + dim, + dim, + QuantAlgo.W4A8_MXFP4_FP8, + weights=row_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + x = torch.randn(2, dim, device="cuda", dtype=torch.bfloat16) + partial_outputs = [] + for rank in range(tp_size): + partial_outputs.append(row_linears[rank](col_linears[rank](x))) + result = sum(partial_outputs) + expected = row_ref(col_ref(x)) + torch.testing.assert_close(result, expected, rtol=0.2, atol=0.2) + + +@pytest.mark.skipif( + not is_sm_100f(), + reason="This test is only supported on SM 100 family GPUs", +) +class TestW4A8NVFP4FP8MLP: + """W4A8 NVFP4/FP8: column → row pipeline. + + Uses synthetic weights with reinterpreted scale dtype. + Same 128-aligned override requirement as MXFP4. + """ + + def test_pipeline_even(self): + self._run_pipeline(256, 2) + + def test_pipeline_uneven(self): + overrides = [(0, 256), (256, 512), (512, 640)] + self._run_pipeline(640, 3, overrides=overrides) + + def _run_pipeline(self, dim, tp_size, overrides=None): + col_linears, col_weights = build_linears( + dim, + dim, + tp_size, + QuantAlgo.W4A8_NVFP4_FP8, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + overrides=overrides, + ) + row_linears, row_weights = build_linears( + dim, + dim, + tp_size, + QuantAlgo.W4A8_NVFP4_FP8, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + overrides=overrides, + ) + col_ref = build_reference( + dim, + dim, + QuantAlgo.W4A8_NVFP4_FP8, + weights=col_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_ref = build_reference( + dim, + dim, + QuantAlgo.W4A8_NVFP4_FP8, + weights=row_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + x = torch.randn(2, dim, device="cuda", dtype=torch.bfloat16) + partial_outputs = [] + for rank in range(tp_size): + partial_outputs.append(row_linears[rank](col_linears[rank](x))) + result = sum(partial_outputs) + expected = row_ref(col_ref(x)) + torch.testing.assert_close(result, expected, rtol=0.2, atol=0.2) + + +@pytest.mark.skipif( + not is_sm_100f(), + reason="This test is only supported on SM 100 family GPUs", +) +class TestW4A8MXFP4MXFP8MLP: + """W4A8 MXFP4/MXFP8: inherits W4A8MXFP4FP8, uses mxfp8_quantize for activation.""" + + def test_pipeline_even(self): + self._run_pipeline(256, 2) + + def test_pipeline_uneven(self): + overrides = [(0, 256), (256, 512), (512, 640)] + self._run_pipeline(640, 3, overrides=overrides) + + def _run_pipeline(self, dim, tp_size, overrides=None): + algo = QuantAlgo.W4A8_MXFP4_MXFP8 + col_linears, col_weights = build_linears( + dim, + dim, + tp_size, + algo, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + overrides=overrides, + ) + row_linears, row_weights = build_linears( + dim, + dim, + tp_size, + algo, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + overrides=overrides, + ) + col_ref = build_reference( + dim, + dim, + algo, + weights=col_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_ref = build_reference( + dim, + dim, + algo, + weights=row_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + x = torch.randn(2, dim, device="cuda", dtype=torch.bfloat16) + partial_outputs = [] + for rank in range(tp_size): + partial_outputs.append(row_linears[rank](col_linears[rank](x))) + result = sum(partial_outputs) + expected = row_ref(col_ref(x)) + torch.testing.assert_close(result, expected, rtol=0.2, atol=0.2) + + +@pytest.mark.skipif( + get_sm_version() < 80, + reason="Weight-only INT8/INT4 is supported on Ampere+ GPUs", +) +class TestWeightOnlyQuantMLP: + """Weight-only INT8 and INT4 quantization.""" + + @pytest.mark.parametrize("quant_algo", [QuantAlgo.W8A16, QuantAlgo.W4A16]) + @pytest.mark.parametrize( + "in_features,hidden,out_features,tp_size", + [ + (256, 256, 256, 2), # even + (256, 256, 256, 3), # uneven + ], + ) + def test_pipeline(self, in_features, hidden, out_features, tp_size, quant_algo): + col_linears, col_weights = build_linears( + in_features, + hidden, + tp_size, + quant_algo, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_linears, row_weights = build_linears( + hidden, + out_features, + tp_size, + quant_algo, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + col_ref = build_reference( + in_features, + hidden, + quant_algo, + weights=col_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_ref = build_reference( + hidden, + out_features, + quant_algo, + weights=row_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + + x = torch.randn(2, in_features, device="cuda", dtype=torch.float16) + + partial_outputs = [] + for rank in range(tp_size): + col_out = col_linears[rank](x) + row_out = row_linears[rank](col_out) + partial_outputs.append(row_out) + result = sum(partial_outputs) + + expected = row_ref(col_ref(x)) + torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + + +class _AWQMLPMixin: + quant_algo = None + + @pytest.mark.parametrize( + "dim,tp_size", + [ + (256, 2), # even + (640, 3), # uneven: 128-group shards -> 256,256,128 + ], + ) + def test_pipeline(self, dim, tp_size): + quant_algo = self.quant_algo + col_linears, col_weights = build_linears( + dim, + dim, + tp_size, + quant_algo, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_linears, row_weights = build_linears( + dim, + dim, + tp_size, + quant_algo, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + col_ref = build_reference( + dim, + dim, + quant_algo, + weights=col_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.COLUMN, + ) + row_ref = build_reference( + dim, + dim, + quant_algo, + weights=row_weights, + bias=True, + tensor_parallel_mode=TensorParallelMode.ROW, + ) + + x = torch.randn(2, dim, device="cuda", dtype=DEFAULT_DTYPES[quant_algo]) + partial_outputs = [] + for rank in range(tp_size): + partial_outputs.append(row_linears[rank](col_linears[rank](x))) + result = sum(partial_outputs) + expected = row_ref(col_ref(x)) + torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif( + get_sm_version() < 80, + reason="W4A16 AWQ is supported on Ampere+ GPUs", +) +class TestW4A16AWQMLP(_AWQMLPMixin): + """W4A16 AWQ with grouped scales.""" + + quant_algo = QuantAlgo.W4A16_AWQ + + +@pytest.mark.skipif( + not (get_sm_version() in (89, 90) or is_sm_100f()), + reason="W4A8 AWQ is supported on Ada, Hopper, and SM 100 family GPUs", +) +class TestW4A8AWQMLP(_AWQMLPMixin): + """W4A8 AWQ with grouped scales.""" + + quant_algo = QuantAlgo.W4A8_AWQ + + +FUSED_QUANT_ALGO = QuantAlgo.W4A8_MXFP4_FP8 + + +class TestFusedLinearLoading: + """Fused QKV and Gate/Up loading for legacy even and override uneven TP.""" + + @pytest.mark.parametrize( + "weight_mode,quant_algo,sub_out,tp_size", + [ + (WeightMode.FUSED_QKV_LINEAR, QuantAlgo.NO_QUANT, 96, 3), + (WeightMode.FUSED_GATE_UP_LINEAR, QuantAlgo.NO_QUANT, 96, 3), + ], + ) + def test_even_no_override(self, weight_mode, quant_algo, sub_out, tp_size): + shard_keys = weight_mode.shard_keys + linears, weights = build_fused_linears( + 256, + sub_out, + tp_size, + quant_algo, + weight_mode, + shard_keys, + ) + ranges = [] + for rank in range(tp_size): + start, end = _legacy_even_slice(sub_out, tp_size, rank) + ranges.append({key: (start, end) for key in shard_keys}) + + ref = build_fused_reference(256, sub_out, quant_algo, weight_mode, shard_keys, weights) + _check_fused_weight_reconstruction(linears, weights, shard_keys, ranges) + _check_fused_forward(linears, ref, shard_keys, ranges, quant_algo) + if quant_algo != QuantAlgo.NO_QUANT: + for linear in linears: + assert linear.weight_scale.numel() > 0 + + @pytest.mark.parametrize("quant_algo", [QuantAlgo.NO_QUANT]) + @pytest.mark.parametrize( + "weight_mode", + [ + WeightMode.FUSED_QKV_LINEAR, + WeightMode.FUSED_GATE_UP_LINEAR, + ], + ) + def test_uneven_override(self, weight_mode, quant_algo): + shard_keys = weight_mode.shard_keys + sub_out = 640 + tp_size = 3 + boundaries = [(0, 256), (256, 512), (512, 640)] + overrides = [{key: boundary for key in shard_keys} for boundary in boundaries] + linears, weights = build_fused_linears( + 256, + sub_out, + tp_size, + quant_algo, + weight_mode, + shard_keys, + overrides=overrides, + ) + ranges = [{key: boundary for key in shard_keys} for boundary in boundaries] + + ref = build_fused_reference(256, sub_out, quant_algo, weight_mode, shard_keys, weights) + _check_fused_weight_reconstruction(linears, weights, shard_keys, ranges) + _check_fused_forward(linears, ref, shard_keys, ranges, quant_algo) + if quant_algo != QuantAlgo.NO_QUANT: + for linear in linears: + assert linear.weight_scale.numel() > 0 + + @pytest.mark.parametrize("quant_algo", [QuantAlgo.NO_QUANT]) + @pytest.mark.parametrize( + "weight_mode", + [ + WeightMode.FUSED_QKV_LINEAR, + WeightMode.FUSED_GATE_UP_LINEAR, + ], + ) + def test_uneven_override_partial_loading(self, weight_mode, quant_algo): + shard_keys = weight_mode.shard_keys + sub_out = 640 + tp_size = 3 + boundaries = [(0, 256), (256, 512), (512, 640)] + overrides = [{key: boundary for key in shard_keys} for boundary in boundaries] + linears, weights = build_fused_linears( + 256, + sub_out, + tp_size, + quant_algo, + weight_mode, + shard_keys, + overrides=overrides, + allow_partial_loading=True, + ) + ranges = [{key: boundary for key in shard_keys} for boundary in boundaries] + + ref = build_fused_reference(256, sub_out, quant_algo, weight_mode, shard_keys, weights) + _check_fused_weight_reconstruction(linears, weights, shard_keys, ranges) + _check_fused_forward(linears, ref, shard_keys, ranges, quant_algo) + if quant_algo != QuantAlgo.NO_QUANT: + for linear in linears: + assert linear.weight_scale.numel() > 0 + + +@pytest.mark.skipif( + not is_sm_100f(), + reason="Fused FP4/NVFP4 loading is supported on SM 100 family GPUs", +) +class TestFusedQuantizedLinearLoading: + """Quantized fused QKV and Gate/Up loading for override uneven TP.""" + + @pytest.mark.parametrize( + "weight_mode,quant_algo,sub_out,tp_size", + [ + (WeightMode.FUSED_QKV_LINEAR, FUSED_QUANT_ALGO, 256, 2), + (WeightMode.FUSED_GATE_UP_LINEAR, FUSED_QUANT_ALGO, 256, 2), + ], + ) + def test_even_no_override(self, weight_mode, quant_algo, sub_out, tp_size): + shard_keys = weight_mode.shard_keys + linears, weights = build_fused_linears( + 256, + sub_out, + tp_size, + quant_algo, + weight_mode, + shard_keys, + ) + ranges = [] + for rank in range(tp_size): + start, end = _legacy_even_slice(sub_out, tp_size, rank) + ranges.append({key: (start, end) for key in shard_keys}) + + ref = build_fused_reference(256, sub_out, quant_algo, weight_mode, shard_keys, weights) + _check_fused_weight_reconstruction(linears, weights, shard_keys, ranges) + _check_fused_forward(linears, ref, shard_keys, ranges, quant_algo) + for linear in linears: + assert linear.weight_scale.numel() > 0 + + @pytest.mark.parametrize("quant_algo", [FUSED_QUANT_ALGO]) + @pytest.mark.parametrize( + "weight_mode", + [ + WeightMode.FUSED_QKV_LINEAR, + WeightMode.FUSED_GATE_UP_LINEAR, + ], + ) + def test_uneven_override(self, weight_mode, quant_algo): + shard_keys = weight_mode.shard_keys + sub_out = 640 + tp_size = 3 + boundaries = [(0, 256), (256, 512), (512, 640)] + overrides = [{key: boundary for key in shard_keys} for boundary in boundaries] + linears, weights = build_fused_linears( + 256, + sub_out, + tp_size, + quant_algo, + weight_mode, + shard_keys, + overrides=overrides, + ) + ranges = [{key: boundary for key in shard_keys} for boundary in boundaries] + + ref = build_fused_reference(256, sub_out, quant_algo, weight_mode, shard_keys, weights) + _check_fused_weight_reconstruction(linears, weights, shard_keys, ranges) + _check_fused_forward(linears, ref, shard_keys, ranges, quant_algo) + for linear in linears: + assert linear.weight_scale.numel() > 0 + + @pytest.mark.parametrize("quant_algo", [QuantAlgo.NVFP4]) + @pytest.mark.parametrize( + "weight_mode", + [ + WeightMode.FUSED_QKV_LINEAR, + WeightMode.FUSED_GATE_UP_LINEAR, + ], + ) + def test_uneven_override_partial_loading(self, weight_mode, quant_algo): + shard_keys = weight_mode.shard_keys + sub_out = 640 + tp_size = 3 + boundaries = [(0, 256), (256, 512), (512, 640)] + overrides = [{key: boundary for key in shard_keys} for boundary in boundaries] + linears, weights = build_fused_linears( + 256, + sub_out, + tp_size, + quant_algo, + weight_mode, + shard_keys, + overrides=overrides, + allow_partial_loading=True, + ) + ranges = [{key: boundary for key in shard_keys} for boundary in boundaries] + + ref = build_fused_reference(256, sub_out, quant_algo, weight_mode, shard_keys, weights) + _check_fused_weight_reconstruction(linears, weights, shard_keys, ranges) + _check_fused_forward(linears, ref, shard_keys, ranges, quant_algo) + for linear in linears: + assert linear.weight_scale.numel() > 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_tp.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_tp.py index 54b462a4cc62..93e4a2474a6d 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_tp.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_flux_tp.py @@ -49,6 +49,13 @@ from tensorrt_llm._utils import get_free_port from tensorrt_llm.models.modeling_utils import QuantConfig + from .tp_shard_utils import ( + copy_tp_parameter, + shard_dim1, + shard_fused_gate_up, + shard_fused_qkv_by_heads, + ) + MODULES_AVAILABLE = True except ImportError: MODULES_AVAILABLE = False @@ -150,6 +157,18 @@ def run_test_in_distributed(world_size: int, test_fn: Callable, use_cuda: bool = timestep_guidance_channels=256, ) +# TP=3 uneven configs: 8 heads already gives an uneven 3+3+2 attention split. +_FLUX1_UNEVEN_TP3_CONFIG = { + **_FLUX1_TEST_CONFIG, + # FLUX.1 FFN intermediate is 512 * 4 = 2048, also uneven over TP=3. +} + +_FLUX2_UNEVEN_TP3_CONFIG = { + **_FLUX2_TEST_CONFIG, + "mlp_ratio": 3.5, + # Keep heads unchanged; make FLUX.2 MLP hidden dim 512 * 3.5 = 1792, uneven over TP=3. +} + def _make_model_config(pretrained_dict, tp_size=1, ulysses_size=1, backend="VANILLA"): """Create DiffusionModelConfig for testing with TP and/or Ulysses.""" @@ -195,72 +214,30 @@ def _stabilize_model_weights(model): # ============================================================================= -# TP weight sharding helpers +# TP weight sharding helpers (see tp_shard_utils.py) # ============================================================================= -def _shard_dim0(tensor, tp_rank, tp_size): - """Shard a tensor along dim 0.""" - chunk = tensor.shape[0] // tp_size - return tensor[tp_rank * chunk : (tp_rank + 1) * chunk].contiguous() - - -def _shard_dim1(tensor, tp_rank, tp_size): - """Shard a tensor along dim 1.""" - chunk = tensor.shape[1] // tp_size - return tensor[:, tp_rank * chunk : (tp_rank + 1) * chunk].contiguous() - - -def _shard_fused_qkv(tensor, tp_rank, tp_size, q_dim, kv_dim): - """Shard a fused QKV weight [q_dim + 2*kv_dim, ...] preserving Q/K/V structure.""" - q, k, v = tensor.split([q_dim, kv_dim, kv_dim], dim=0) - return torch.cat( - [ - _shard_dim0(q, tp_rank, tp_size), - _shard_dim0(k, tp_rank, tp_size), - _shard_dim0(v, tp_rank, tp_size), - ], - dim=0, - ) - - -def _shard_fused_gate_up(tensor, tp_rank, tp_size): - """Shard a fused gate_up weight [2*intermediate, ...] preserving gate/up structure.""" - half = tensor.shape[0] // 2 - gate, up = tensor.split([half, half], dim=0) - return torch.cat( - [ - _shard_dim0(gate, tp_rank, tp_size), - _shard_dim0(up, tp_rank, tp_size), - ], - dim=0, - ) - - -def _copy_ref_weights_to_tp(ref_model, tp_model, tp_rank, tp_size): - """Copy weights from a TP=1 reference model into a TP model with correct sharding. - - Handles column-parallel (QKV, MLP up/gate), row-parallel (output projs), - fused QKV/gate_up weights, and wrapper projectors (FluxJointAttnMLPProj, - FluxJointQKVMLPProj) that have different sub-module structure at TP>1. - """ +def _copy_ref_weights_to_tp(ref_model, tp_model, tp_rank, tp_size, config_dict): + """Copy weights from a TP=1 reference model into a TP model with correct sharding.""" ref_params = dict(ref_model.named_parameters()) - - # First handle wrapper projectors whose sub-module names differ between TP=1 and TP>1. - # At TP=1: single .proj Linear. At TP>1: split into sub-Linears. + num_heads = config_dict["num_attention_heads"] + head_dim = config_dict["attention_head_dim"] + vgm = getattr(tp_model.model_config, "visual_gen_mapping", None) + ulysses_size = vgm.ulysses_size if vgm is not None else 1 handled_tp_params = set() for tp_name, tp_module in tp_model.named_modules(): if isinstance(tp_module, FluxJointAttnMLPProj) and tp_module.tp_size > 1: - # TP model has .attn_proj + .mlp_proj; ref has .proj - ref_w = ref_params[f"{tp_name}.proj.weight"] # [out, attn_dim + mlp_dim] + ref_w = ref_params[f"{tp_name}.proj.weight"] w_attn = ref_w[:, : tp_module.attn_dim] w_mlp = ref_w[:, tp_module.attn_dim :] + attn_start, attn_end = tp_module.attn_shard tp_model.get_parameter(f"{tp_name}.attn_proj.weight").data.copy_( - _shard_dim1(w_attn, tp_rank, tp_size) + w_attn[:, attn_start:attn_end].contiguous() ) tp_model.get_parameter(f"{tp_name}.mlp_proj.weight").data.copy_( - _shard_dim1(w_mlp, tp_rank, tp_size) + shard_dim1(w_mlp, tp_rank, tp_size) ) handled_tp_params.update( [ @@ -274,20 +251,25 @@ def _copy_ref_weights_to_tp(ref_model, tp_model, tp_rank, tp_size): handled_tp_params.add(f"{tp_name}.bias") elif isinstance(tp_module, FluxJointQKVMLPProj) and tp_module.tp_size > 1: - # TP model has .qkv_proj + .mlp_proj; ref has .proj - ref_w = ref_params[f"{tp_name}.proj.weight"] # [qkv+mlp, hidden] + ref_w = ref_params[f"{tp_name}.proj.weight"] w_qkv = ref_w[: tp_module.full_qkv_dim] w_mlp = ref_w[tp_module.full_qkv_dim :] - # QKV: split into Q/K/V, shard each, re-fuse tp_model.get_parameter(f"{tp_name}.qkv_proj.weight").data.copy_( - _shard_fused_qkv( - w_qkv, tp_rank, tp_size, tp_module.full_q_dim, tp_module.full_kv_dim + shard_fused_qkv_by_heads( + w_qkv, + tp_rank, + tp_size, + num_heads, + num_heads, + head_dim, + tp_module.full_q_dim, + tp_module.full_kv_dim, + ulysses_size, ) ) - # MLP: split gate/up, shard each, re-fuse tp_model.get_parameter(f"{tp_name}.mlp_proj.weight").data.copy_( - _shard_fused_gate_up(w_mlp, tp_rank, tp_size) + shard_fused_gate_up(w_mlp, tp_rank, tp_size) ) handled_tp_params.update( [ @@ -295,18 +277,25 @@ def _copy_ref_weights_to_tp(ref_model, tp_model, tp_rank, tp_size): f"{tp_name}.mlp_proj.weight", ] ) - # Handle bias if present if f"{tp_name}.proj.bias" in ref_params: ref_b = ref_params[f"{tp_name}.proj.bias"] b_qkv = ref_b[: tp_module.full_qkv_dim] b_mlp = ref_b[tp_module.full_qkv_dim :] tp_model.get_parameter(f"{tp_name}.qkv_proj.bias").data.copy_( - _shard_fused_qkv( - b_qkv, tp_rank, tp_size, tp_module.full_q_dim, tp_module.full_kv_dim + shard_fused_qkv_by_heads( + b_qkv, + tp_rank, + tp_size, + num_heads, + num_heads, + head_dim, + tp_module.full_q_dim, + tp_module.full_kv_dim, + ulysses_size, ) ) tp_model.get_parameter(f"{tp_name}.mlp_proj.bias").data.copy_( - _shard_fused_gate_up(b_mlp, tp_rank, tp_size) + shard_fused_gate_up(b_mlp, tp_rank, tp_size) ) handled_tp_params.update( [ @@ -315,49 +304,20 @@ def _copy_ref_weights_to_tp(ref_model, tp_model, tp_rank, tp_size): ] ) - # Now handle all remaining parameters by shape comparison. with torch.no_grad(): for tp_name, tp_param in tp_model.named_parameters(): - if tp_name in handled_tp_params: - continue - if tp_name not in ref_params: + if tp_name in handled_tp_params or tp_name not in ref_params: continue - - ref_param = ref_params[tp_name] - - if tp_param.shape == ref_param.shape: - # Replicated parameter (norms, embeddings, etc.) - tp_param.data.copy_(ref_param.data) - elif tp_param.ndim >= 2 and tp_param.shape[1] == ref_param.shape[1]: - # Column parallel: dim 0 is smaller (output dim sharded) - if "qkv_proj" in tp_name or "add_qkv_proj" in tp_name: - # Fused QKV: figure out q_dim from total (q=k=v for FLUX) - q_dim = ref_param.shape[0] // 3 - tp_param.data.copy_( - _shard_fused_qkv(ref_param.data, tp_rank, tp_size, q_dim, q_dim) - ) - elif "gate_up_proj" in tp_name: - tp_param.data.copy_(_shard_fused_gate_up(ref_param.data, tp_rank, tp_size)) - else: - tp_param.data.copy_(_shard_dim0(ref_param.data, tp_rank, tp_size)) - elif tp_param.ndim >= 2 and tp_param.shape[0] == ref_param.shape[0]: - # Row parallel: dim 1 is smaller (input dim sharded) - tp_param.data.copy_(_shard_dim1(ref_param.data, tp_rank, tp_size)) - elif tp_param.ndim == 1 and tp_param.shape[0] < ref_param.shape[0]: - # 1D bias for column parallel - if "qkv_proj" in tp_name or "add_qkv_proj" in tp_name: - q_dim = ref_param.shape[0] // 3 - tp_param.data.copy_( - _shard_fused_qkv(ref_param.data, tp_rank, tp_size, q_dim, q_dim) - ) - elif "gate_up_proj" in tp_name: - tp_param.data.copy_(_shard_fused_gate_up(ref_param.data, tp_rank, tp_size)) - else: - tp_param.data.copy_(_shard_dim0(ref_param.data, tp_rank, tp_size)) - else: - raise ValueError( - f"Cannot shard {tp_name}: ref={ref_param.shape}, tp={tp_param.shape}" - ) + copy_tp_parameter( + tp_name, + ref_params[tp_name], + tp_param, + tp_rank, + tp_size, + num_heads, + head_dim, + ulysses_size=ulysses_size, + ) # ============================================================================= @@ -410,6 +370,16 @@ def _logic_flux1_tp_forward(rank, world_size): def _logic_flux1_tp_vs_single_gpu(rank, world_size): """FLUX.1: TP 2-GPU output matches single-GPU reference.""" + _logic_flux1_tp_vs_single_gpu_with_config(rank, world_size, _FLUX1_TEST_CONFIG) + + +def _logic_flux1_tp3_uneven_vs_single_gpu(rank, world_size): + """FLUX.1: TP=3 with uneven head/MLP dims matches single-GPU reference.""" + _logic_flux1_tp_vs_single_gpu_with_config(rank, world_size, _FLUX1_UNEVEN_TP3_CONFIG) + + +def _logic_flux1_tp_vs_single_gpu_with_config(rank, world_size, config_dict): + """FLUX.1: TP output matches single-GPU reference.""" from tensorrt_llm._torch.visual_gen.models.flux.transformer_flux import FluxTransformer2DModel device = torch.device(f"cuda:{rank}") @@ -418,22 +388,25 @@ def _logic_flux1_tp_vs_single_gpu(rank, world_size): batch = 1 img_seq = 16 txt_seq = 8 + in_channels = 64 # Create single-GPU reference model torch.manual_seed(123) - ref_config = _make_model_config(_FLUX1_TEST_CONFIG, tp_size=1) + ref_config = _make_model_config(config_dict, tp_size=1) ref_model = FluxTransformer2DModel(ref_config).to(device).to(compute_dtype) _stabilize_model_weights(ref_model) # Create TP model and copy sharded weights from ref torch.manual_seed(123) - tp_config = _make_model_config(_FLUX1_TEST_CONFIG, tp_size=world_size) + tp_config = _make_model_config(config_dict, tp_size=world_size) tp_model = FluxTransformer2DModel(tp_config).to(device).to(compute_dtype) - _copy_ref_weights_to_tp(ref_model, tp_model, rank, world_size) + _copy_ref_weights_to_tp(ref_model, tp_model, rank, world_size, config_dict) # Same inputs on all ranks torch.manual_seed(456) - hidden_states = torch.randn(batch, img_seq, 64, device=device, dtype=compute_dtype) * 0.1 + hidden_states = ( + torch.randn(batch, img_seq, in_channels, device=device, dtype=compute_dtype) * 0.1 + ) encoder_hidden_states = ( torch.randn(batch, txt_seq, 256, device=device, dtype=compute_dtype) * 0.1 ) @@ -517,6 +490,16 @@ def _logic_flux2_tp_forward(rank, world_size): def _logic_flux2_tp_vs_single_gpu(rank, world_size): """FLUX.2: TP 2-GPU output matches single-GPU reference.""" + _logic_flux2_tp_vs_single_gpu_with_config(rank, world_size, _FLUX2_TEST_CONFIG) + + +def _logic_flux2_tp3_uneven_vs_single_gpu(rank, world_size): + """FLUX.2: TP=3 with uneven head/MLP dims matches single-GPU reference.""" + _logic_flux2_tp_vs_single_gpu_with_config(rank, world_size, _FLUX2_UNEVEN_TP3_CONFIG) + + +def _logic_flux2_tp_vs_single_gpu_with_config(rank, world_size, config_dict): + """FLUX.2: TP output matches single-GPU reference.""" from tensorrt_llm._torch.visual_gen.models.flux.transformer_flux2 import Flux2Transformer2DModel device = torch.device(f"cuda:{rank}") @@ -525,22 +508,25 @@ def _logic_flux2_tp_vs_single_gpu(rank, world_size): batch = 1 img_seq = 16 txt_seq = 8 + in_channels = 128 # Create single-GPU reference model torch.manual_seed(123) - ref_config = _make_model_config(_FLUX2_TEST_CONFIG, tp_size=1) + ref_config = _make_model_config(config_dict, tp_size=1) ref_model = Flux2Transformer2DModel(ref_config).to(device).to(compute_dtype) _stabilize_model_weights(ref_model) # Create TP model and copy sharded weights from ref torch.manual_seed(123) - tp_config = _make_model_config(_FLUX2_TEST_CONFIG, tp_size=world_size) + tp_config = _make_model_config(config_dict, tp_size=world_size) tp_model = Flux2Transformer2DModel(tp_config).to(device).to(compute_dtype) - _copy_ref_weights_to_tp(ref_model, tp_model, rank, world_size) + _copy_ref_weights_to_tp(ref_model, tp_model, rank, world_size, config_dict) # Same inputs on all ranks torch.manual_seed(456) - hidden_states = torch.randn(batch, img_seq, 128, device=device, dtype=compute_dtype) * 0.1 + hidden_states = ( + torch.randn(batch, img_seq, in_channels, device=device, dtype=compute_dtype) * 0.1 + ) encoder_hidden_states = ( torch.randn(batch, txt_seq, 256, device=device, dtype=compute_dtype) * 0.1 ) @@ -608,7 +594,7 @@ def _logic_flux2_tp_ulysses_vs_single_gpu(rank, world_size): ) combined_model = Flux2Transformer2DModel(combined_config).to(device).to(compute_dtype) vgm = combined_config.visual_gen_mapping - _copy_ref_weights_to_tp(ref_model, combined_model, vgm.tp_rank, tp_size) + _copy_ref_weights_to_tp(ref_model, combined_model, vgm.tp_rank, tp_size, _FLUX2_TEST_CONFIG) # Same inputs on all ranks (Ulysses shards at runtime) torch.manual_seed(456) @@ -682,5 +668,17 @@ def test_flux2_tp_ulysses_vs_single_gpu(self): run_test_in_distributed(world_size=4, test_fn=_logic_flux2_tp_ulysses_vs_single_gpu) +class TestFluxUnevenTP3: + """TP=3 tests where head count and MLP dims are not divisible by tp_size.""" + + def test_flux1_tp3_uneven_vs_single_gpu(self): + """FLUX.1 TP=3 (8 heads, uneven MLP) matches single-GPU reference.""" + run_test_in_distributed(world_size=3, test_fn=_logic_flux1_tp3_uneven_vs_single_gpu) + + def test_flux2_tp3_uneven_vs_single_gpu(self): + """FLUX.2 TP=3 (8 heads, uneven MLP) matches single-GPU reference.""" + run_test_in_distributed(world_size=3, test_fn=_logic_flux2_tp3_uneven_vs_single_gpu) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_tp_attention.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_tp_attention.py index 4783dcd4db1a..29e0dba6a9fc 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_tp_attention.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_tp_attention.py @@ -155,70 +155,75 @@ def _shard_tp_weights(ref_attn, tp_attn, tp_rank, tp_size, qkv_mode=QKVMode.FUSE RMSNorm (TP-enabled): split weight """ with torch.no_grad(): - if qkv_mode == QKVMode.FUSE_QKV: - # Fused QKV: weight is [q_dim + 2*kv_dim, hidden_size] - full_w = ref_attn.qkv_proj.weight.data - q_dim = ref_attn.q_dim - kv_dim = ref_attn.kv_dim - q_w, k_w, v_w = full_w.split([q_dim, kv_dim, kv_dim], dim=0) + q_start, q_end = tp_attn.local_q_dim_start, tp_attn.local_q_dim_end + kv_start, kv_end = tp_attn.local_kv_dim_start, tp_attn.local_kv_dim_end - q_shard = _shard_dim0(q_w, tp_rank, tp_size) - k_shard = _shard_dim0(k_w, tp_rank, tp_size) - v_shard = _shard_dim0(v_w, tp_rank, tp_size) - tp_attn.qkv_proj.weight.data.copy_(torch.cat([q_shard, k_shard, v_shard], dim=0)) + if qkv_mode == QKVMode.FUSE_QKV: + q_w, k_w, v_w = ref_attn.qkv_proj.weight.data.split( + [ref_attn.q_dim, ref_attn.kv_dim, ref_attn.kv_dim], dim=0 + ) + tp_attn.qkv_proj.weight.data.copy_( + torch.cat( + [ + q_w[q_start:q_end], + k_w[kv_start:kv_end], + v_w[kv_start:kv_end], + ], + dim=0, + ).contiguous() + ) if ref_attn.qkv_proj.bias is not None: - full_b = ref_attn.qkv_proj.bias.data - q_b, k_b, v_b = full_b.split([q_dim, kv_dim, kv_dim], dim=0) + q_b, k_b, v_b = ref_attn.qkv_proj.bias.data.split( + [ref_attn.q_dim, ref_attn.kv_dim, ref_attn.kv_dim], dim=0 + ) tp_attn.qkv_proj.bias.data.copy_( torch.cat( [ - _shard_dim0(q_b, tp_rank, tp_size), - _shard_dim0(k_b, tp_rank, tp_size), - _shard_dim0(v_b, tp_rank, tp_size), + q_b[q_start:q_end], + k_b[kv_start:kv_end], + v_b[kv_start:kv_end], ], dim=0, - ) + ).contiguous() ) else: - for name in ("to_q", "to_k", "to_v"): + for name, bounds in ( + ("to_q", (q_start, q_end)), + ("to_k", (kv_start, kv_end)), + ("to_v", (kv_start, kv_end)), + ): ref_proj = getattr(ref_attn, name) tp_proj = getattr(tp_attn, name) - tp_proj.weight.data.copy_(_shard_dim0(ref_proj.weight.data, tp_rank, tp_size)) + start, end = bounds + tp_proj.weight.data.copy_(ref_proj.weight.data[start:end].contiguous()) if ref_proj.bias is not None: - tp_proj.bias.data.copy_(_shard_dim0(ref_proj.bias.data, tp_rank, tp_size)) + tp_proj.bias.data.copy_(ref_proj.bias.data[start:end].contiguous()) - # Output projection: row-parallel (split input dim = dim 1) + # Output projection: row-parallel (split input dim = dim 1, head-aligned) ref_out = ref_attn.to_out[0] tp_out = tp_attn.to_out[0] - shard_size = math.ceil(ref_out.weight.shape[1] / tp_size) - start = tp_rank * shard_size - end = min(start + shard_size, ref_out.weight.shape[1]) - tp_out.weight.data.copy_(ref_out.weight.data[:, start:end].contiguous()) + q_start, q_end = tp_attn.local_q_dim_start, tp_attn.local_q_dim_end + tp_out.weight.data.copy_(ref_out.weight.data[:, q_start:q_end].contiguous()) if ref_out.bias is not None: tp_out.bias.data.copy_(ref_out.bias.data) - # QK norm weights (if TP-enabled, they're sharded) + # QK norm weights (if TP-enabled, use Attention head-based shard bounds) if hasattr(ref_attn, "norm_q") and hasattr(tp_attn, "norm_q"): if tp_attn.norm_q.enable_tp: - shard_size = ref_attn.norm_q.weight.shape[0] // tp_size - start = tp_rank * shard_size - end = start + shard_size - tp_attn.norm_q.weight.data.copy_(ref_attn.norm_q.weight.data[start:end]) - tp_attn.norm_k.weight.data.copy_(ref_attn.norm_k.weight.data[start:end]) + tp_attn.norm_q.weight.data.copy_( + ref_attn.norm_q.weight.data[tp_attn.local_q_dim_start : tp_attn.local_q_dim_end] + ) + tp_attn.norm_k.weight.data.copy_( + ref_attn.norm_k.weight.data[ + tp_attn.local_kv_dim_start : tp_attn.local_kv_dim_end + ] + ) else: tp_attn.norm_q.weight.data.copy_(ref_attn.norm_q.weight.data) tp_attn.norm_k.weight.data.copy_(ref_attn.norm_k.weight.data) -def _shard_dim0(tensor, tp_rank, tp_size): - """Shard a tensor along dim 0 (works for both 1D bias and 2D weight).""" - shard_size = math.ceil(tensor.shape[0] / tp_size) - start = tp_rank * shard_size - end = min(start + shard_size, tensor.shape[0]) - return tensor[start:end].contiguous() - - # ============================================================================= # Manual F.sdpa reference # ============================================================================= @@ -398,16 +403,18 @@ def _logic_tp_hidden_512(rank, world_size): _run_tp_with_params(rank, world_size, batch=2, seq=16, hidden_size=512, num_heads=4) -def _logic_tp_heads_not_divisible(rank, world_size): - """TP when num_heads % tp_size != 0. Expected to fail until uneven sharding is implemented.""" - _run_tp_with_params(rank, world_size, batch=2, seq=16, hidden_size=320, num_heads=5) +def _logic_tp_size_3_uneven_heads(rank, world_size): + """TP=3 when num_heads and hidden_size are not divisible by tp_size.""" + _run_tp_with_params(rank, world_size, batch=2, seq=16, hidden_size=512, num_heads=8) def _logic_tp_world_size_4(rank, world_size): _run_tp_with_params(rank, world_size, batch=2, seq=16, hidden_size=512, num_heads=16) -def _logic_tp_ulysses_combined(rank, world_size, ulysses_size, tp_size): +def _logic_tp_ulysses_combined( + rank, world_size, ulysses_size, tp_size, hidden_size=512, num_heads=16 +): """TP + Ulysses combined matches F.sdpa reference on the full sequence. 4 GPUs: tp_size=2, ulysses_size=2. @@ -416,8 +423,6 @@ def _logic_tp_ulysses_combined(rank, world_size, ulysses_size, tp_size): assert tp_size * ulysses_size == world_size device = torch.device(f"cuda:{rank}") - hidden_size = 512 - num_heads = 16 head_dim = hidden_size // num_heads batch = 2 seq_per_rank = 8 @@ -459,6 +464,18 @@ def _logic_tp_ulysses_combined(rank, world_size, ulysses_size, tp_size): torch.testing.assert_close(combined_out, expected_shard, rtol=1e-2, atol=1e-2) +def _logic_tp3_ulysses_uneven_combined(rank, world_size): + """TP=3 + Ulysses=2 with 8 attention heads (4+2+2 TP split).""" + _logic_tp_ulysses_combined( + rank, + world_size, + ulysses_size=2, + tp_size=3, + hidden_size=512, + num_heads=8, + ) + + # ============================================================================= # Test classes # ============================================================================= @@ -489,9 +506,9 @@ def test_hidden_128(self): def test_hidden_512(self): _run(2, _logic_tp_hidden_512) - @pytest.mark.xfail(reason="Uneven head sharding not yet implemented", raises=Exception) - def test_tp_heads_not_divisible(self): - _run(2, _logic_tp_heads_not_divisible) + def test_tp_size_3_uneven_heads(self): + """TP=3 with 8 heads (4+2+2 split) matches F.sdpa reference.""" + _run(3, _logic_tp_size_3_uneven_heads) def test_tp_world_size_4(self): _run(4, _logic_tp_world_size_4) @@ -518,6 +535,9 @@ def test_tp_4_ulysses_2(self): world = ulysses_size * tp_size _run(world, _logic_tp_ulysses_combined, ulysses_size, tp_size) + def test_tp_3_ulysses_2_uneven(self): + _run(6, _logic_tp3_ulysses_uneven_combined) + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_tp.py b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_tp.py index ecd98ba203ff..6013b492054e 100644 --- a/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_tp.py +++ b/tests/unittest/_torch/visual_gen/multi_gpu/test_wan_tp.py @@ -45,6 +45,8 @@ from tensorrt_llm._utils import get_free_port from tensorrt_llm.models.modeling_utils import QuantConfig + from .tp_shard_utils import copy_tp_parameter + MODULES_AVAILABLE = True except ImportError: MODULES_AVAILABLE = False @@ -143,6 +145,11 @@ def run_test_in_distributed(world_size: int, test_fn: Callable, use_cuda: bool = added_kv_proj_dim=256, # add_k_proj input dim = hidden_size (image embeds projected to hidden_size before blocks) ) +# Existing WAN configs are already uneven with TP=3: +# 4 attention heads split as 2+1+1, and ffn_dim=512 is not divisible by 3. +_WAN_UNEVEN_TP3_CONFIG = dict(_WAN_T2V_TEST_CONFIG) +_WAN_I2V_UNEVEN_TP3_CONFIG = dict(_WAN_I2V_TEST_CONFIG) + # ============================================================================= # Model config + weight helpers @@ -193,95 +200,32 @@ def _stabilize_model_weights(model): # ============================================================================= -# TP weight sharding helpers +# TP weight sharding helpers (see tp_shard_utils.py) # ============================================================================= -def _shard_dim0(tensor, tp_rank, tp_size): - """Shard a tensor along dim 0.""" - chunk = tensor.shape[0] // tp_size - return tensor[tp_rank * chunk : (tp_rank + 1) * chunk].contiguous() - - -def _shard_dim1(tensor, tp_rank, tp_size): - """Shard a tensor along dim 1.""" - chunk = tensor.shape[1] // tp_size - return tensor[:, tp_rank * chunk : (tp_rank + 1) * chunk].contiguous() - - -def _shard_fused_qkv(tensor, tp_rank, tp_size, q_dim, kv_dim): - """Shard a fused QKV weight [q_dim + 2*kv_dim, ...] preserving Q/K/V structure.""" - q, k, v = tensor.split([q_dim, kv_dim, kv_dim], dim=0) - return torch.cat( - [ - _shard_dim0(q, tp_rank, tp_size), - _shard_dim0(k, tp_rank, tp_size), - _shard_dim0(v, tp_rank, tp_size), - ], - dim=0, - ) - - -def _shard_fused_gate_up(tensor, tp_rank, tp_size): - """Shard a fused gate_up weight [2*intermediate, ...] preserving gate/up structure.""" - half = tensor.shape[0] // 2 - gate, up = tensor.split([half, half], dim=0) - return torch.cat( - [ - _shard_dim0(gate, tp_rank, tp_size), - _shard_dim0(up, tp_rank, tp_size), - ], - dim=0, - ) - - -def _copy_ref_weights_to_tp(ref_model, tp_model, tp_rank, tp_size): - """Copy weights from a TP=1 reference model into a TP model with correct sharding. - - Handles column-parallel (QKV, MLP up/gate), row-parallel (output projs), - fused QKV/gate_up weights, and replicated parameters (norms, embeddings). - """ +def _copy_ref_weights_to_tp(ref_model, tp_model, tp_rank, tp_size, config_dict): + """Copy weights from a TP=1 reference model into a TP model with correct sharding.""" ref_params = dict(ref_model.named_parameters()) + num_heads = config_dict["num_attention_heads"] + head_dim = config_dict["attention_head_dim"] + vgm = getattr(tp_model.model_config, "visual_gen_mapping", None) + ulysses_size = vgm.ulysses_size if vgm is not None else 1 with torch.no_grad(): for tp_name, tp_param in tp_model.named_parameters(): if tp_name not in ref_params: continue - - ref_param = ref_params[tp_name] - - if tp_param.shape == ref_param.shape: - # Replicated parameter (norms, embeddings, etc.) - tp_param.data.copy_(ref_param.data) - elif tp_param.ndim >= 2 and tp_param.shape[1] == ref_param.shape[1]: - # Column parallel: dim 0 is smaller (output dim sharded) - if "qkv_proj" in tp_name or "add_qkv_proj" in tp_name: - q_dim = ref_param.shape[0] // 3 - tp_param.data.copy_( - _shard_fused_qkv(ref_param.data, tp_rank, tp_size, q_dim, q_dim) - ) - elif "gate_up_proj" in tp_name: - tp_param.data.copy_(_shard_fused_gate_up(ref_param.data, tp_rank, tp_size)) - else: - tp_param.data.copy_(_shard_dim0(ref_param.data, tp_rank, tp_size)) - elif tp_param.ndim >= 2 and tp_param.shape[0] == ref_param.shape[0]: - # Row parallel: dim 1 is smaller (input dim sharded) - tp_param.data.copy_(_shard_dim1(ref_param.data, tp_rank, tp_size)) - elif tp_param.ndim == 1 and tp_param.shape[0] < ref_param.shape[0]: - # 1D bias for column parallel - if "qkv_proj" in tp_name or "add_qkv_proj" in tp_name: - q_dim = ref_param.shape[0] // 3 - tp_param.data.copy_( - _shard_fused_qkv(ref_param.data, tp_rank, tp_size, q_dim, q_dim) - ) - elif "gate_up_proj" in tp_name: - tp_param.data.copy_(_shard_fused_gate_up(ref_param.data, tp_rank, tp_size)) - else: - tp_param.data.copy_(_shard_dim0(ref_param.data, tp_rank, tp_size)) - else: - raise ValueError( - f"Cannot shard {tp_name}: ref={ref_param.shape}, tp={tp_param.shape}" - ) + copy_tp_parameter( + tp_name, + ref_params[tp_name], + tp_param, + tp_rank, + tp_size, + num_heads, + head_dim, + ulysses_size=ulysses_size, + ) # ============================================================================= @@ -330,6 +274,16 @@ def _logic_wan_t2v_tp_forward(rank, world_size): def _logic_wan_t2v_tp_vs_single_gpu(rank, world_size): """WAN T2V: TP 2-GPU output matches single-GPU reference.""" + _logic_wan_t2v_tp_vs_single_gpu_with_config(rank, world_size, _WAN_T2V_TEST_CONFIG) + + +def _logic_wan_t2v_tp3_uneven_vs_single_gpu(rank, world_size): + """WAN T2V: TP=3 with uneven head/FFN dims matches single-GPU reference.""" + _logic_wan_t2v_tp_vs_single_gpu_with_config(rank, world_size, _WAN_UNEVEN_TP3_CONFIG) + + +def _logic_wan_t2v_tp_vs_single_gpu_with_config(rank, world_size, config_dict): + """WAN T2V: TP output matches single-GPU reference.""" from tensorrt_llm._torch.visual_gen.models.wan.transformer_wan import WanTransformer3DModel device = torch.device(f"cuda:{rank}") @@ -342,15 +296,15 @@ def _logic_wan_t2v_tp_vs_single_gpu(rank, world_size): # Create single-GPU reference model torch.manual_seed(123) - ref_config = _make_model_config(_WAN_T2V_TEST_CONFIG, tp_size=1) + ref_config = _make_model_config(config_dict, tp_size=1) ref_model = WanTransformer3DModel(ref_config).to(device).to(compute_dtype) _stabilize_model_weights(ref_model) # Create TP model and copy sharded weights from ref torch.manual_seed(123) - tp_config = _make_model_config(_WAN_T2V_TEST_CONFIG, tp_size=world_size) + tp_config = _make_model_config(config_dict, tp_size=world_size) tp_model = WanTransformer3DModel(tp_config).to(device).to(compute_dtype) - _copy_ref_weights_to_tp(ref_model, tp_model, rank, world_size) + _copy_ref_weights_to_tp(ref_model, tp_model, rank, world_size, config_dict) # Same inputs on all ranks torch.manual_seed(456) @@ -410,7 +364,7 @@ def _logic_wan_t2v_tp_ulysses_vs_single_gpu(rank, world_size): ) combined_model = WanTransformer3DModel(combined_config).to(device).to(compute_dtype) vgm = combined_config.visual_gen_mapping - _copy_ref_weights_to_tp(ref_model, combined_model, vgm.tp_rank, tp_size) + _copy_ref_weights_to_tp(ref_model, combined_model, vgm.tp_rank, tp_size, _WAN_T2V_TEST_CONFIG) # Same inputs on all ranks (Ulysses shards at runtime) torch.manual_seed(456) @@ -494,6 +448,16 @@ def _logic_wan_i2v_tp_forward(rank, world_size): def _logic_wan_i2v_tp_vs_single_gpu(rank, world_size): """WAN I2V: TP 2-GPU output matches single-GPU reference.""" + _logic_wan_i2v_tp_vs_single_gpu_with_config(rank, world_size, _WAN_I2V_TEST_CONFIG) + + +def _logic_wan_i2v_tp3_uneven_vs_single_gpu(rank, world_size): + """WAN I2V: TP=3 with uneven head/FFN dims matches single-GPU reference.""" + _logic_wan_i2v_tp_vs_single_gpu_with_config(rank, world_size, _WAN_I2V_UNEVEN_TP3_CONFIG) + + +def _logic_wan_i2v_tp_vs_single_gpu_with_config(rank, world_size, config_dict): + """WAN I2V: TP output matches single-GPU reference.""" from tensorrt_llm._torch.visual_gen.models.wan.transformer_wan import WanTransformer3DModel device = torch.device(f"cuda:{rank}") @@ -507,15 +471,15 @@ def _logic_wan_i2v_tp_vs_single_gpu(rank, world_size): # Create single-GPU reference model torch.manual_seed(123) - ref_config = _make_model_config(_WAN_I2V_TEST_CONFIG, tp_size=1) + ref_config = _make_model_config(config_dict, tp_size=1) ref_model = WanTransformer3DModel(ref_config).to(device).to(compute_dtype) _stabilize_model_weights(ref_model) # Create TP model and copy sharded weights from ref torch.manual_seed(123) - tp_config = _make_model_config(_WAN_I2V_TEST_CONFIG, tp_size=world_size) + tp_config = _make_model_config(config_dict, tp_size=world_size) tp_model = WanTransformer3DModel(tp_config).to(device).to(compute_dtype) - _copy_ref_weights_to_tp(ref_model, tp_model, rank, world_size) + _copy_ref_weights_to_tp(ref_model, tp_model, rank, world_size, config_dict) # Same inputs on all ranks torch.manual_seed(456) @@ -526,7 +490,8 @@ def _logic_wan_i2v_tp_vs_single_gpu(rank, world_size): torch.randn(batch, txt_seq, 128, device=device, dtype=compute_dtype) * 0.1 ) encoder_hidden_states_image = ( - torch.randn(batch, img_seq, 64, device=device, dtype=compute_dtype) * 0.1 + torch.randn(batch, img_seq, config_dict["image_dim"], device=device, dtype=compute_dtype) + * 0.1 ) timestep = torch.tensor([0.5], device=device, dtype=compute_dtype) @@ -590,5 +555,17 @@ def test_wan_i2v_tp_vs_single_gpu(self): run_test_in_distributed(world_size=2, test_fn=_logic_wan_i2v_tp_vs_single_gpu) +class TestWanUnevenTP3: + """TP=3 tests where head count and FFN dims are not divisible by tp_size.""" + + def test_wan_t2v_tp3_uneven_vs_single_gpu(self): + """WAN T2V TP=3 (4 heads, uneven FFN) matches single-GPU reference.""" + run_test_in_distributed(world_size=3, test_fn=_logic_wan_t2v_tp3_uneven_vs_single_gpu) + + def test_wan_i2v_tp3_uneven_vs_single_gpu(self): + """WAN I2V TP=3 (4 heads, uneven FFN) matches single-GPU reference.""" + run_test_in_distributed(world_size=3, test_fn=_logic_wan_i2v_tp3_uneven_vs_single_gpu) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/unittest/_torch/visual_gen/multi_gpu/tp_shard_utils.py b/tests/unittest/_torch/visual_gen/multi_gpu/tp_shard_utils.py new file mode 100644 index 000000000000..7abbb121214e --- /dev/null +++ b/tests/unittest/_torch/visual_gen/multi_gpu/tp_shard_utils.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TP weight-sharding helpers for VisualGen multi-GPU unit tests. + +Mirrors ``Linear._calc_shard`` for MLP dims and ``Attention.shard_start`` for +head-aligned Q/K/V shards so reference weights match TP module layouts. +""" + +from __future__ import annotations + +import torch + + +def calc_shard(total: int, tp_size: int, rank: int) -> int: + """Start index for *rank* when splitting *total* elements across *tp_size* ranks.""" + return (total // tp_size) * rank + min(total % tp_size, rank) + + +def qkv_head_bounds( + tp_rank: int, + tp_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + ulysses_size: int = 1, +) -> tuple[int, int, int, int]: + """Return (q_start, q_end, kv_start, kv_end) feature bounds for one TP rank.""" + gqa_ratio = num_attention_heads // num_key_value_heads + kv_heads_per_ulysses = num_key_value_heads // ulysses_size + kv_head_start = calc_shard(kv_heads_per_ulysses, tp_size, tp_rank) * ulysses_size + kv_head_end = calc_shard(kv_heads_per_ulysses, tp_size, tp_rank + 1) * ulysses_size + attn_head_start = kv_head_start * gqa_ratio + attn_head_end = kv_head_end * gqa_ratio + q_start = attn_head_start * head_dim + q_end = attn_head_end * head_dim + kv_start = kv_head_start * head_dim + kv_end = kv_head_end * head_dim + return q_start, q_end, kv_start, kv_end + + +def shard_dim0(tensor: torch.Tensor, tp_rank: int, tp_size: int) -> torch.Tensor: + """Shard a tensor along dim 0 (column-parallel output / 1D bias).""" + start = calc_shard(tensor.shape[0], tp_size, tp_rank) + end = calc_shard(tensor.shape[0], tp_size, tp_rank + 1) + return tensor[start:end].contiguous() + + +def shard_dim1(tensor: torch.Tensor, tp_rank: int, tp_size: int) -> torch.Tensor: + """Shard a tensor along dim 1 (row-parallel input).""" + start = calc_shard(tensor.shape[1], tp_size, tp_rank) + end = calc_shard(tensor.shape[1], tp_size, tp_rank + 1) + return tensor[:, start:end].contiguous() + + +def shard_fused_qkv_by_heads( + tensor: torch.Tensor, + tp_rank: int, + tp_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + q_dim: int, + kv_dim: int, + ulysses_size: int = 1, +) -> torch.Tensor: + """Shard fused QKV using Attention head boundaries, not flat row splits.""" + q, k, v = tensor.split([q_dim, kv_dim, kv_dim], dim=0) + q_start, q_end, kv_start, kv_end = qkv_head_bounds( + tp_rank, tp_size, num_attention_heads, num_key_value_heads, head_dim, ulysses_size + ) + return torch.cat([q[q_start:q_end], k[kv_start:kv_end], v[kv_start:kv_end]], dim=0).contiguous() + + +def shard_kv_dim0( + tensor: torch.Tensor, + tp_rank: int, + tp_size: int, + num_key_value_heads: int, + head_dim: int, + ulysses_size: int = 1, +) -> torch.Tensor: + """Shard column-parallel K/V (or KV-norm) weights along head boundaries.""" + _, _, kv_start, kv_end = qkv_head_bounds( + tp_rank, tp_size, num_key_value_heads, num_key_value_heads, head_dim, ulysses_size + ) + return tensor[kv_start:kv_end].contiguous() + + +def shard_q_dim0( + tensor: torch.Tensor, + tp_rank: int, + tp_size: int, + num_attention_heads: int, + head_dim: int, + num_key_value_heads: int | None = None, + ulysses_size: int = 1, +) -> torch.Tensor: + """Shard column-parallel Q weights along head boundaries.""" + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + q_start, q_end, _, _ = qkv_head_bounds( + tp_rank, tp_size, num_attention_heads, num_key_value_heads, head_dim, ulysses_size + ) + return tensor[q_start:q_end].contiguous() + + +def shard_q_dim1( + tensor: torch.Tensor, + tp_rank: int, + tp_size: int, + num_attention_heads: int, + head_dim: int, + num_key_value_heads: int | None = None, + ulysses_size: int = 1, +) -> torch.Tensor: + """Shard row-parallel output-proj weights along Q head boundaries.""" + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + q_start, q_end, _, _ = qkv_head_bounds( + tp_rank, tp_size, num_attention_heads, num_key_value_heads, head_dim, ulysses_size + ) + return tensor[:, q_start:q_end].contiguous() + + +def shard_kv_dim1( + tensor: torch.Tensor, + tp_rank: int, + tp_size: int, + num_key_value_heads: int, + head_dim: int, + ulysses_size: int = 1, +) -> torch.Tensor: + """Shard row-parallel K/V output weights along head boundaries.""" + _, _, kv_start, kv_end = qkv_head_bounds( + tp_rank, tp_size, num_key_value_heads, num_key_value_heads, head_dim, ulysses_size + ) + return tensor[:, kv_start:kv_end].contiguous() + + +def shard_fused_gate_up(tensor: torch.Tensor, tp_rank: int, tp_size: int) -> torch.Tensor: + """Shard fused gate_up weight [2*intermediate, ...] preserving gate/up structure.""" + half = tensor.shape[0] // 2 + gate, up = tensor.split([half, half], dim=0) + return torch.cat( + [ + shard_dim0(gate, tp_rank, tp_size), + shard_dim0(up, tp_rank, tp_size), + ], + dim=0, + ) + + +def copy_tp_parameter( + tp_name: str, + ref_param: torch.Tensor, + tp_param: torch.Tensor, + tp_rank: int, + tp_size: int, + num_attention_heads: int, + head_dim: int, + num_key_value_heads: int | None = None, + ulysses_size: int = 1, +) -> None: + """Copy one reference parameter into its TP-sharded counterpart.""" + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + if tp_param.shape == ref_param.shape: + tp_param.data.copy_(ref_param.data) + return + + if tp_param.ndim >= 2 and tp_param.shape[1] == ref_param.shape[1]: + if "qkv_proj" in tp_name or "add_qkv_proj" in tp_name: + q_dim = num_attention_heads * head_dim + kv_dim = num_key_value_heads * head_dim + tp_param.data.copy_( + shard_fused_qkv_by_heads( + ref_param.data, + tp_rank, + tp_size, + num_attention_heads, + num_key_value_heads, + head_dim, + q_dim, + kv_dim, + ulysses_size, + ) + ) + elif "to_q" in tp_name: + tp_param.data.copy_( + shard_q_dim0( + ref_param.data, + tp_rank, + tp_size, + num_attention_heads, + head_dim, + num_key_value_heads, + ulysses_size, + ) + ) + elif ( + "to_k" in tp_name + or "to_v" in tp_name + or "add_k_proj" in tp_name + or "add_v_proj" in tp_name + ): + tp_param.data.copy_( + shard_kv_dim0( + ref_param.data, tp_rank, tp_size, num_key_value_heads, head_dim, ulysses_size + ) + ) + elif "gate_up_proj" in tp_name: + tp_param.data.copy_(shard_fused_gate_up(ref_param.data, tp_rank, tp_size)) + else: + tp_param.data.copy_(shard_dim0(ref_param.data, tp_rank, tp_size)) + elif tp_param.ndim >= 2 and tp_param.shape[0] == ref_param.shape[0]: + if "to_add_out" in tp_name: + tp_param.data.copy_( + shard_kv_dim1( + ref_param.data, tp_rank, tp_size, num_key_value_heads, head_dim, ulysses_size + ) + ) + elif "to_out" in tp_name: + tp_param.data.copy_( + shard_q_dim1( + ref_param.data, + tp_rank, + tp_size, + num_attention_heads, + head_dim, + num_key_value_heads, + ulysses_size, + ) + ) + else: + tp_param.data.copy_(shard_dim1(ref_param.data, tp_rank, tp_size)) + elif tp_param.ndim == 1 and tp_param.shape[0] < ref_param.shape[0]: + if "qkv_proj" in tp_name or "add_qkv_proj" in tp_name: + q_dim = num_attention_heads * head_dim + kv_dim = num_key_value_heads * head_dim + tp_param.data.copy_( + shard_fused_qkv_by_heads( + ref_param.data, + tp_rank, + tp_size, + num_attention_heads, + num_key_value_heads, + head_dim, + q_dim, + kv_dim, + ulysses_size, + ) + ) + elif "to_q" in tp_name or "norm_q" in tp_name: + tp_param.data.copy_( + shard_q_dim0( + ref_param.data, + tp_rank, + tp_size, + num_attention_heads, + head_dim, + num_key_value_heads, + ulysses_size, + ) + ) + elif ( + "to_k" in tp_name + or "to_v" in tp_name + or "add_k_proj" in tp_name + or "add_v_proj" in tp_name + or "norm_added_k" in tp_name + or "norm_k" in tp_name + ): + tp_param.data.copy_( + shard_kv_dim0( + ref_param.data, tp_rank, tp_size, num_key_value_heads, head_dim, ulysses_size + ) + ) + elif "gate_up_proj" in tp_name: + tp_param.data.copy_(shard_fused_gate_up(ref_param.data, tp_rank, tp_size)) + else: + tp_param.data.copy_(shard_dim0(ref_param.data, tp_rank, tp_size)) + else: + raise ValueError(f"Cannot shard {tp_name}: ref={ref_param.shape}, tp={tp_param.shape}") From d7f3f8987caffc922cd5515cb6e25d3e413a5e2c Mon Sep 17 00:00:00 2001 From: Brenden Elgarten Date: Tue, 2 Jun 2026 22:04:17 +0000 Subject: [PATCH 2/4] fix coderabbit comments Signed-off-by: Brenden Elgarten --- tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py b/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py index f27f69dec0bf..db6e8335143e 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py @@ -257,7 +257,7 @@ def range_size(r): "up": (local_mlp_hidden_start, local_mlp_hidden_end), }, ) - self.local_qkv_dim = (q_dim + 2 * kv_dim) // self.tp_size + self.local_qkv_dim = local_q_dim + 2 * local_kv_dim self.local_mlp_dim = local_mlp_hidden_size def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: From 81eb3d5ede542327a63d107110959c507d0c055b Mon Sep 17 00:00:00 2001 From: Brenden Elgarten Date: Mon, 22 Jun 2026 19:46:31 +0000 Subject: [PATCH 3/4] address Shreyas' comments, add lpips test Signed-off-by: Brenden Elgarten --- tensorrt_llm/_torch/modules/gated_mlp.py | 12 +++++++ tensorrt_llm/_torch/modules/linear.py | 31 ++++++++++++++++--- .../visual_gen/models/flux/joint_proj.py | 5 ++- .../models/flux/transformer_flux.py | 15 +++++++++ .../_torch/visual_gen/modules/attention.py | 17 ++++------ .../visual_gen/test_visual_gen_multi_gpu.py | 1 + 6 files changed, 64 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index 2b49fa195a08..e8528e23c1ca 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -71,6 +71,14 @@ def __init__( mapping.tp_rank + 1) local_intermediate_size = local_intermediate_end - local_intermediate_start + self._uneven_tp_blocks_lora = (mapping.tp_size > 1 + and self.intermediate_size % + mapping.tp_size != 0) + + # gateup_shard_indices_mapping is the local offset and size for each sub-weight + # in this rank's concatenated (gate || up) buffer. + # override_tp_sharding is the absolute range of the global weight from which + # this rank pulls each sub-weight. gateup_shard_indices_mapping = { 'gate': (0, local_intermediate_size), 'up': (local_intermediate_size, local_intermediate_size), @@ -297,6 +305,10 @@ def forward_lora( ) -> torch.Tensor: assert lora_params is not None assert self.layer_idx is not None, "layer_idx is required for lora" + if self._uneven_tp_blocks_lora: + raise NotImplementedError( + "LoRA is not supported with uneven TP for GatedMLP " + "(intermediate_size not divisible by tp_size).") h1 = self.gate_up_proj(x) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index d51babad8630..acd621ef4cc3 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -106,6 +106,12 @@ def load_weight_shard( device: torch.device = torch.device('cpu'), return_slice_indices: bool = False, ) -> torch.Tensor: + """Legacy weight shard helper using ceil-divide sharding. + + `Linear.load_shard` is preferred — it respects uneven-TP overrides, fused + QKV/gate-up sharding dicts, and quant-specific scale/packing semantics that + this function does not. + """ # Skip device transfers on integrated GPUs to conserve shared memory if weight.device.type != device.type and is_device_integrated(): # For integrated GPU systems (e.g., DGX Spark), CPU and GPU share limited physical memory. @@ -2923,8 +2929,11 @@ def __init__( 'cutlass', 'cublaslt', 'cuda_core' ] - assert self.tp_mode in (TensorParallelMode.ROW, - TensorParallelMode.COLUMN, None) + if self.tp_mode not in (TensorParallelMode.ROW, + TensorParallelMode.COLUMN, None): + raise ValueError( + f"Invalid tp_mode {self.tp_mode!r}; expected ROW, COLUMN, or None." + ) # Init TP sharding either from override or auto generated _uneven_tp_unsupported = {QuantAlgo.NVFP4_ARC} @@ -3006,8 +3015,10 @@ def _calc_shard(total, tp_size, rank): def _auto_tp_sharding(self, features, quant_config): """Auto-generate tp_sharding tuple based on quant alignment requirements. - For VANILLA mode only. Fused modes with non-divisible dims require - explicit override_tp_sharding from the model layer. + VANILLA mode only. Fused modes (FUSED_QKV, FUSED_GATE_UP) require explicit + override_tp_sharding because individual sub-weight sizes (Q vs K vs V; gate + vs up) are not knowable here — they aren't always equal (e.g. GQA), and + cross-rank consistency must be decided by the caller. """ assert self.weights_loading_config.weight_mode == WeightMode.VANILLA, ( f"_auto_tp_sharding only supports VANILLA mode, got " @@ -3049,12 +3060,14 @@ def _calculate_local_features_helper(self, features): return end - start def calculate_local_in_features(self, in_features): + """Local input feature count after TP sharding (full size if not row-parallel).""" if self.tp_mode != TensorParallelMode.ROW: return in_features return self._calculate_local_features_helper(in_features) def calculate_local_out_features(self, out_features): + """Local output feature count after TP sharding (full size if not column-parallel).""" if self.tp_mode != TensorParallelMode.COLUMN: return out_features @@ -3062,7 +3075,7 @@ def calculate_local_out_features(self, out_features): def load_shard( self, - weights: Dict, + weights: Union[Dict, torch.Tensor], label: Optional[str] = None, device: torch.device = torch.device('cpu'), name: Optional[str] = None, @@ -3072,6 +3085,14 @@ def load_shard( # 2 are packed in each 8 bit element of the tensor elm_packing: int = 1, ) -> torch.Tensor: + """Slice a weight tensor for this rank's TP shard. + + Unified entry point for module-aware weight loading: respects + `self.tp_sharding` (uneven TP overrides, fused QKV/gate-up dicts) and + quant-specific knobs (`scale_span`, `elm_packing`). Pass `weights` as a + dict with `label` to pick the entry, or as a bare tensor when there's + only one. Supersedes the free function `load_weight_shard`. + """ if label: if label not in weights: return None diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py b/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py index db6e8335143e..db6d2e336818 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/joint_proj.py @@ -63,7 +63,7 @@ def __init__( self.has_bias = bias self.attn_shard = attn_shard - assert attn_dim % self.tp_size == 0 or self.attn_shard, ( + assert attn_dim % self.tp_size == 0 or self.attn_shard is not None, ( "Explicit attention sharding required for uneven TP" ) @@ -197,6 +197,9 @@ def __init__( self.local_qkv_dim = q_dim + 2 * kv_dim self.local_mlp_dim = mlp_dim else: + assert override_qkv_sharding is not None, ( + "override_qkv_sharding required when tp_size > 1" + ) def range_size(r): return r[1] - r[0] diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py index 527dce31df23..092016ba2977 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/transformer_flux.py @@ -495,6 +495,21 @@ def __init__( attn_shard=(self.attn.local_q_dim_start, self.attn.local_q_dim_end), ) + # MLP + Attn Output projection, requires special handling for TP + self.proj_out = FluxJointAttnMLPProj( + attn_dim=self.attn.q_dim, + mlp_dim=self.mlp_hidden_dim, + out_dim=dim, + bias=True, + dtype=dtype, + quant_config=quant_config, + skip_create_weights_in_init=skip_create_weights, + force_dynamic_quantization=force_dynamic_quant, + config=config, + # need explicit shard because we are aligned on head boundaries + attn_shard=(self.attn.local_q_dim_start, self.attn.local_q_dim_end), + ) + def forward( self, hidden_states: torch.Tensor, diff --git a/tensorrt_llm/_torch/visual_gen/modules/attention.py b/tensorrt_llm/_torch/visual_gen/modules/attention.py index aee7ccd1d821..e0571ce99969 100644 --- a/tensorrt_llm/_torch/visual_gen/modules/attention.py +++ b/tensorrt_llm/_torch/visual_gen/modules/attention.py @@ -109,7 +109,7 @@ def __init__( self.q_dim = self.num_attention_heads * self.head_dim self.kv_dim = self.num_key_value_heads * self.head_dim - self._calculate_tp_parameters(ulysses_size if enable_ulysses else None) + self._calculate_tp_parameters(ulysses_size if enable_sequence_parallel else None) self._init_qkv_proj() # Structural eligibility for SEPARATE_QKV self-attn quantize dedup. @@ -259,19 +259,14 @@ def _calculate_tp_parameters(self, ulysses_size: Optional[int]): ulysses_size = 1 assert self.num_key_value_heads % ulysses_size == 0 - # Note: this is intentionally stronger than `num_kv_head >= ulysses_size * tp_size` assert self.num_key_value_heads // ulysses_size >= self.tp_size - def _calc_shard(full, size, rank): - full //= ulysses_size - shard = (full // size) * rank + min(full % size, rank) - return shard * ulysses_size - - self.local_key_value_head_start = _calc_shard( - self.num_key_value_heads, self.tp_size, self.tp_rank + kv_heads_per_ulysses = self.num_key_value_heads // ulysses_size + self.local_key_value_head_start = ( + Linear._calc_shard(kv_heads_per_ulysses, self.tp_size, self.tp_rank) * ulysses_size ) - self.local_key_value_head_end = _calc_shard( - self.num_key_value_heads, self.tp_size, self.tp_rank + 1 + self.local_key_value_head_end = ( + Linear._calc_shard(kv_heads_per_ulysses, self.tp_size, self.tp_rank + 1) * ulysses_size ) self.local_num_key_value_heads = ( self.local_key_value_head_end - self.local_key_value_head_start diff --git a/tests/integration/defs/examples/visual_gen/test_visual_gen_multi_gpu.py b/tests/integration/defs/examples/visual_gen/test_visual_gen_multi_gpu.py index a60765fe600c..a8383bba0ac6 100644 --- a/tests/integration/defs/examples/visual_gen/test_visual_gen_multi_gpu.py +++ b/tests/integration/defs/examples/visual_gen/test_visual_gen_multi_gpu.py @@ -59,6 +59,7 @@ WAN22_LPIPS_TP_VARIANTS = [ ("tp2", {"tp_size": 2}), + ("tp3", {"tp_size": 3}), ("cfg2_tp2", {"cfg_size": 2, "tp_size": 2}), ("tp2_ulysses2", {"tp_size": 2, "ulysses_size": 2}), ] From 3a6e8839aeaf2604035145b62edd5e2a33bb0de2 Mon Sep 17 00:00:00 2001 From: Brenden Elgarten Date: Tue, 23 Jun 2026 16:14:34 +0000 Subject: [PATCH 4/4] fix ci: fused scale loading, auto sharding with skip_create_weights Signed-off-by: Brenden Elgarten --- tensorrt_llm/_torch/modules/linear.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index acd621ef4cc3..ea2f47f2f867 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -1213,8 +1213,9 @@ def load_weights_fused_qkv_linear( ] scales = [ - module.load_shard(s, scale_span=128) if s is not None else None - for s in full_scales_squeezed + module.load_shard(s, scale_span=128, name=name) + if s is not None else None + for s, name in zip(full_scales_squeezed, ('q', 'k', 'v')) ] processed_mapping = self.remap_fused_shard_indices_by_divisible_factor( module.fused_weight_shard_indices_mapping, 128) @@ -1241,9 +1242,11 @@ def load_weights_fused_gate_up_linear( for s in full_scales ] scales = [ - module.load_shard(s, scale_span=128) if s is not None else None - for s in full_scales_squeezed + module.load_shard(s, scale_span=128, name=name) + if s is not None else None + for s, name in zip(full_scales_squeezed, ('gate', 'up')) ] + processed_mapping = self.remap_fused_shard_indices_by_divisible_factor( module.fused_weight_shard_indices_mapping, 128) for shard_key, scale in zip(processed_mapping.keys(), scales): @@ -2943,7 +2946,8 @@ def __init__( self.tp_sharding = override_tp_sharding elif self.tp_size > 1 and self.tp_mode is not None \ and self.weights_loading_config.weight_mode == WeightMode.VANILLA \ - and _quant_algo not in _uneven_tp_unsupported: + and _quant_algo not in _uneven_tp_unsupported \ + and not skip_create_weights_in_init: features = in_features if self.tp_mode == TensorParallelMode.ROW else out_features self.tp_sharding = self._auto_tp_sharding(features, quant_config) else: