Skip to content

Commit e56397d

Browse files
authored
[None][feat] Support tensor parallelism of trtllm moe backend for nemotron-h model (NVIDIA#11470)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 617440d commit e56397d

4 files changed

Lines changed: 81 additions & 1 deletion

File tree

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,19 @@ def maybe_pad_for_mxfp4(weight: torch.Tensor,
176176
return weight
177177

178178

179+
def _pad_tensor_to_shape(tensor: torch.Tensor, shape: tuple) -> torch.Tensor:
180+
"""Pad tensor to match target shape. Used for post-shard alignment."""
181+
if tensor.numel() == 0:
182+
return tensor
183+
if tensor.shape == shape:
184+
return tensor
185+
if len(tensor.shape) == 1:
186+
return F.pad(tensor, (0, shape[0] - tensor.shape[0])).contiguous()
187+
row_pad = shape[0] - tensor.shape[0]
188+
col_pad = shape[1] - tensor.shape[1]
189+
return F.pad(tensor, (0, col_pad, 0, row_pad)).contiguous()
190+
191+
179192
def interleave_linear_and_gate(x: torch.Tensor,
180193
group_size: int = 64,
181194
dim: int = -1) -> torch.Tensor:
@@ -2915,6 +2928,9 @@ def round_up(x, alignment):
29152928
return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape,
29162929
w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape)
29172930

2931+
def _round_up(self, x, alignment):
2932+
return (x + alignment - 1) // alignment * alignment
2933+
29182934
def create_weights(self, module: torch.nn.Module):
29192935
# Here we only enable padding for hidden_size > 1024 since there are small unit tests that expect no padding.
29202936
if module.hidden_size > 1024 and module.hidden_size % 256 != 0:
@@ -2923,6 +2939,15 @@ def create_weights(self, module: torch.nn.Module):
29232939
# See the comment in MXFP4WeightTRTLLMGenFusedMoEMethod for more details.
29242940
self.input_hidden_alignment = 256
29252941

2942+
else:
2943+
# Weight scales require M % 128 in get_shuffle_matrix_sf_a_row_indices.
2944+
# Check if intermediate_size after padding satisfies this requirement.
2945+
# If not, set weight_alignment to 128.
2946+
intermediate_size_padded = self._round_up(
2947+
module.intermediate_size_per_partition, self.weight_alignment)
2948+
if intermediate_size_padded % 128 != 0:
2949+
self.weight_alignment = 128
2950+
29262951
super().create_weights(module, bias_dtype=torch.float32)
29272952

29282953
def setup_quant_scales(self, module: torch.nn.Module):
@@ -2981,6 +3006,8 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
29813006
dst_w3_weight.copy_(w3_weight_shard.view(dst_w3_weight.dtype))
29823007
dst_w1_weight.copy_(w1_weight_shard.view(dst_w1_weight.dtype))
29833008
else:
3009+
w1_weight_shard = _pad_tensor_to_shape(w1_weight_shard,
3010+
dst_w3_w1_weight_gpu.shape)
29843011
dst_w3_w1_weight_gpu.copy_(
29853012
w1_weight_shard.view(dst_w3_w1_weight_gpu.dtype))
29863013

@@ -3038,6 +3065,8 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
30383065
epilogue_tile_m = 128
30393066

30403067
# Keep weights in device buffer
3068+
w2_weight_shard = _pad_tensor_to_shape(w2_weight_shard,
3069+
dst_w2_weight_gpu.shape)
30413070
dst_w2_weight_gpu.copy_(w2_weight_shard.view(dst_w2_weight_gpu.dtype),
30423071
non_blocking=dst_on_gpu)
30433072
# Get permuted indices
@@ -3071,7 +3100,7 @@ def load_expert_w3_w1_weight_scale_nvfp4(
30713100
alignment = _get_weight_alignment(self.weight_alignment,
30723101
module.scaling_vector_size,
30733102
module.tp_size,
3074-
w3_weight_scale.shape[0])
3103+
w1_weight_scale.shape[0])
30753104
w1_weight_scale = maybe_pad_for_mxfp4(
30763105
w1_weight_scale,
30773106
self.input_hidden_alignment // module.scaling_vector_size,
@@ -3113,6 +3142,8 @@ def load_expert_w3_w1_weight_scale_nvfp4(
31133142
w1_weight_scale.view(dst_w1_weight_scale.dtype))
31143143
else:
31153144
# Non-gated activation (e.g., ReLU2): buffer only contains w1 scale
3145+
w1_weight_scale = _pad_tensor_to_shape(
3146+
w1_weight_scale, dst_w3_w1_weight_scale_gpu.shape)
31163147
dst_w3_w1_weight_scale_gpu.copy_(
31173148
w1_weight_scale.view(dst_w3_w1_weight_scale_gpu.dtype))
31183149

@@ -3170,6 +3201,8 @@ def load_expert_w2_weight_scale_nvfp4(self,
31703201
TensorParallelMode.ROW,
31713202
device=device)
31723203
# Keep weights in device buffer
3204+
w2_weight_scale = _pad_tensor_to_shape(w2_weight_scale,
3205+
dst_w2_weight_scale_gpu.shape)
31733206
dst_w2_weight_scale_gpu.copy_(
31743207
w2_weight_scale.view(dst_w2_weight_scale_gpu.dtype))
31753208

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5798,6 +5798,46 @@ def test_nvfp4_8gpus(self, attention_dp, moe_backend):
57985798
task.evaluate(llm,
57995799
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
58005800

5801+
@skip_pre_blackwell
5802+
@pytest.mark.skip_less_mpi_world_size(8)
5803+
@pytest.mark.parametrize(
5804+
"tp_size, ep_size, pp_size, attention_dp",
5805+
[
5806+
(4, 1, 2, False),
5807+
(4, 4, 2, False),
5808+
(8, 1, 1, False),
5809+
(8, 8, 1, False),
5810+
(8, 1, 1, True),
5811+
],
5812+
ids=["TP4_PP2", "TEP4_PP2", "TP8_PP1", "TEP8_PP1", "TP8_PP1_ADP"],
5813+
)
5814+
def test_nvfp4_parallelism(self, tp_size, ep_size, pp_size, attention_dp):
5815+
with LLM(
5816+
f"{llm_models_root()}/Nemotron-SuperV3-phase1-mtp-nvfp4-fp8kv",
5817+
kv_cache_config=KvCacheConfig(
5818+
enable_block_reuse=False,
5819+
mamba_ssm_cache_dtype="float16",
5820+
free_gpu_memory_fraction=0.8,
5821+
),
5822+
max_batch_size=512,
5823+
tensor_parallel_size=tp_size,
5824+
moe_expert_parallel_size=ep_size,
5825+
pipeline_parallel_size=pp_size,
5826+
enable_attention_dp=attention_dp,
5827+
cuda_graph_config=CudaGraphConfig(max_batch_size=512,
5828+
enable_padding=True),
5829+
disable_overlap_scheduler=False,
5830+
moe_config=MoeConfig(backend="TRTLLM"),
5831+
) as llm:
5832+
task = MMLU(self.MODEL_NAME)
5833+
task.evaluate(llm,
5834+
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
5835+
# TODO: GSM8K will be failed due to mamba cache issue for pp_size > 1.
5836+
if pp_size == 1:
5837+
task = GSM8K(self.MODEL_NAME)
5838+
task.evaluate(
5839+
llm, extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
5840+
58015841
@skip_pre_blackwell
58025842
@pytest.mark.skip_less_mpi_world_size(8)
58035843
def test_nvfp4_8gpus_mtp(self):

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_
283283
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-python_mamba_cache]
284284
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_fp8_4gpus[attention_dp_on-cpp_mamba_cache]
285285
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on-trtllm]
286+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_parallelism[TP4_PP2]
287+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_parallelism[TEP4_PP2]
288+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_parallelism[TP8_PP1]
289+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_parallelism[TEP8_PP1]
290+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_parallelism[TP8_PP1_ADP]
286291

287292
# multimodal accuracy tests
288293
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ l0_dgx_b200:
117117
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus_mtp TIMEOUT (60)
118118
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on-trtllm] TIMEOUT (60)
119119
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on-cutlass] TIMEOUT (60)
120+
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_parallelism[TP4_PP2] TIMEOUT (60)
120121
- condition:
121122
ranges:
122123
system_gpu_count:
@@ -146,6 +147,7 @@ l0_dgx_b200:
146147
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_pp4_mtp1] TIMEOUT (60)
147148
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv] TIMEOUT (60)
148149
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_fp8[latency_moe_deepgemm] TIMEOUT (60)
150+
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_parallelism[TP8_PP1] TIMEOUT (60)
149151
- test_e2e.py::test_deepseek_r1_mtp_bench TIMEOUT(60) # Cover https://nvbugs/5670108
150152
- condition:
151153
ranges:

0 commit comments

Comments
 (0)