Skip to content

Commit b24a88e

Browse files
authored
Fix noop pass when there are dynamic shapes (#17678)
The original pass only checked orig_tensor, but it should also check output_tensor
1 parent f478cb3 commit b24a88e

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

exir/passes/remove_noop_pass.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
7777
continue
7878

7979
if node.target == torch.ops.aten.slice_copy.Tensor:
80-
# Only do this check if all the dims are static.
81-
if all(isinstance(dim, int) for dim in orig_tensor.size()):
82-
if orig_tensor.shape == node.meta["val"].shape:
80+
output_tensor = node.meta["val"]
81+
# Only do this check if all dims are static on both sides.
82+
# The output may contain unbacked SymInts (e.g. from
83+
# data-dependent slicing with .item()), so we must check
84+
# both tensors before comparing shapes.
85+
if all(isinstance(dim, int) for dim in orig_tensor.size()) and all(
86+
isinstance(dim, int) for dim in output_tensor.size()
87+
):
88+
if orig_tensor.shape == output_tensor.shape:
8389
# If the graph is quantized, we must remove the entire pattern consisting of dq->op->q.
8490
# Otherwise, removing only the op will suffice.
8591
if node.args[0].target in _DEQUANT_OPS:

0 commit comments

Comments
 (0)