|
34 | 34 | from executorch.exir.pass_manager import PassManager, PassType |
35 | 35 | from executorch.exir.passes import dead_code_elimination_pass |
36 | 36 | from torch.fx.node import Node |
| 37 | +from torch.utils import _pytree as pytree |
37 | 38 |
|
38 | 39 |
|
39 | 40 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
@@ -387,6 +388,219 @@ def maybe_remove_or_replace(self, node: Node) -> bool: |
387 | 388 | return False |
388 | 389 |
|
389 | 390 |
|
| 391 | +@register_cadence_pass(CadencePassAttribute(opt_level=1)) |
| 392 | +class RemovePermuteBeforeMeanPass(RemoveOrReplacePassInterface): |
| 393 | + """Remove or sink permute ops that precede mean reductions through unary chains. |
| 394 | +
|
| 395 | + When a permute feeds into a mean (possibly through unary ops like |
| 396 | + dequantize/quantize), two optimizations apply: |
| 397 | +
|
| 398 | + 1. If non-reduced dims maintain their relative order and positions, the |
| 399 | + permute is fully removed and the mean's reduction dims are remapped. |
| 400 | + 2. Otherwise, the permute is moved after the mean so it operates on |
| 401 | + smaller data. |
| 402 | +
|
| 403 | + Cost model |
| 404 | + ---------- |
| 405 | + Let S_in = input size in bytes, S_out = output size in bytes, |
| 406 | + R = S_in / S_out (reduction ratio), C_c = contiguous mean compute |
| 407 | + cost, C_s = strided mean compute cost, and Delta = C_s - C_c. |
| 408 | +
|
| 409 | + Original graph (permute -> mean): |
| 410 | + Cost_orig = 3*S_in + S_out + C_c |
| 411 | + Breakdown: permute reads and writes S_in (2*S_in), mean reads |
| 412 | + S_in and writes S_out. |
| 413 | +
|
| 414 | + Case 1 -- full removal (mean with remapped dims, no permute): |
| 415 | + Cost_remove = S_in + S_out + C_s |
| 416 | + Profitable when: Delta < 2*S_in |
| 417 | + The strided access penalty must be less than twice the full |
| 418 | + input tensor I/O (the eliminated permute cost). |
| 419 | +
|
| 420 | + Case 2 -- reorder (mean with remapped dims -> small permute): |
| 421 | + Cost_reorder = S_in + 3*S_out + C_s |
| 422 | + Profitable when: Delta < 2*(S_in - S_out) = 2*S_in*(R-1)/R |
| 423 | + At R = 4 the threshold is 1.5*S_in; at R = 16 it approaches |
| 424 | + 2*S_in, converging to the full-removal bound as S_out becomes |
| 425 | + negligible. |
| 426 | +
|
| 427 | + Full removal always dominates reorder when both are structurally |
| 428 | + possible (Cost_remove < Cost_reorder since S_out < 3*S_out), so |
| 429 | + removal is applied without a cost gate. |
| 430 | +
|
| 431 | + Additionally, if the original permute does not place the reduction |
| 432 | + dims as the trailing (innermost) dimensions, the mean is already |
| 433 | + strided in the original graph. Removing or sinking the permute |
| 434 | + saves the permute I/O (2*S_in) while the mean performance can only |
| 435 | + improve or stay the same. This makes the transformation |
| 436 | + unconditionally profitable without needing the cost model. |
| 437 | + """ |
| 438 | + |
| 439 | + _UNARY_TARGETS: frozenset[EdgeOpOverload] = frozenset( |
| 440 | + { |
| 441 | + exir_ops.edge.cadence.dequantize_per_tensor.default, |
| 442 | + exir_ops.edge.cadence.quantize_per_tensor.default, |
| 443 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 444 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 445 | + exir_ops.edge.aten.clone.default, |
| 446 | + exir_ops.edge.aten.relu.default, |
| 447 | + exir_ops.edge.aten.neg.default, |
| 448 | + exir_ops.edge.aten.abs.default, |
| 449 | + } |
| 450 | + ) |
| 451 | + |
| 452 | + _MIN_REDUCTION_RATIO_FOR_REORDER: int = 4 |
| 453 | + |
| 454 | + @property |
| 455 | + def targets(self) -> list[EdgeOpOverload]: |
| 456 | + return [exir_ops.edge.aten.mean.dim] |
| 457 | + |
| 458 | + def _find_permute_through_unary_chain(self, mean_node: Node) -> Optional[Node]: |
| 459 | + """Walk backward from mean through single-user unary ops to find a permute.""" |
| 460 | + current = mean_node.args[0] |
| 461 | + if not isinstance(current, Node): |
| 462 | + return None |
| 463 | + while True: |
| 464 | + if current.target == exir_ops.edge.aten.permute_copy.default: |
| 465 | + return current |
| 466 | + if current.target not in self._UNARY_TARGETS: |
| 467 | + return None |
| 468 | + if len(current.users) != 1: |
| 469 | + return None |
| 470 | + parent = current.args[0] |
| 471 | + if not isinstance(parent, Node): |
| 472 | + return None |
| 473 | + current = parent |
| 474 | + |
| 475 | + @staticmethod |
| 476 | + def _reduction_dims_are_trailing(reduction_dims: list[int], ndim: int) -> bool: |
| 477 | + """Check whether all reduction dims are the trailing (innermost) dims.""" |
| 478 | + canonical = sorted(d % ndim for d in reduction_dims) |
| 479 | + return canonical == list(range(ndim - len(canonical), ndim)) |
| 480 | + |
| 481 | + def _is_reorder_profitable( |
| 482 | + self, |
| 483 | + reduction_dims: list[int], |
| 484 | + new_reduction_dims: list[int], |
| 485 | + ndim: int, |
| 486 | + input_shape: torch.Size, |
| 487 | + ) -> bool: |
| 488 | + """Determine whether sinking the permute past the mean is profitable. |
| 489 | +
|
| 490 | + Three cases, in order of evaluation: |
| 491 | +
|
| 492 | + 1. The original permute does not place the reduction dims as trailing |
| 493 | + (innermost) dimensions. The mean is already strided, so removing |
| 494 | + the permute saves 2*S_in with no mean degradation. Always |
| 495 | + profitable. |
| 496 | +
|
| 497 | + 2. The new (remapped) reduction dims are trailing in the original |
| 498 | + layout. The mean stays contiguous after the transformation |
| 499 | + (Delta ~ 0). Always profitable. |
| 500 | +
|
| 501 | + 3. The mean becomes strided. The reorder is profitable when |
| 502 | + Delta < 2*S_in*(R-1)/R. We approximate this by requiring |
| 503 | + R >= _MIN_REDUCTION_RATIO_FOR_REORDER, ensuring the I/O savings |
| 504 | + from operating on a smaller tensor outweigh the strided access |
| 505 | + penalty. |
| 506 | + """ |
| 507 | + if not self._reduction_dims_are_trailing(reduction_dims, ndim): |
| 508 | + return True |
| 509 | + |
| 510 | + if self._reduction_dims_are_trailing(new_reduction_dims, ndim): |
| 511 | + return True |
| 512 | + |
| 513 | + reduction_ratio = 1 |
| 514 | + canonical_reduction = set(new_reduction_dims) |
| 515 | + for d in canonical_reduction: |
| 516 | + reduction_ratio *= input_shape[d] |
| 517 | + |
| 518 | + return reduction_ratio >= self._MIN_REDUCTION_RATIO_FOR_REORDER |
| 519 | + |
| 520 | + def maybe_remove_or_replace(self, node: Node) -> bool: |
| 521 | + reduction_dims = get_arg(node, "dim", list[int]) |
| 522 | + |
| 523 | + permute_node = self._find_permute_through_unary_chain(node) |
| 524 | + if permute_node is None: |
| 525 | + return False |
| 526 | + |
| 527 | + perm = get_arg(permute_node, "dims", list[int]) |
| 528 | + ndim = len(perm) |
| 529 | + |
| 530 | + if len(permute_node.users) != 1: |
| 531 | + return False |
| 532 | + |
| 533 | + permute_input = permute_node.args[0] |
| 534 | + assert isinstance(permute_input, Node) |
| 535 | + |
| 536 | + new_reduction_dims = [perm[d] for d in reduction_dims] |
| 537 | + keepdim = get_arg(node, "keepdim", bool) |
| 538 | + |
| 539 | + # Determine if the permute can be fully removed (post-mean permute |
| 540 | + # would be a no-op) vs needing to be sunk after the mean. |
| 541 | + canonical_reduction = set(new_reduction_dims) |
| 542 | + if keepdim: |
| 543 | + can_remove = all( |
| 544 | + perm[d] == d for d in range(ndim) if d not in canonical_reduction |
| 545 | + ) |
| 546 | + else: |
| 547 | + non_reduced_in_perm_order = [ |
| 548 | + d for d in perm if d not in canonical_reduction |
| 549 | + ] |
| 550 | + can_remove = non_reduced_in_perm_order == sorted(non_reduced_in_perm_order) |
| 551 | + |
| 552 | + # Full removal is almost always profitable. Reorder requires a |
| 553 | + # tighter cost bound; verify via the cost model before proceeding. |
| 554 | + if not can_remove: |
| 555 | + input_shape = permute_input.meta["val"].shape |
| 556 | + if not self._is_reorder_profitable( |
| 557 | + reduction_dims, new_reduction_dims, ndim, input_shape |
| 558 | + ): |
| 559 | + return False |
| 560 | + |
| 561 | + # Rewire: the permute's single user (either the mean itself or the |
| 562 | + # first unary op in the chain) should read from the permute's input. |
| 563 | + permute_user = next(iter(permute_node.users)) |
| 564 | + permute_user.replace_input_with(permute_node, permute_input) |
| 565 | + node.args = (node.args[0], new_reduction_dims) + node.args[2:] |
| 566 | + |
| 567 | + # Re-derive the mean's meta since its reduction dims changed. |
| 568 | + fake_args = pytree.tree_map( |
| 569 | + lambda x: x.meta["val"] if isinstance(x, Node) else x, |
| 570 | + node.args, |
| 571 | + ) |
| 572 | + node.meta["val"] = node.target(*fake_args) # pyre-ignore[29] |
| 573 | + |
| 574 | + if not can_remove: |
| 575 | + # Compute the post-mean permute on the reduced output. |
| 576 | + if keepdim: |
| 577 | + post_perm = list(perm) |
| 578 | + else: |
| 579 | + non_reduced_original = sorted( |
| 580 | + d for d in range(ndim) if d not in canonical_reduction |
| 581 | + ) |
| 582 | + non_reduced_permuted = [d for d in perm if d not in canonical_reduction] |
| 583 | + post_perm = [ |
| 584 | + non_reduced_original.index(d) for d in non_reduced_permuted |
| 585 | + ] |
| 586 | + |
| 587 | + graph = node.graph |
| 588 | + with graph.inserting_after(node): |
| 589 | + new_permute = graph.create_node( |
| 590 | + "call_function", |
| 591 | + exir_ops.edge.aten.permute_copy.default, |
| 592 | + args=(node, post_perm), |
| 593 | + ) |
| 594 | + new_permute.meta["val"] = exir_ops.edge.aten.permute_copy.default( |
| 595 | + node.meta["val"], post_perm |
| 596 | + ) |
| 597 | + for user in list(node.users): |
| 598 | + if user is not new_permute: |
| 599 | + user.replace_input_with(node, new_permute) |
| 600 | + |
| 601 | + return True |
| 602 | + |
| 603 | + |
390 | 604 | @register_cadence_pass(CadencePassAttribute(opt_level=2)) |
391 | 605 | class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps): |
392 | 606 | permutable_ops: set[EdgeOpOverload] = ( |
@@ -646,6 +860,7 @@ class CommonRemovePasses: |
646 | 860 | RemoveNopSliceOrViewOpPass, |
647 | 861 | RemoveToOpsPass, |
648 | 862 | RemoveZeroSizedCatArgsPass, |
| 863 | + RemovePermuteBeforeMeanPass, |
649 | 864 | RemovePermutesAroundElementwiseOps, |
650 | 865 | FuseTransposeOrPermuteOpPairsPass, |
651 | 866 | RemoveSqueezeViewBeforeElementwiseOps, |
|
0 commit comments