Skip to content

Commit 8183123

Browse files
committed
[None][fix] fix AutoDeploy sharding IR dist config
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
1 parent ad30176 commit 8183123

3 files changed

Lines changed: 5 additions & 11 deletions

File tree

tensorrt_llm/_torch/auto_deploy/transform/library/sharding_ir.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,14 +1069,6 @@ def get_partition(lst, world_size, rank):
10691069
self._localize_expert_indices(
10701070
gm, selected_experts, routing_weights, experts_per_rank, ep_rank, ep_size
10711071
)
1072-
_, all_reduce_op = _get_dist_ops(dc.dist_backend)
1073-
with gm.graph.inserting_after(self.node):
1074-
red = gm.graph.call_function(
1075-
all_reduce_op,
1076-
args=(self.node, dc.allreduce_strategy),
1077-
)
1078-
self.node.replace_all_uses_with(red)
1079-
red.replace_input_with(red, self.node)
10801072

10811073
ad_logger.debug(
10821074
f" sharded MoE: {num_experts} experts, ep={ep_size}, ep_rank={ep_rank}, "

tensorrt_llm/_torch/auto_deploy/utils/dist_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def from_sharding_params(
137137
dist_mapping: dict,
138138
enable_attention_dp: bool = False,
139139
allreduce_strategy: str = "NCCL",
140+
dist_backend: Literal["auto", "torch", "trtllm"] = "auto",
140141
) -> "DistConfig":
141142
"""Build ``DistConfig`` from sharding-transform YAML inputs + runtime MPI info.
142143
@@ -154,6 +155,7 @@ def from_sharding_params(
154155
moe_cluster_size=dist_mapping.get("moe_cluster", 1),
155156
enable_attention_dp=enable_attention_dp,
156157
allreduce_strategy=allreduce_strategy,
158+
dist_backend=dist_backend,
157159
)
158160

159161
def to_mapping(self) -> Any:

tests/unittest/auto_deploy/multigpu/transformations/library/test_apply_sharding_hints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,8 @@ def _make_list_moe_graph():
480480
return gm
481481

482482

483-
def test_list_moe_ir_contract_inserts_all_reduce_for_ep():
484-
"""List-based MoE EP sharding localizes experts and adds a graph collective."""
483+
def test_list_moe_ir_contract_leaves_ep_reduction_to_modeling():
484+
"""List-based MoE EP sharding localizes experts without choosing a reduction site."""
485485
gm = _make_list_moe_graph()
486486
gm_out = _make_optimizer(world_size=2)(None, gm)
487487
moe_nodes = _call_nodes(gm_out, torch.ops.auto_deploy.torch_moe)
@@ -492,7 +492,7 @@ def test_list_moe_ir_contract_inserts_all_reduce_for_ep():
492492
assert len(w1_weight) == 2
493493
assert len(w2_weight) == 2
494494
assert len(w3_weight) == 2
495-
assert len(_call_nodes(gm_out, torch.ops.auto_deploy.torch_dist_all_reduce)) == 1
495+
assert len(_call_nodes(gm_out, torch.ops.auto_deploy.torch_dist_all_reduce)) == 0
496496

497497

498498
def _optional_auto_deploy_default(name):

0 commit comments

Comments
 (0)