@@ -404,11 +404,16 @@ class Subgraph:
404404 edges_in : set [tuple [torch .fx .Node , torch .fx .Node ]] = field (default_factory = set )
405405 # Outgoing edges of the subgraph to permute nodes.
406406 edges_out : set [tuple [torch .fx .Node , torch .fx .Node ]] = field (default_factory = set )
407+ # Incoming edges from constant nodes that need a compensating permute.
408+ constant_edges_in : set [tuple [torch .fx .Node , torch .fx .Node ]] = field (
409+ default_factory = set
410+ )
407411
408412 permutable_ops : set [EdgeOpOverload ] = {
409413 exir_ops .edge .aten .add .Tensor ,
410414 exir_ops .edge .aten .mul .Tensor ,
411415 exir_ops .edge .aten .hardtanh .default ,
416+ exir_ops .edge .aten .clamp .default ,
412417 exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
413418 exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
414419 exir_ops .edge .cadence .quantize_per_tensor .default ,
@@ -455,7 +460,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
455460
456461 return PassResult (graph_module , False )
457462
458- def visit (
463+ def visit ( # noqa: C901
459464 self ,
460465 node : torch .fx .Node ,
461466 subgraph : Subgraph ,
@@ -474,6 +479,11 @@ def visit(
474479 if self .get_permutation (user ) != subgraph .end_permute :
475480 return False
476481 subgraph .edges_out .add ((node , user ))
482+ elif user .op == "output" :
483+ # Graph output requires the data in its original layout.
484+ # Removing permutes here would silently change the output
485+ # format, so treat this as an invalid subgraph boundary.
486+ return False
477487 elif not self .visit (user , subgraph , processed_nodes ):
478488 return False
479489
@@ -484,11 +494,44 @@ def visit(
484494 if self .get_permutation (inp ) != subgraph .start_permute :
485495 return False
486496 subgraph .edges_in .add ((inp , node ))
497+ elif self ._is_constant (inp ):
498+ # Only accept the constant if we can compensate it with a
499+ # permute or view. Otherwise reject the subgraph.
500+ const_rank = self ._get_node_rank (inp )
501+ if const_rank is None :
502+ return False
503+ if const_rank > len (subgraph .end_permute ):
504+ return False
505+ if (
506+ const_rank < len (subgraph .end_permute )
507+ and inp .meta .get ("val" ) is None
508+ ):
509+ return False
510+ subgraph .constant_edges_in .add ((inp , node ))
487511 elif not self .visit (inp , subgraph , processed_nodes ):
488512 return False
489513
490514 return True
491515
516+ def _is_constant (self , node : torch .fx .Node ) -> bool :
517+ """Check if a node's value is available at compile time.
518+ Only considers direct constants (get_attr, parameter/buffer/constant
519+ placeholders) — does not recurse into call_function chains to avoid
520+ stack overflow on deep graphs."""
521+ if node .op == "get_attr" :
522+ return True
523+ if node .op == "placeholder" :
524+ target = str (node .target )
525+ return target .startswith (("b_" , "p_" , "c_" ))
526+ return False
527+
528+ def _get_node_rank (self , node : torch .fx .Node ) -> int | None :
529+ """Return the tensor rank of a node's output, or None if unknown."""
530+ val = node .meta .get ("val" )
531+ if val is not None and hasattr (val , "shape" ):
532+ return len (val .shape )
533+ return None
534+
492535 def is_node_permutable (self , node : torch .fx .Node ) -> bool :
493536 if node .target not in self .permutable_ops :
494537 return False
@@ -514,6 +557,50 @@ def permute_subgraph(self, subgraph: Subgraph) -> None:
514557 else :
515558 out .replace_input_with (inp , cast (torch .fx .Node , inp .kwargs ["input" ]))
516559
560+ # Insert compensating permute on constant inputs.
561+ # Since the subgraph's start permutes are being removed, the subgraph
562+ # will operate in the un-permuted (original) layout. Constants that
563+ # were in the permuted layout need end_permute (the inverse of
564+ # start_permute) to convert back to the original layout.
565+ for const_node , user_node in subgraph .constant_edges_in :
566+ graph = const_node .graph
567+ const_rank = self ._get_node_rank (const_node )
568+ permute_rank = len (subgraph .end_permute )
569+
570+ with graph .inserting_after (const_node ):
571+ if const_rank is not None and const_rank == permute_rank :
572+ new_node = graph .create_node (
573+ "call_function" ,
574+ exir_ops .edge .aten .permute_copy .default ,
575+ args = (const_node , subgraph .end_permute ),
576+ )
577+ elif (
578+ const_rank is not None
579+ and const_rank < permute_rank
580+ and const_node .meta .get ("val" ) is not None
581+ ):
582+ # Rank mismatch (e.g. rank-1 bias with rank-4 permute).
583+ # The constant is broadcastable and its shape is smaller
584+ # than the permute rank, so we can't apply the permute
585+ # directly. Instead, use view_copy to rearrange the
586+ # shape according to the end_permute restricted to
587+ # the trailing dimensions.
588+ original_shape = list (const_node .meta ["val" ].shape )
589+ # Pad shape to match permute rank for reordering
590+ padded = [1 ] * (permute_rank - const_rank ) + original_shape
591+ target_shape = [padded [d ] for d in subgraph .end_permute ]
592+ # Strip leading 1s back to original rank
593+ target_shape = target_shape [permute_rank - const_rank :]
594+ new_node = graph .create_node (
595+ "call_function" ,
596+ exir_ops .edge .aten .view_copy .default ,
597+ args = (const_node , target_shape ),
598+ )
599+ else :
600+ # Cannot determine rank or handle this case; skip.
601+ continue
602+ user_node .replace_input_with (const_node , new_node )
603+
517604 # Skip outgoing permutes.
518605 for inp , out in subgraph .edges_out :
519606 assert out .target == exir_ops .edge .aten .permute_copy .default
0 commit comments