Skip to content

Commit 490ec5c

Browse files
authored
Fix Multiple constraints for allocation for two cat inputs of same underlying tensor (pytorch#18830)
Differential Revision: D100494796 Pull Request resolved: pytorch#18830
1 parent 97a86bb commit 490ec5c

1 file changed

Lines changed: 44 additions & 9 deletions

File tree

backends/cadence/aot/memory_constraints.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,45 @@ def is_cat_along_outermost_dim(
452452
return False
453453
return True
454454

455+
def _has_duplicate_resolved_sources(
456+
self, cat_tensors: Sequence[torch.fx.Node]
457+
) -> bool:
458+
"""Return True if two cat inputs resolve to the same underlying tensor."""
459+
if len(cat_tensors) != len(set(cat_tensors)):
460+
return True
461+
resolved_sources = set()
462+
for arg in cat_tensors:
463+
resolved = arg
464+
while (
465+
info := self.constraint.get_relative_placement_source(resolved)
466+
) is not None:
467+
if self.constraint.is_alias_of(info.source, resolved):
468+
resolved = info.source
469+
else:
470+
break
471+
if id(resolved) in resolved_sources:
472+
return True
473+
resolved_sources.add(id(resolved))
474+
return False
475+
476+
def _has_unaligned_cat_tensors(
477+
self,
478+
graph: torch.fx.Graph,
479+
node: torch.fx.Node,
480+
cat_tensors: Sequence[torch.fx.Node],
481+
) -> bool:
482+
"""Return True if any non-placeholder cat tensor has misaligned offset."""
483+
if is_node_in_flattened_output(graph, node):
484+
return False
485+
expected_alignment = 8
486+
relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors)
487+
for idx, arg in enumerate(cat_tensors):
488+
if not (arg.op == "placeholder") and (
489+
relative_offsets[idx] & (expected_alignment - 1) != 0
490+
):
491+
return True
492+
return False
493+
455494
# If A = cat(B, C), and the concatenation is along the outermost dimension, then
456495
# we can optimize away this cat operation if (1) B and C are placed contiguously,
457496
# and (2) the absolute memory location of tensor A is the same as B. This function
@@ -486,21 +525,17 @@ def is_removable_cat_op(
486525
return False
487526
# If the same tensor appears multiple times in the cat inputs,
488527
# we cannot place it at multiple different offsets relative to the output.
489-
if len(cat_tensors) != len(set(cat_tensors)):
528+
# Also check resolved sources: two different alias nodes may resolve to
529+
# the same underlying tensor, which can't be at two offsets.
530+
if self._has_duplicate_resolved_sources(cat_tensors):
490531
return False
491532

492533
# Many ops in HiFi require the input to be aligned to 8-byte boundary.
493534
# If the cat is not the graph's output, then ensure that the relative
494535
# offset of any concatenated non-placeholder tensor is a multiple of
495536
# 8 bytes,
496-
if not is_node_in_flattened_output(graph_module.graph, node):
497-
expected_alignment = 8
498-
relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors)
499-
for idx, arg in enumerate(cat_tensors):
500-
if not (arg.op == "placeholder") and (
501-
relative_offsets[idx] & (expected_alignment - 1) != 0
502-
):
503-
return False
537+
if self._has_unaligned_cat_tensors(graph_module.graph, node, cat_tensors):
538+
return False
504539

505540
return True
506541

0 commit comments

Comments
 (0)