Skip to content

Commit 7dda05e

Browse files
committed
[#9164][feat] AutoDeploy: noaux_tc MoE routing pattern matcher
Signed-off-by: Guan-Ming (Wesley) Chiu <105915352+guan404ming@users.noreply.github.com>
1 parent 0a19205 commit 7dda05e

3 files changed

Lines changed: 442 additions & 0 deletions

File tree

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ transforms:
7272
stage: pattern_matcher
7373
match_moe_routing_pattern:
7474
stage: pattern_matcher
75+
match_noaux_tc_pattern:
76+
stage: pattern_matcher
7577
############################################################################################
7678
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
7779
############################################################################################

tensorrt_llm/_torch/auto_deploy/transform/library/moe_routing.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
2929
The fused kernel avoids computing softmax over all experts (e.g. 256), instead
3030
finding top-k from raw logits and computing softmax only over the k selected values.
31+
32+
Also detects the noaux_tc routing pattern used by DeepSeek-V3 / NemotronH /
33+
GLM4-MoE / Kimi-K2, replacing it with ``torch.ops.trtllm.noaux_tc_op``.
3134
"""
3235

3336
import operator
@@ -223,3 +226,298 @@ def _apply(
223226
has_valid_shapes=num_matches == 0,
224227
)
225228
return gm, info
229+
230+
231+
# ---------------------------------------------------------------------------
232+
# noaux_tc routing pattern helpers
233+
# ---------------------------------------------------------------------------
234+
235+
_TOPK_OPS = (torch.ops.aten.topk.default,)
236+
_VIEW_OPS = (torch.ops.aten.view.default, torch.ops.aten.reshape.default)
237+
_ADD_OPS = (torch.ops.aten.add.Tensor,)
238+
239+
240+
def _scalar_int(node_or_value) -> Optional[int]:
241+
"""Return *node_or_value* as a Python int if it is a literal, else None."""
242+
if isinstance(node_or_value, int):
243+
return node_or_value
244+
return None
245+
246+
247+
def _find_bias_add_after_sigmoid(sigmoid_node: Node) -> Optional[Tuple[Node, Node]]:
248+
"""Find ``scores + bias`` user of *sigmoid_node*; return (add_node, bias_node)."""
249+
for user in sigmoid_node.users:
250+
if not is_op(user, _ADD_OPS):
251+
continue
252+
a, b = user.args[0], user.args[1]
253+
if a is sigmoid_node and isinstance(b, Node):
254+
return user, b
255+
if b is sigmoid_node and isinstance(a, Node):
256+
return user, a
257+
return None
258+
259+
260+
def _find_group_topk(scores_with_bias: Node) -> Optional[Tuple[Node, int]]:
261+
"""Find the ``topk(view(scores_with_bias, ...), k=2)`` user; return (node, n_group)."""
262+
for user in scores_with_bias.users:
263+
view_node = user if is_op(user, _VIEW_OPS) else None
264+
if view_node is None:
265+
continue
266+
# view shape can be the second arg (list of ints/Nodes)
267+
shape = view_node.args[1] if len(view_node.args) > 1 else None
268+
if not isinstance(shape, (list, tuple)) or len(shape) < 2:
269+
continue
270+
n_group = _scalar_int(shape[-2])
271+
if n_group is None:
272+
continue
273+
for vu in view_node.users:
274+
if is_op(vu, _TOPK_OPS) and _scalar_int(vu.args[1]) == 2:
275+
return vu, n_group
276+
return None
277+
278+
279+
def _find_outer_topk(masked_node: Node) -> Optional[Tuple[Node, int]]:
280+
"""Find ``topk(masked, k=top_k)`` user of *masked_node*; return (node, top_k)."""
281+
for user in masked_node.users:
282+
candidate = user
283+
if is_op(candidate, _TOPK_OPS):
284+
top_k = _scalar_int(candidate.args[1])
285+
if top_k is not None:
286+
return candidate, top_k
287+
# allow one view in between
288+
if is_op(candidate, _VIEW_OPS):
289+
for vu in candidate.users:
290+
if is_op(vu, _TOPK_OPS):
291+
top_k = _scalar_int(vu.args[1])
292+
if top_k is not None:
293+
return vu, top_k
294+
return None
295+
296+
297+
def _descends_from(node: Node, target: Node, max_depth: int = 10) -> bool:
298+
"""Return True if *target* is reachable from *node*'s input ancestry within max_depth hops."""
299+
if not isinstance(node, Node) or not isinstance(target, Node):
300+
return False
301+
visited = set()
302+
frontier = [(node, 0)]
303+
while frontier:
304+
n, d = frontier.pop()
305+
if n is target:
306+
return True
307+
if d >= max_depth or n in visited:
308+
continue
309+
visited.add(n)
310+
for inp in n.all_input_nodes:
311+
frontier.append((inp, d + 1))
312+
return False
313+
314+
315+
def _find_gather_from_indices(indices_node: Node, scores_node: Node) -> Optional[Node]:
316+
"""Find ``aten.gather.default(scores_node, dim, indices_node)`` user of *indices_node*."""
317+
for user in indices_node.users:
318+
if not is_op(user, torch.ops.aten.gather.default):
319+
continue
320+
if len(user.args) >= 3 and user.args[0] is scores_node and user.args[2] is indices_node:
321+
return user
322+
return None
323+
324+
325+
def _is_sum_of(node, cur: Node) -> bool:
326+
return (
327+
isinstance(node, Node)
328+
and is_op(node, torch.ops.aten.sum.dim_IntList)
329+
and node.args[0] is cur
330+
)
331+
332+
333+
def _is_normalize_divisor(divisor, cur: Node) -> bool:
334+
"""Accept ``sum(cur, ...)`` or its epsilon-stabilized form ``sum(...) + eps``."""
335+
if _is_sum_of(divisor, cur):
336+
return True
337+
if isinstance(divisor, Node) and is_op(divisor, torch.ops.aten.add.Tensor):
338+
a, b = divisor.args[0], divisor.args[1]
339+
for sum_cand, eps_cand in ((a, b), (b, a)):
340+
if _is_sum_of(sum_cand, cur) and isinstance(eps_cand, (int, float)):
341+
return True
342+
return False
343+
344+
345+
def _walk_div_then_mul(start: Node) -> Tuple[Node, float]:
346+
"""Walk forward through optional ``div.Tensor(self, sum)`` then ``mul.Tensor(self, scalar)``.
347+
348+
Returns ``(final_node, routed_scaling_factor)``. If no scale is found, the
349+
factor defaults to ``1.0`` and *final_node* is the last node reached on the
350+
chain (gather, or div if no mul, etc.).
351+
"""
352+
cur = start
353+
# optional norm: div(cur, sum(cur, ..., keepdim=True) [+ eps])
354+
for user in cur.users:
355+
if (
356+
is_op(user, torch.ops.aten.div.Tensor)
357+
and user.args[0] is cur
358+
and _is_normalize_divisor(user.args[1], cur)
359+
):
360+
cur = user
361+
break
362+
# optional scalar multiply
363+
for user in cur.users:
364+
if is_op(user, torch.ops.aten.mul.Tensor) and user.args[0] is cur:
365+
scalar = user.args[1]
366+
if isinstance(scalar, (int, float)):
367+
return user, float(scalar)
368+
return cur, 1.0
369+
370+
371+
@TransformRegistry.register("match_noaux_tc_pattern")
372+
class MatchNoAuxTCPattern(BaseTransform):
373+
"""Match the noaux_tc MoE routing chain and replace with a fused trtllm op.
374+
375+
This transform detects the DeepSeek-V3 style routing pattern::
376+
377+
sigmoid → +bias → group top-k → mask → top-k → gather → [norm] → scale
378+
379+
and replaces it with::
380+
381+
topk_weights, topk_idx = trtllm.noaux_tc_op(
382+
router_logits, bias, n_group, topk_group, top_k, routed_scaling_factor
383+
)
384+
385+
The fused kernel performs sigmoid, bias correction, group-based top-k
386+
selection, gather, normalization and scaling in a single CUDA kernel.
387+
"""
388+
389+
config: TransformConfig
390+
391+
@classmethod
392+
def get_config_class(cls) -> Type[TransformConfig]:
393+
return TransformConfig
394+
395+
def _apply(
396+
self,
397+
gm: GraphModule,
398+
cm: CachedSequenceInterface,
399+
factory: ModelFactory,
400+
shared_config: SharedConfig,
401+
) -> Tuple[GraphModule, TransformInfo]:
402+
graph = gm.graph
403+
num_matches = 0
404+
405+
for node in list(graph.nodes):
406+
# ---- anchor: aten.sigmoid -> add(bias) ------------------------
407+
if not is_op(node, torch.ops.aten.sigmoid.default):
408+
continue
409+
sigmoid_node = node
410+
router_logits = sigmoid_node.args[0]
411+
if not isinstance(router_logits, Node):
412+
continue
413+
414+
bias_add = _find_bias_add_after_sigmoid(sigmoid_node)
415+
if bias_add is None:
416+
continue
417+
scores_with_bias_node, bias_node = bias_add
418+
419+
# ---- group top-k(k=2) ----------------------------------------
420+
inner = _find_group_topk(scores_with_bias_node)
421+
if inner is None:
422+
continue
423+
inner_topk, n_group = inner
424+
425+
inner_values = _get_single_getitem_user(inner_topk, 0)
426+
if inner_values is None:
427+
continue
428+
429+
# ---- sum -> outer topk(k=topk_group) -------------------------
430+
sum_node = None
431+
for u in inner_values.users:
432+
if is_op(u, torch.ops.aten.sum.dim_IntList):
433+
sum_node = u
434+
break
435+
if sum_node is None:
436+
continue
437+
438+
outer_grp_topk = None
439+
for u in sum_node.users:
440+
if is_op(u, _TOPK_OPS):
441+
outer_grp_topk = u
442+
break
443+
if outer_grp_topk is None:
444+
continue
445+
topk_group = _scalar_int(outer_grp_topk.args[1])
446+
if topk_group is None:
447+
continue
448+
449+
# ---- final masked top-k(k=top_k) -----------------------------
450+
# Only accept a masked_node whose mask input descends from outer_grp_topk;
451+
# otherwise an unrelated branch consuming scores_with_bias could be picked.
452+
masked_node = None
453+
for u in scores_with_bias_node.users:
454+
if not is_op(
455+
u,
456+
(
457+
torch.ops.aten.where.self,
458+
torch.ops.aten.masked_fill.Scalar,
459+
torch.ops.aten.mul.Tensor,
460+
),
461+
):
462+
continue
463+
if not _descends_from(u, outer_grp_topk):
464+
continue
465+
masked_node = u
466+
break
467+
if masked_node is None:
468+
continue
469+
470+
outer_topk = _find_outer_topk(masked_node)
471+
if outer_topk is None:
472+
continue
473+
final_topk_node, top_k = outer_topk
474+
475+
final_indices = _get_single_getitem_user(final_topk_node, 1)
476+
if final_indices is None:
477+
continue
478+
479+
# ---- weights branch: gather(scores, -1, topk_idx) [/ sum] * scale --
480+
gather_node = _find_gather_from_indices(final_indices, sigmoid_node)
481+
if gather_node is None:
482+
continue
483+
weights_tail, routed_scaling_factor = _walk_div_then_mul(gather_node)
484+
485+
# ---- emit fused noaux_tc_op ---------------------------------
486+
ad_logger.info(
487+
"Matched noaux_tc routing pattern: "
488+
f"n_group={n_group}, topk_group={topk_group}, top_k={top_k}, "
489+
f"scale={routed_scaling_factor}"
490+
)
491+
492+
with graph.inserting_before(sigmoid_node):
493+
fused = graph.call_function(
494+
torch.ops.trtllm.noaux_tc_op,
495+
args=(
496+
router_logits,
497+
bias_node,
498+
n_group,
499+
topk_group,
500+
top_k,
501+
routed_scaling_factor,
502+
),
503+
)
504+
fused_weights = graph.call_function(operator.getitem, args=(fused, 0))
505+
fused_indices = graph.call_function(operator.getitem, args=(fused, 1))
506+
507+
final_indices.replace_all_uses_with(fused_indices)
508+
weights_tail.replace_all_uses_with(fused_weights)
509+
510+
num_matches += 1
511+
512+
if num_matches > 0:
513+
eliminate_dead_code(gm)
514+
gm.recompile()
515+
ad_logger.info(f"Fused {num_matches} noaux_tc routing pattern(s).")
516+
517+
info = TransformInfo(
518+
skipped=False,
519+
num_matches=num_matches,
520+
is_clean=num_matches == 0,
521+
has_valid_shapes=num_matches == 0,
522+
)
523+
return gm, info

0 commit comments

Comments
 (0)