Skip to content

Commit 6f022ca

Browse files
committed
[#14679][fix] Fix fused-QKV TP sharding for Phi-3/Phi-4
Signed-off-by: Guan-Ming (Wesley) Chiu <105915352+guan404ming@users.noreply.github.com>
1 parent 1aa232a commit 6f022ca

3 files changed

Lines changed: 51 additions & 13 deletions

File tree

examples/auto_deploy/model_registry/models.yaml

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,15 @@ models:
6767
- name: Qwen/Qwen3-8B
6868
config_id: default_ws_2
6969
yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml', 'enable_sharder_ir.yaml']
70-
# RuntimeError: a and b must have same reduction dim, but got [s44*s70, 5120] X [2560, 5120]. See https://github.com/NVIDIA/TensorRT-LLM/issues/14679
71-
# - name: microsoft/phi-4
72-
# config_id: default_ws_2
73-
# yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml']
74-
# RuntimeError: a and b must have same reduction dim, but got [s44*s70, 5120] X [2560, 5120]. See https://github.com/NVIDIA/TensorRT-LLM/issues/14679
75-
# - name: microsoft/Phi-4-reasoning
76-
# config_id: default_ws_2
77-
# yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml']
78-
# RuntimeError: a and b must have same reduction dim, but got [s44*s70, 5120] X [2560, 5120]. See https://github.com/NVIDIA/TensorRT-LLM/issues/14679
79-
# - name: microsoft/Phi-4-reasoning-plus
80-
# config_id: default_ws_2
81-
# yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml']
70+
- name: microsoft/phi-4
71+
config_id: default_ws_2
72+
yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml']
73+
- name: microsoft/Phi-4-reasoning
74+
config_id: default_ws_2
75+
yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml']
76+
- name: microsoft/Phi-4-reasoning-plus
77+
config_id: default_ws_2
78+
yaml_extra: ['dashboard_default.yaml', 'world_size_2.yaml']
8279
# IndexError: list index out of range in AutoDeploy sharding path. See https://github.com/NVIDIA/TensorRT-LLM/issues/14681
8380
# - name: google/gemma-1.1-7b-it
8481
# config_id: default_ws_2

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2855,7 +2855,7 @@ def _process_mla_sharding(
28552855

28562856
def _determine_fused_weight_dims(
28572857
linear_nodes: List[Node],
2858-
) -> None:
2858+
) -> Optional[List[int]]:
28592859
"""
28602860
Determine the fused weight dims for the given linear nodes and subgraph nodes.
28612861
"""
@@ -2900,6 +2900,8 @@ def _determine_fused_weight_dims(
29002900
weight_dim = linear_node.meta["val"].shape[2]
29012901
fused_weight_dims = [weight_dim // num_chunks] * num_chunks
29022902

2903+
return fused_weight_dims
2904+
29032905

29042906
def _find_upstream_qk_proj(node: Node, gm: GraphModule) -> Optional[str]:
29052907
"""

tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
ShardingTransformConfig,
4141
SplitDimension,
4242
WeightShardingInfo,
43+
_determine_fused_weight_dims,
4344
_update_node_args,
4445
)
4546
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
@@ -502,6 +503,44 @@ def test_update_node_args_preserves_nested_symbolic_shape_nodes():
502503
assert placeholder_targets == ["x"]
503504

504505

506+
class _FusedQKVProj(nn.Module):
507+
"""Single fused qkv_proj sliced into q/k/v (Phi-3/Phi-4 layout)."""
508+
509+
def __init__(self, hidden_size: int, n_heads: int, n_kv_heads: int, head_dim: int):
510+
super().__init__()
511+
self.q_dim = n_heads * head_dim
512+
self.kv_dim = n_kv_heads * head_dim
513+
self.qkv_proj = nn.Linear(hidden_size, self.q_dim + 2 * self.kv_dim, bias=False)
514+
515+
def forward(self, x):
516+
qkv = self.qkv_proj(x)
517+
q = qkv[..., : self.q_dim]
518+
k = qkv[..., self.q_dim : self.q_dim + self.kv_dim]
519+
v = qkv[..., self.q_dim + self.kv_dim :]
520+
return q.sum() + k.sum() + v.sum()
521+
522+
523+
def test_determine_fused_weight_dims_qkv():
524+
"""Regression for NVIDIA/TensorRT-LLM#14679: fused qkv_proj column sharding.
525+
526+
`_determine_fused_weight_dims` must return the [q, k, v] split sizes so the
527+
slice boundaries get divided by world_size during column sharding. A missing
528+
return made it yield None, leaving the slices at full width and breaking TP
529+
for fused-qkv models like Phi-3/Phi-4.
530+
"""
531+
hidden_size, n_heads, n_kv_heads, head_dim = 32, 4, 2, 8
532+
model = _FusedQKVProj(hidden_size, n_heads, n_kv_heads, head_dim)
533+
x = torch.randn(2, 3, hidden_size)
534+
gm = torch_export_to_gm(model, args=(x,), clone=True)
535+
536+
slice_nodes = [n for n in gm.graph.nodes if is_op(n, torch.ops.aten.slice)]
537+
assert len(slice_nodes) == 3, "Expected 3 slice nodes for fused QKV"
538+
qkv_node = slice_nodes[0].args[0]
539+
540+
kv_dim = n_kv_heads * head_dim
541+
assert _determine_fused_weight_dims([qkv_node]) == [n_heads * head_dim, kv_dim, kv_dim]
542+
543+
505544
def _run_sharding_execution_job(
506545
model_cls: nn.Module,
507546
dist_op_expected: str,

0 commit comments

Comments
 (0)