Skip to content

Commit c391738

Browse files
authored
Improve RemovePermutesAroundElementwiseOps robustness (pytorch#18980)
Differential Revision: D101430531 Pull Request resolved: pytorch#18980
1 parent 6c82462 commit c391738

1 file changed

Lines changed: 88 additions & 1 deletion

File tree

backends/cadence/aot/remove_ops.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,16 @@ class Subgraph:
404404
edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set)
405405
# Outgoing edges of the subgraph to permute nodes.
406406
edges_out: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set)
407+
# Incoming edges from constant nodes that need a compensating permute.
408+
constant_edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(
409+
default_factory=set
410+
)
407411

408412
permutable_ops: set[EdgeOpOverload] = {
409413
exir_ops.edge.aten.add.Tensor,
410414
exir_ops.edge.aten.mul.Tensor,
411415
exir_ops.edge.aten.hardtanh.default,
416+
exir_ops.edge.aten.clamp.default,
412417
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
413418
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
414419
exir_ops.edge.cadence.quantize_per_tensor.default,
@@ -455,7 +460,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
455460

456461
return PassResult(graph_module, False)
457462

458-
def visit(
463+
def visit( # noqa: C901
459464
self,
460465
node: torch.fx.Node,
461466
subgraph: Subgraph,
@@ -474,6 +479,11 @@ def visit(
474479
if self.get_permutation(user) != subgraph.end_permute:
475480
return False
476481
subgraph.edges_out.add((node, user))
482+
elif user.op == "output":
483+
# Graph output requires the data in its original layout.
484+
# Removing permutes here would silently change the output
485+
# format, so treat this as an invalid subgraph boundary.
486+
return False
477487
elif not self.visit(user, subgraph, processed_nodes):
478488
return False
479489

@@ -484,11 +494,44 @@ def visit(
484494
if self.get_permutation(inp) != subgraph.start_permute:
485495
return False
486496
subgraph.edges_in.add((inp, node))
497+
elif self._is_constant(inp):
498+
# Only accept the constant if we can compensate it with a
499+
# permute or view. Otherwise reject the subgraph.
500+
const_rank = self._get_node_rank(inp)
501+
if const_rank is None:
502+
return False
503+
if const_rank > len(subgraph.end_permute):
504+
return False
505+
if (
506+
const_rank < len(subgraph.end_permute)
507+
and inp.meta.get("val") is None
508+
):
509+
return False
510+
subgraph.constant_edges_in.add((inp, node))
487511
elif not self.visit(inp, subgraph, processed_nodes):
488512
return False
489513

490514
return True
491515

516+
def _is_constant(self, node: torch.fx.Node) -> bool:
517+
"""Check if a node's value is available at compile time.
518+
Only considers direct constants (get_attr, parameter/buffer/constant
519+
placeholders) — does not recurse into call_function chains to avoid
520+
stack overflow on deep graphs."""
521+
if node.op == "get_attr":
522+
return True
523+
if node.op == "placeholder":
524+
target = str(node.target)
525+
return target.startswith(("b_", "p_", "c_"))
526+
return False
527+
528+
def _get_node_rank(self, node: torch.fx.Node) -> int | None:
529+
"""Return the tensor rank of a node's output, or None if unknown."""
530+
val = node.meta.get("val")
531+
if val is not None and hasattr(val, "shape"):
532+
return len(val.shape)
533+
return None
534+
492535
def is_node_permutable(self, node: torch.fx.Node) -> bool:
493536
if node.target not in self.permutable_ops:
494537
return False
@@ -514,6 +557,50 @@ def permute_subgraph(self, subgraph: Subgraph) -> None:
514557
else:
515558
out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"]))
516559

560+
# Insert compensating permute on constant inputs.
561+
# Since the subgraph's start permutes are being removed, the subgraph
562+
# will operate in the un-permuted (original) layout. Constants that
563+
# were in the permuted layout need end_permute (the inverse of
564+
# start_permute) to convert back to the original layout.
565+
for const_node, user_node in subgraph.constant_edges_in:
566+
graph = const_node.graph
567+
const_rank = self._get_node_rank(const_node)
568+
permute_rank = len(subgraph.end_permute)
569+
570+
with graph.inserting_after(const_node):
571+
if const_rank is not None and const_rank == permute_rank:
572+
new_node = graph.create_node(
573+
"call_function",
574+
exir_ops.edge.aten.permute_copy.default,
575+
args=(const_node, subgraph.end_permute),
576+
)
577+
elif (
578+
const_rank is not None
579+
and const_rank < permute_rank
580+
and const_node.meta.get("val") is not None
581+
):
582+
# Rank mismatch (e.g. rank-1 bias with rank-4 permute).
583+
# The constant is broadcastable and its shape is smaller
584+
# than the permute rank, so we can't apply the permute
585+
# directly. Instead, use view_copy to rearrange the
586+
# shape according to the end_permute restricted to
587+
# the trailing dimensions.
588+
original_shape = list(const_node.meta["val"].shape)
589+
# Pad shape to match permute rank for reordering
590+
padded = [1] * (permute_rank - const_rank) + original_shape
591+
target_shape = [padded[d] for d in subgraph.end_permute]
592+
# Strip leading 1s back to original rank
593+
target_shape = target_shape[permute_rank - const_rank :]
594+
new_node = graph.create_node(
595+
"call_function",
596+
exir_ops.edge.aten.view_copy.default,
597+
args=(const_node, target_shape),
598+
)
599+
else:
600+
# Cannot determine rank or handle this case; skip.
601+
continue
602+
user_node.replace_input_with(const_node, new_node)
603+
517604
# Skip outgoing permutes.
518605
for inp, out in subgraph.edges_out:
519606
assert out.target == exir_ops.edge.aten.permute_copy.default

0 commit comments

Comments
 (0)