@@ -452,6 +452,45 @@ def is_cat_along_outermost_dim(
452452 return False
453453 return True
454454
455+ def _has_duplicate_resolved_sources (
456+ self , cat_tensors : Sequence [torch .fx .Node ]
457+ ) -> bool :
458+ """Return True if two cat inputs resolve to the same underlying tensor."""
459+ if len (cat_tensors ) != len (set (cat_tensors )):
460+ return True
461+ resolved_sources = set ()
462+ for arg in cat_tensors :
463+ resolved = arg
464+ while (
465+ info := self .constraint .get_relative_placement_source (resolved )
466+ ) is not None :
467+ if self .constraint .is_alias_of (info .source , resolved ):
468+ resolved = info .source
469+ else :
470+ break
471+ if id (resolved ) in resolved_sources :
472+ return True
473+ resolved_sources .add (id (resolved ))
474+ return False
475+
476+ def _has_unaligned_cat_tensors (
477+ self ,
478+ graph : torch .fx .Graph ,
479+ node : torch .fx .Node ,
480+ cat_tensors : Sequence [torch .fx .Node ],
481+ ) -> bool :
482+ """Return True if any non-placeholder cat tensor has misaligned offset."""
483+ if is_node_in_flattened_output (graph , node ):
484+ return False
485+ expected_alignment = 8
486+ relative_offsets = get_relative_offsets_of_cat_tensors (cat_tensors )
487+ for idx , arg in enumerate (cat_tensors ):
488+ if not (arg .op == "placeholder" ) and (
489+ relative_offsets [idx ] & (expected_alignment - 1 ) != 0
490+ ):
491+ return True
492+ return False
493+
455494 # If A = cat(B, C), and the concatenation is along the outermost dimension, then
456495 # we can optimize away this cat operation if (1) B and C are placed contiguously,
457496 # and (2) the absolute memory location of tensor A is the same as B. This function
@@ -486,21 +525,17 @@ def is_removable_cat_op(
486525 return False
487526 # If the same tensor appears multiple times in the cat inputs,
488527 # we cannot place it at multiple different offsets relative to the output.
489- if len (cat_tensors ) != len (set (cat_tensors )):
528+ # Also check resolved sources: two different alias nodes may resolve to
529+ # the same underlying tensor, which can't be at two offsets.
530+ if self ._has_duplicate_resolved_sources (cat_tensors ):
490531 return False
491532
492533 # Many ops in HiFi require the input to be aligned to 8-byte boundary.
493534 # If the cat is not the graph's output, then ensure that the relative
494535 # offset of any concatenated non-placeholder tensor is a multiple of
495536 # 8 bytes,
496- if not is_node_in_flattened_output (graph_module .graph , node ):
497- expected_alignment = 8
498- relative_offsets = get_relative_offsets_of_cat_tensors (cat_tensors )
499- for idx , arg in enumerate (cat_tensors ):
500- if not (arg .op == "placeholder" ) and (
501- relative_offsets [idx ] & (expected_alignment - 1 ) != 0
502- ):
503- return False
537+ if self ._has_unaligned_cat_tensors (graph_module .graph , node , cat_tensors ):
538+ return False
504539
505540 return True
506541
0 commit comments