|
28 | 28 |
|
29 | 29 | The fused kernel avoids computing softmax over all experts (e.g. 256), instead |
30 | 30 | finding top-k from raw logits and computing softmax only over the k selected values. |
| 31 | +
|
| 32 | +Also detects the noaux_tc routing pattern used by DeepSeek-V3 / NemotronH / |
| 33 | +GLM4-MoE / Kimi-K2, replacing it with ``torch.ops.trtllm.noaux_tc_op``. |
31 | 34 | """ |
32 | 35 |
|
33 | 36 | import operator |
@@ -223,3 +226,298 @@ def _apply( |
223 | 226 | has_valid_shapes=num_matches == 0, |
224 | 227 | ) |
225 | 228 | return gm, info |
| 229 | + |
| 230 | + |
| 231 | +# --------------------------------------------------------------------------- |
| 232 | +# noaux_tc routing pattern helpers |
| 233 | +# --------------------------------------------------------------------------- |
| 234 | + |
| 235 | +_TOPK_OPS = (torch.ops.aten.topk.default,) |
| 236 | +_VIEW_OPS = (torch.ops.aten.view.default, torch.ops.aten.reshape.default) |
| 237 | +_ADD_OPS = (torch.ops.aten.add.Tensor,) |
| 238 | + |
| 239 | + |
| 240 | +def _scalar_int(node_or_value) -> Optional[int]: |
| 241 | + """Return *node_or_value* as a Python int if it is a literal, else None.""" |
| 242 | + if isinstance(node_or_value, int): |
| 243 | + return node_or_value |
| 244 | + return None |
| 245 | + |
| 246 | + |
| 247 | +def _find_bias_add_after_sigmoid(sigmoid_node: Node) -> Optional[Tuple[Node, Node]]: |
| 248 | + """Find ``scores + bias`` user of *sigmoid_node*; return (add_node, bias_node).""" |
| 249 | + for user in sigmoid_node.users: |
| 250 | + if not is_op(user, _ADD_OPS): |
| 251 | + continue |
| 252 | + a, b = user.args[0], user.args[1] |
| 253 | + if a is sigmoid_node and isinstance(b, Node): |
| 254 | + return user, b |
| 255 | + if b is sigmoid_node and isinstance(a, Node): |
| 256 | + return user, a |
| 257 | + return None |
| 258 | + |
| 259 | + |
| 260 | +def _find_group_topk(scores_with_bias: Node) -> Optional[Tuple[Node, int]]: |
| 261 | + """Find the ``topk(view(scores_with_bias, ...), k=2)`` user; return (node, n_group).""" |
| 262 | + for user in scores_with_bias.users: |
| 263 | + view_node = user if is_op(user, _VIEW_OPS) else None |
| 264 | + if view_node is None: |
| 265 | + continue |
| 266 | + # view shape can be the second arg (list of ints/Nodes) |
| 267 | + shape = view_node.args[1] if len(view_node.args) > 1 else None |
| 268 | + if not isinstance(shape, (list, tuple)) or len(shape) < 2: |
| 269 | + continue |
| 270 | + n_group = _scalar_int(shape[-2]) |
| 271 | + if n_group is None: |
| 272 | + continue |
| 273 | + for vu in view_node.users: |
| 274 | + if is_op(vu, _TOPK_OPS) and _scalar_int(vu.args[1]) == 2: |
| 275 | + return vu, n_group |
| 276 | + return None |
| 277 | + |
| 278 | + |
| 279 | +def _find_outer_topk(masked_node: Node) -> Optional[Tuple[Node, int]]: |
| 280 | + """Find ``topk(masked, k=top_k)`` user of *masked_node*; return (node, top_k).""" |
| 281 | + for user in masked_node.users: |
| 282 | + candidate = user |
| 283 | + if is_op(candidate, _TOPK_OPS): |
| 284 | + top_k = _scalar_int(candidate.args[1]) |
| 285 | + if top_k is not None: |
| 286 | + return candidate, top_k |
| 287 | + # allow one view in between |
| 288 | + if is_op(candidate, _VIEW_OPS): |
| 289 | + for vu in candidate.users: |
| 290 | + if is_op(vu, _TOPK_OPS): |
| 291 | + top_k = _scalar_int(vu.args[1]) |
| 292 | + if top_k is not None: |
| 293 | + return vu, top_k |
| 294 | + return None |
| 295 | + |
| 296 | + |
| 297 | +def _descends_from(node: Node, target: Node, max_depth: int = 10) -> bool: |
| 298 | + """Return True if *target* is reachable from *node*'s input ancestry within max_depth hops.""" |
| 299 | + if not isinstance(node, Node) or not isinstance(target, Node): |
| 300 | + return False |
| 301 | + visited = set() |
| 302 | + frontier = [(node, 0)] |
| 303 | + while frontier: |
| 304 | + n, d = frontier.pop() |
| 305 | + if n is target: |
| 306 | + return True |
| 307 | + if d >= max_depth or n in visited: |
| 308 | + continue |
| 309 | + visited.add(n) |
| 310 | + for inp in n.all_input_nodes: |
| 311 | + frontier.append((inp, d + 1)) |
| 312 | + return False |
| 313 | + |
| 314 | + |
| 315 | +def _find_gather_from_indices(indices_node: Node, scores_node: Node) -> Optional[Node]: |
| 316 | + """Find ``aten.gather.default(scores_node, dim, indices_node)`` user of *indices_node*.""" |
| 317 | + for user in indices_node.users: |
| 318 | + if not is_op(user, torch.ops.aten.gather.default): |
| 319 | + continue |
| 320 | + if len(user.args) >= 3 and user.args[0] is scores_node and user.args[2] is indices_node: |
| 321 | + return user |
| 322 | + return None |
| 323 | + |
| 324 | + |
| 325 | +def _is_sum_of(node, cur: Node) -> bool: |
| 326 | + return ( |
| 327 | + isinstance(node, Node) |
| 328 | + and is_op(node, torch.ops.aten.sum.dim_IntList) |
| 329 | + and node.args[0] is cur |
| 330 | + ) |
| 331 | + |
| 332 | + |
| 333 | +def _is_normalize_divisor(divisor, cur: Node) -> bool: |
| 334 | + """Accept ``sum(cur, ...)`` or its epsilon-stabilized form ``sum(...) + eps``.""" |
| 335 | + if _is_sum_of(divisor, cur): |
| 336 | + return True |
| 337 | + if isinstance(divisor, Node) and is_op(divisor, torch.ops.aten.add.Tensor): |
| 338 | + a, b = divisor.args[0], divisor.args[1] |
| 339 | + for sum_cand, eps_cand in ((a, b), (b, a)): |
| 340 | + if _is_sum_of(sum_cand, cur) and isinstance(eps_cand, (int, float)): |
| 341 | + return True |
| 342 | + return False |
| 343 | + |
| 344 | + |
| 345 | +def _walk_div_then_mul(start: Node) -> Tuple[Node, float]: |
| 346 | + """Walk forward through optional ``div.Tensor(self, sum)`` then ``mul.Tensor(self, scalar)``. |
| 347 | +
|
| 348 | + Returns ``(final_node, routed_scaling_factor)``. If no scale is found, the |
| 349 | + factor defaults to ``1.0`` and *final_node* is the last node reached on the |
| 350 | + chain (gather, or div if no mul, etc.). |
| 351 | + """ |
| 352 | + cur = start |
| 353 | + # optional norm: div(cur, sum(cur, ..., keepdim=True) [+ eps]) |
| 354 | + for user in cur.users: |
| 355 | + if ( |
| 356 | + is_op(user, torch.ops.aten.div.Tensor) |
| 357 | + and user.args[0] is cur |
| 358 | + and _is_normalize_divisor(user.args[1], cur) |
| 359 | + ): |
| 360 | + cur = user |
| 361 | + break |
| 362 | + # optional scalar multiply |
| 363 | + for user in cur.users: |
| 364 | + if is_op(user, torch.ops.aten.mul.Tensor) and user.args[0] is cur: |
| 365 | + scalar = user.args[1] |
| 366 | + if isinstance(scalar, (int, float)): |
| 367 | + return user, float(scalar) |
| 368 | + return cur, 1.0 |
| 369 | + |
| 370 | + |
| 371 | +@TransformRegistry.register("match_noaux_tc_pattern") |
| 372 | +class MatchNoAuxTCPattern(BaseTransform): |
| 373 | + """Match the noaux_tc MoE routing chain and replace with a fused trtllm op. |
| 374 | +
|
| 375 | + This transform detects the DeepSeek-V3 style routing pattern:: |
| 376 | +
|
| 377 | + sigmoid → +bias → group top-k → mask → top-k → gather → [norm] → scale |
| 378 | +
|
| 379 | + and replaces it with:: |
| 380 | +
|
| 381 | + topk_weights, topk_idx = trtllm.noaux_tc_op( |
| 382 | + router_logits, bias, n_group, topk_group, top_k, routed_scaling_factor |
| 383 | + ) |
| 384 | +
|
| 385 | + The fused kernel performs sigmoid, bias correction, group-based top-k |
| 386 | + selection, gather, normalization and scaling in a single CUDA kernel. |
| 387 | + """ |
| 388 | + |
| 389 | + config: TransformConfig |
| 390 | + |
| 391 | + @classmethod |
| 392 | + def get_config_class(cls) -> Type[TransformConfig]: |
| 393 | + return TransformConfig |
| 394 | + |
| 395 | + def _apply( |
| 396 | + self, |
| 397 | + gm: GraphModule, |
| 398 | + cm: CachedSequenceInterface, |
| 399 | + factory: ModelFactory, |
| 400 | + shared_config: SharedConfig, |
| 401 | + ) -> Tuple[GraphModule, TransformInfo]: |
| 402 | + graph = gm.graph |
| 403 | + num_matches = 0 |
| 404 | + |
| 405 | + for node in list(graph.nodes): |
| 406 | + # ---- anchor: aten.sigmoid -> add(bias) ------------------------ |
| 407 | + if not is_op(node, torch.ops.aten.sigmoid.default): |
| 408 | + continue |
| 409 | + sigmoid_node = node |
| 410 | + router_logits = sigmoid_node.args[0] |
| 411 | + if not isinstance(router_logits, Node): |
| 412 | + continue |
| 413 | + |
| 414 | + bias_add = _find_bias_add_after_sigmoid(sigmoid_node) |
| 415 | + if bias_add is None: |
| 416 | + continue |
| 417 | + scores_with_bias_node, bias_node = bias_add |
| 418 | + |
| 419 | + # ---- group top-k(k=2) ---------------------------------------- |
| 420 | + inner = _find_group_topk(scores_with_bias_node) |
| 421 | + if inner is None: |
| 422 | + continue |
| 423 | + inner_topk, n_group = inner |
| 424 | + |
| 425 | + inner_values = _get_single_getitem_user(inner_topk, 0) |
| 426 | + if inner_values is None: |
| 427 | + continue |
| 428 | + |
| 429 | + # ---- sum -> outer topk(k=topk_group) ------------------------- |
| 430 | + sum_node = None |
| 431 | + for u in inner_values.users: |
| 432 | + if is_op(u, torch.ops.aten.sum.dim_IntList): |
| 433 | + sum_node = u |
| 434 | + break |
| 435 | + if sum_node is None: |
| 436 | + continue |
| 437 | + |
| 438 | + outer_grp_topk = None |
| 439 | + for u in sum_node.users: |
| 440 | + if is_op(u, _TOPK_OPS): |
| 441 | + outer_grp_topk = u |
| 442 | + break |
| 443 | + if outer_grp_topk is None: |
| 444 | + continue |
| 445 | + topk_group = _scalar_int(outer_grp_topk.args[1]) |
| 446 | + if topk_group is None: |
| 447 | + continue |
| 448 | + |
| 449 | + # ---- final masked top-k(k=top_k) ----------------------------- |
| 450 | + # Only accept a masked_node whose mask input descends from outer_grp_topk; |
| 451 | + # otherwise an unrelated branch consuming scores_with_bias could be picked. |
| 452 | + masked_node = None |
| 453 | + for u in scores_with_bias_node.users: |
| 454 | + if not is_op( |
| 455 | + u, |
| 456 | + ( |
| 457 | + torch.ops.aten.where.self, |
| 458 | + torch.ops.aten.masked_fill.Scalar, |
| 459 | + torch.ops.aten.mul.Tensor, |
| 460 | + ), |
| 461 | + ): |
| 462 | + continue |
| 463 | + if not _descends_from(u, outer_grp_topk): |
| 464 | + continue |
| 465 | + masked_node = u |
| 466 | + break |
| 467 | + if masked_node is None: |
| 468 | + continue |
| 469 | + |
| 470 | + outer_topk = _find_outer_topk(masked_node) |
| 471 | + if outer_topk is None: |
| 472 | + continue |
| 473 | + final_topk_node, top_k = outer_topk |
| 474 | + |
| 475 | + final_indices = _get_single_getitem_user(final_topk_node, 1) |
| 476 | + if final_indices is None: |
| 477 | + continue |
| 478 | + |
| 479 | + # ---- weights branch: gather(scores, -1, topk_idx) [/ sum] * scale -- |
| 480 | + gather_node = _find_gather_from_indices(final_indices, sigmoid_node) |
| 481 | + if gather_node is None: |
| 482 | + continue |
| 483 | + weights_tail, routed_scaling_factor = _walk_div_then_mul(gather_node) |
| 484 | + |
| 485 | + # ---- emit fused noaux_tc_op --------------------------------- |
| 486 | + ad_logger.info( |
| 487 | + "Matched noaux_tc routing pattern: " |
| 488 | + f"n_group={n_group}, topk_group={topk_group}, top_k={top_k}, " |
| 489 | + f"scale={routed_scaling_factor}" |
| 490 | + ) |
| 491 | + |
| 492 | + with graph.inserting_before(sigmoid_node): |
| 493 | + fused = graph.call_function( |
| 494 | + torch.ops.trtllm.noaux_tc_op, |
| 495 | + args=( |
| 496 | + router_logits, |
| 497 | + bias_node, |
| 498 | + n_group, |
| 499 | + topk_group, |
| 500 | + top_k, |
| 501 | + routed_scaling_factor, |
| 502 | + ), |
| 503 | + ) |
| 504 | + fused_weights = graph.call_function(operator.getitem, args=(fused, 0)) |
| 505 | + fused_indices = graph.call_function(operator.getitem, args=(fused, 1)) |
| 506 | + |
| 507 | + final_indices.replace_all_uses_with(fused_indices) |
| 508 | + weights_tail.replace_all_uses_with(fused_weights) |
| 509 | + |
| 510 | + num_matches += 1 |
| 511 | + |
| 512 | + if num_matches > 0: |
| 513 | + eliminate_dead_code(gm) |
| 514 | + gm.recompile() |
| 515 | + ad_logger.info(f"Fused {num_matches} noaux_tc routing pattern(s).") |
| 516 | + |
| 517 | + info = TransformInfo( |
| 518 | + skipped=False, |
| 519 | + num_matches=num_matches, |
| 520 | + is_clean=num_matches == 0, |
| 521 | + has_valid_shapes=num_matches == 0, |
| 522 | + ) |
| 523 | + return gm, info |
0 commit comments