diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py
index a48d88fa224..231a7c9b1b8 100644
--- a/exir/emit/_emitter.py
+++ b/exir/emit/_emitter.py
@@ -622,7 +622,16 @@ def _emit_evalue(self, val: EValue) -> _AbstractValue:
return _AbstractValue(len(self.emitter_state.values) - 1, tensor)
def _emit_spec(self, spec: ValueSpec) -> _EmitterValue:
- """Given the provided spec constructs the corresponding EValue from it and then emits it."""
+ """Given the provided spec constructs the corresponding EValue from it and then emits it.
+
+ If `spec` was already emitted earlier (e.g., because two FX
+ nodes share the same TensorSpec object — typically because the
+ planner's `_alias_inplace_result_specs` aliased an in-place
+ op's result onto its mutated input), reuse the existing
+ value_id. This keeps the invariant "one TensorSpec ↔ one
+ Value" so downstream emit doesn't create duplicate Values for
+ aliased FX nodes.
+ """
def _process(spec: LeafValueSpec) -> _AbstractValue:
if isinstance(spec, (list, tuple)):
@@ -642,6 +651,25 @@ def _process(spec: LeafValueSpec) -> _AbstractValue:
f"Invalid node spec expected TensorSpec received {spec}",
)
+ # Spec was already emitted — reuse the existing Value so
+ # two FX nodes sharing one TensorSpec also share one
+ # value_id in the lowered IR.
+ existing_id = self.emitter_state.spec2id_dict.get(spec)
+ if existing_id is not None:
+ existing_evalue = self.emitter_state.values[existing_id]
+ # Both insertion sites for `spec2id_dict` (this method
+ # and the placeholder emitter) only register ids whose
+ # EValue wraps a `Tensor` (built via
+ # `_tensor_spec_to_evalue` from a `TensorSpec`).
+ self._internal_assert_emitter(
+ isinstance(existing_evalue.val, Tensor),
+ self.node,
+ f"spec2id_dict entry for TensorSpec must point to a "
+ f"Tensor EValue, got "
+ f"{type(existing_evalue.val).__name__}",
+ )
+ return _AbstractValue(existing_id, existing_evalue.val)
+
ret = self._emit_evalue(self._tensor_spec_to_evalue(spec)) # pyre-ignore
self.emitter_state.spec2id_dict[spec] = ret.id # pyre-ignore
return ret
@@ -2020,6 +2048,14 @@ def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool:
)
value = self._emit_evalue(evalue)
+ # Populate spec2id_dict so downstream `_emit_spec` calls (e.g.,
+ # for in-place op result FX nodes whose spec was aliased onto
+ # this placeholder's spec by the planner's
+ # `_alias_inplace_result_specs`) reuse this placeholder's
+ # value_id rather than creating a new Value.
+ if isinstance(spec, TensorSpec):
+ self.emitter_state.spec2id_dict[spec] = value.id
+
# Only user inputs should remain as inputs.
if is_user_input:
self.inputs.append(value.id)
diff --git a/exir/memory_planning.py b/exir/memory_planning.py
index 9b19cbc7770..3c9f4313ae2 100644
--- a/exir/memory_planning.py
+++ b/exir/memory_planning.py
@@ -30,7 +30,12 @@
from executorch.exir.control_flow import while_loop as exir_while
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.error import internal_assert, InternalError
-from executorch.exir.operator.convert import is_inplace_variant, is_out_variant
+from executorch.exir.operator.convert import (
+ is_inplace_variant,
+ is_out_variant,
+ output_to_aliased_input_map,
+ unwrap_op_overload,
+)
from executorch.exir.schema import DeviceType, NonConstBufferDevice, TensorShapeDynamism
from executorch.exir.tensor import TensorSpec
from torch import fx
@@ -302,13 +307,137 @@ def _is_out_var_node(node: torch.fx.Node) -> bool:
def _is_inplace_node(node: torch.fx.Node) -> bool:
- return (
- node.op == "call_function"
- and isinstance(node.target, torch._ops.OpOverload)
- and is_inplace_variant(
- node.target._schema.name, node.target._schema.overload_name
+ if node.op != "call_function":
+ return False
+ target = node.target
+ if not isinstance(target, torch._ops.OpOverload) and not isinstance(
+ getattr(target, "_op", None), torch._ops.OpOverload
+ ):
+ return False
+ op = unwrap_op_overload(target)
+ return is_inplace_variant(op._schema.name, op._schema.overload_name)
+
+
+def _alias_inplace_result_specs(node: torch.fx.Node) -> None: # noqa: C901
+ """Alias an in-place op's result TensorSpec(s) onto the corresponding
+ input's spec.
+
+ In-place ops (schema kind == inplace) mutate one or more of their
+ inputs and return tensors that alias them, declared via the
+ ``Tensor(a!)`` schema annotation. To make the memory planner treat
+ result and aliased input as one storage, we copy the input's spec
+ object onto the output's ``node.meta["spec"]`` slot.
+
+ Output→input correspondence is computed via
+ ``output_to_aliased_input_map``, which matches each return's
+ write-alias set against the inputs that share it.
+
+ Gating:
+
+ - Only runs for in-place nodes (caller checks ``_is_inplace_node``).
+ - Multi-output in-place ops are supported when each return's alias
+ set matches exactly one input's alias set.
+ - Falls through silently when alias info is absent or unparseable,
+ preserving the original spec. ``logging.debug`` records each
+ early-return reason so silent regressions are observable.
+ """
+ target = node.target
+ op = unwrap_op_overload(target)
+
+ schema = op._schema
+ out_to_in = output_to_aliased_input_map(schema)
+ if not out_to_in:
+ logging.debug(
+ f"_alias_inplace_result_specs: schema for {op} declares no "
+ f"write-aliased outputs matching an input; skipping."
)
- )
+ return
+
+ # Normalize the current spec container into a list for uniform
+ # handling. Caller guarantees this node was identified as an
+ # in-place op with a meta spec (see `_is_inplace_node`), so an
+ # unrecognized container shape is a real bug — assert loudly so it
+ # surfaces in tests rather than silently disabling aliasing.
+ current = node.meta.get("spec")
+ if isinstance(current, TensorSpec):
+ out_specs_list: List[Optional[TensorSpec]] = [current]
+ return_container_kind = "scalar"
+ elif isinstance(current, (list, tuple)):
+ out_specs_list = list(current)
+ return_container_kind = type(current).__name__
+ else:
+ raise InternalError(
+ f"_alias_inplace_result_specs: in-place node {node.name} "
+ f"({op}) has unrecognized spec container of type "
+ f"{type(current).__name__!r}; expected TensorSpec, list, "
+ f"or tuple."
+ )
+
+ # Compute new spec for each return; None means "keep original".
+ # Mutated inputs are usually positional (the `Tensor(a!)` `self`
+ # arg), but custom ops may pass them via kwargs — fall back to
+ # `node.kwargs[arg_name]` in that case.
+ replacements: List[Optional[TensorSpec]] = [None] * len(out_specs_list)
+
+ for out_idx, in_idx in out_to_in.items():
+ if out_idx >= len(out_specs_list):
+ logging.debug(
+ f"_alias_inplace_result_specs: schema for {op} declares "
+ f"return {out_idx} but spec container has only "
+ f"{len(out_specs_list)} entries; skipping this return."
+ )
+ continue
+ # Resolve the mutated input. Prefer positional args, fall back
+ # to kwargs by argument name (custom ops may pass `Tensor(a!)`
+ # args via kwargs).
+ in_node: object
+ if in_idx < len(node.args):
+ in_node = node.args[in_idx]
+ else:
+ arg_name = (
+ schema.arguments[in_idx].name
+ if in_idx < len(schema.arguments)
+ else None
+ )
+ if arg_name is None or arg_name not in node.kwargs:
+ logging.debug(
+ f"_alias_inplace_result_specs: schema for {op} "
+ f"expects mutated input at position {in_idx} "
+ f"(name={arg_name!r}) but it is supplied neither "
+ f"positionally nor via kwargs; skipping."
+ )
+ continue
+ in_node = node.kwargs[arg_name]
+ if not isinstance(in_node, torch.fx.Node):
+ continue
+ # NOTE: alias unconditionally — including when the input is a
+ # placeholder (named buffer / mutable input). Skipping
+ # placeholders would leave a dangling out-arg on in-place op
+ # instructions whose self is a buffer; the runtime kernel
+ # writes to that out-arg's storage rather than mutating the
+ # buffer, producing incorrect results. The companion change in
+ # `_emit_spec` (emit/_emitter.py) deduplicates by spec identity
+ # so this aliasing doesn't produce two Values for the same FQN.
+ in_spec = in_node.meta.get("spec")
+ if not isinstance(in_spec, TensorSpec):
+ continue
+ replacements[out_idx] = in_spec
+
+ if not any(r is not None for r in replacements):
+ return
+
+ # Assemble the new spec container, preserving the original shape.
+ new_list = [
+ replacements[i] if replacements[i] is not None else out_specs_list[i]
+ for i in range(len(out_specs_list))
+ ]
+ if return_container_kind == "scalar":
+ if isinstance(new_list[0], TensorSpec):
+ node.meta["spec"] = new_list[0]
+ elif return_container_kind == "list":
+ node.meta["spec"] = list(new_list)
+ elif return_container_kind == "tuple":
+ node.meta["spec"] = tuple(new_list)
def update_tensor_lifetime(
@@ -474,6 +603,11 @@ def collect_specs_from_nodes( # noqa: C901
continue
if _is_inplace_node(node):
+ # Schema-driven: alias the in-place op's result spec(s) onto
+ # the corresponding mutated input's spec so the planner gives
+ # them the same allocation slot. See `_alias_inplace_result_specs`
+ # for the full rationale and gating rules.
+ _alias_inplace_result_specs(node)
continue
if _is_mutable_buffer(node, graph_signature) and ignore_mutable_buffers:
diff --git a/exir/operator/convert.py b/exir/operator/convert.py
index 74bd686c542..ab0dced87f5 100644
--- a/exir/operator/convert.py
+++ b/exir/operator/convert.py
@@ -25,7 +25,7 @@
import dataclasses
import logging
-from typing import Dict, Optional, Tuple
+from typing import Dict, FrozenSet, Optional, Tuple
import torch
from torch._ops import OpOverload
@@ -105,6 +105,70 @@ def get_out_args_from_schema(out_var_schema: FunctionSchema) -> Tuple[str]:
return tuple(arg.name for arg in out_var_schema.arguments.out)
+def unwrap_op_overload(target: object) -> torch._ops.OpOverload:
+ """Return the underlying ``torch._ops.OpOverload`` for a node target.
+
+ Handles both raw aten ops (``torch.ops.aten.X.Y``) and edge-dialect
+ wrappers (``ops.edge.aten.X.Y`` / ``BackendOpOverload``), which expose
+ the underlying ATen op as ``target._op`` but are not themselves
+ subclasses of ``torch._ops.OpOverload``.
+
+ Raises ``TypeError`` if ``target`` is neither an ``OpOverload`` nor
+ wraps one.
+ """
+ if isinstance(target, torch._ops.OpOverload):
+ return target
+ underlying = getattr(target, "_op", None)
+ if isinstance(underlying, torch._ops.OpOverload):
+ return underlying
+ raise TypeError(
+ f"unwrap_op_overload: expected a torch._ops.OpOverload or a "
+ f"wrapper exposing one as `_op`, got {type(target).__name__}: "
+ f"{target!r}"
+ )
+
+
+def output_to_aliased_input_map(
+ schema: torch.FunctionSchema,
+) -> Dict[int, int]:
+ """For a mutating op, return a map from output index to input arg
+ index, where the output aliases (i.e., mutates) the input via a
+ shared ``Tensor(a!)`` write-alias set.
+
+ Only outputs whose ``alias_info.is_write`` is True are considered.
+ Inputs without a write-aliased ``alias_info`` cannot be aliased
+ targets and are skipped. If multiple inputs share a single
+ return's alias set (no known cases in aten today), the first
+ matching input wins.
+
+ Returns an empty dict for ops with no write-aliased outputs (i.e.,
+ purely functional ops, or ops whose schema declares no returns).
+
+ Note: ``schema`` is the pybind ``torch._C.FunctionSchema`` (as
+ obtained from ``op._schema``), not the torchgen native
+ ``FunctionSchema`` used elsewhere in this file.
+ """
+ # Build alias_set -> input_idx so we can look up each return.
+ alias_set_to_input: Dict[FrozenSet[str], int] = {}
+ for in_idx, arg in enumerate(schema.arguments):
+ info = arg.alias_info
+ if info is None or not info.is_write:
+ continue
+ alias_set = frozenset(info.before_set)
+ alias_set_to_input.setdefault(alias_set, in_idx)
+
+ out_to_in: Dict[int, int] = {}
+ for out_idx, ret in enumerate(schema.returns):
+ info = ret.alias_info
+ if info is None or not info.is_write:
+ continue
+ alias_set = frozenset(info.before_set)
+ in_idx = alias_set_to_input.get(alias_set)
+ if in_idx is not None:
+ out_to_in[out_idx] = in_idx
+ return out_to_in
+
+
def parse_qualified_opname(qualified_opname: str) -> Tuple[str, str]:
"""
Given a qualified opname like aten::add, return a tuple for namespace
diff --git a/exir/operator/test/test_operator.py b/exir/operator/test/test_operator.py
index 8235fd0ba82..770991b15f9 100644
--- a/exir/operator/test/test_operator.py
+++ b/exir/operator/test/test_operator.py
@@ -9,7 +9,12 @@
import unittest
import torch
-from executorch.exir.operator.convert import _get_overload_schema, to_out_variant
+from executorch.exir.operator.convert import (
+ _get_overload_schema,
+ output_to_aliased_input_map,
+ to_out_variant,
+ unwrap_op_overload,
+)
from executorch.exir.operator.util import gen_out_variant_schema
from torch.library import _scoped_library, impl, impl_abstract
@@ -82,3 +87,86 @@ def custom_mutator_out(
schema.__str__(),
"DO_NOT_USE_TEST_ONLY::custom_mutator.out(Tensor x, Tensor(a!) y, *, Tensor(b!) out) -> Tensor(b!)",
)
+
+
+class TestUnwrapOpOverload(unittest.TestCase):
+ def test_aten_overload_returned_as_is(self) -> None:
+ op = torch.ops.aten.add.Tensor
+ self.assertIs(unwrap_op_overload(op), op)
+
+ def test_wrapper_with_op_attr_peeled_to_aten(self) -> None:
+ # Mimic the structural shape of `EdgeOpOverload` /
+ # `BackendOpOverload`: a non-OpOverload wrapper that exposes
+ # the underlying aten op via `_op`.
+ class _FakeWrapper: # noqa: B903
+ def __init__(self, op: torch._ops.OpOverload) -> None:
+ self._op = op
+
+ aten_op = torch.ops.aten.add.Tensor
+ wrapper = _FakeWrapper(aten_op)
+ self.assertIs(unwrap_op_overload(wrapper), aten_op)
+
+ def test_non_op_raises(self) -> None:
+ with self.assertRaises(TypeError):
+ unwrap_op_overload("not an op")
+ with self.assertRaises(TypeError):
+ unwrap_op_overload(None)
+ with self.assertRaises(TypeError):
+ unwrap_op_overload(42)
+
+ def test_wrapper_with_non_op_underlying_raises(self) -> None:
+ class _BadWrapper:
+ _op = "not an op overload"
+
+ with self.assertRaises(TypeError):
+ unwrap_op_overload(_BadWrapper())
+
+
+class TestOutputToAliasedInputMap(unittest.TestCase):
+ def test_functional_op_returns_empty(self) -> None:
+ # `aten::add.Tensor` is purely functional — no Tensor(a!) on
+ # any return.
+ schema = torch.ops.aten.add.Tensor._schema
+ self.assertEqual(output_to_aliased_input_map(schema), {})
+
+ def test_single_output_inplace_op(self) -> None:
+ # `aten::index_put_` mutates `self` (arg 0) and returns it
+ # (return 0).
+ schema = torch.ops.aten.index_put_.default._schema
+ self.assertEqual(output_to_aliased_input_map(schema), {0: 0})
+
+ def test_single_output_inplace_via_pybind_parse(self) -> None:
+ # Synthetic single-mutation schema parsed via the pybind
+ # FunctionSchema parser; mutates `self` at position 0 and
+ # returns it.
+ schema = torch._C.parse_schema(
+ "test::single(Tensor(a!) self, int n) -> Tensor(a!)"
+ )
+ self.assertEqual(output_to_aliased_input_map(schema), {0: 0})
+
+ def test_multi_output_inplace_via_pybind_parse(self) -> None:
+ # Synthetic multi-mutation schema: two write-aliased inputs
+ # `a` and `b`, each returned as its own aliased output.
+ schema = torch._C.parse_schema(
+ "test::multi(Tensor(x!) a, Tensor(y!) b, int n) "
+ "-> (Tensor(x!), Tensor(y!))"
+ )
+ # Output 0 (alias set {x}) → input 0 (a).
+ # Output 1 (alias set {y}) → input 1 (b).
+ self.assertEqual(output_to_aliased_input_map(schema), {0: 0, 1: 1})
+
+ def test_partial_aliasing_returns_only_matched(self) -> None:
+ # Two returns, only the first carries write-alias info.
+ schema = torch._C.parse_schema(
+ "test::partial(Tensor(z!) self, Tensor other) " "-> (Tensor(z!), Tensor)"
+ )
+ self.assertEqual(output_to_aliased_input_map(schema), {0: 0})
+
+ def test_tied_inputs_first_match_wins(self) -> None:
+ # Two inputs share the same write-alias set; per the docstring
+ # contract ("the first matching input wins"), the helper must
+ # map the single output back to input index 0, not 1.
+ schema = torch._C.parse_schema(
+ "test::tied(Tensor(a!) x, Tensor(a!) y) -> Tensor(a!)"
+ )
+ self.assertEqual(output_to_aliased_input_map(schema), {0: 0})
diff --git a/exir/passes/reinplace.py b/exir/passes/reinplace.py
index 349869a2f4b..3c6bad77da7 100644
--- a/exir/passes/reinplace.py
+++ b/exir/passes/reinplace.py
@@ -6,19 +6,224 @@
# pyre-strict
-from typing import Set
+from typing import Any, Dict, FrozenSet, Iterable, Optional, Set, Tuple
import torch
from executorch.exir.dialects._ops import ops
+from executorch.exir.operator.convert import (
+ output_to_aliased_input_map,
+ unwrap_op_overload,
+)
from torch.export import ExportedProgram
-def _is_index_put(node: torch.fx.Node) -> bool:
- """Check if a node is an index_put operation."""
- return node.op == "call_function" and node.target in (
- torch.ops.aten.index_put.default,
+# ---------------------------------------------------------------------------
+# Public API for extending the pass with additional ops.
+# ---------------------------------------------------------------------------
+
+
+# Default set of edge-dialect functional ops the pass attempts to rewrite
+# in-place. Today only `index_put -> index_put_`. The pass auto-derives
+# the in-place form by name + schema match — callers who add ops here
+# do not need to specify the in-place op explicitly.
+#
+# `reinplace_pass` runs after `to_edge` (inside `to_executorch`), so it
+# only ever sees edge-dialect targets (`EdgeOpOverload`). Aten targets
+# do not appear at this stage; only edge ops belong here.
+DEFAULT_INPLACEABLE_OPS: FrozenSet[Any] = frozenset(
+ {
ops.edge.aten.index_put.default,
- )
+ }
+)
+
+
+# ---------------------------------------------------------------------------
+# Schema-based discovery and validation.
+# ---------------------------------------------------------------------------
+
+
+def _op_schema(op: Any) -> torch.FunctionSchema:
+ """Return the underlying ``FunctionSchema`` for an op overload.
+
+ Delegates to ``unwrap_op_overload`` to peel any edge-dialect or
+ backend wrapper down to the bare ``torch._ops.OpOverload`` before
+ reading ``_schema``. Falls back to ``op._schema`` if the input is
+ a schema-bearing object that ``unwrap_op_overload`` doesn't
+ recognize (e.g., a custom op-like with a ``_schema`` attribute but
+ not a true ``OpOverload``).
+ """
+ try:
+ return unwrap_op_overload(op)._schema
+ except TypeError:
+ return op._schema
+
+
+def _is_inplace_of(
+ functional_schema: torch.FunctionSchema,
+ candidate_schema: torch.FunctionSchema,
+) -> bool:
+ """Return True if `candidate_schema` is the in-place form of
+ `functional_schema`: same input arg types positionally, with the
+ first arg of the candidate carrying a `Tensor(a!)` write alias.
+ """
+ f_args = functional_schema.arguments
+ c_args = candidate_schema.arguments
+ if len(f_args) != len(c_args):
+ return False
+ if not c_args:
+ return False
+ first = c_args[0]
+ if first.alias_info is None or not first.alias_info.is_write:
+ return False
+ for fa, ca in zip(f_args, c_args):
+ # Compare JIT types directly; `str(...)` repr equality is
+ # fragile to type-printing changes (qualified vs. unqualified
+ # names, container parameterization). Fall back to str only if
+ # direct equality raises (some C-bound types may not implement
+ # __eq__ uniformly across Torch versions).
+ try:
+ if fa.type != ca.type:
+ return False
+ except Exception:
+ if str(fa.type) != str(ca.type):
+ return False
+ return True
+
+
+def _derive_edge_inplace_overload(functional_op: Any) -> Optional[Any]:
+ """Auto-derive the in-place edge-dialect overload for a functional
+ op (works for both `EdgeOpOverload` inputs — the common case —
+ and bare aten `OpOverload` inputs; the result is always an
+ edge-dialect op).
+
+ Strategy: peel `EdgeOpOverload._op` (no-op for bare aten) to find
+ the underlying aten schema. The aten in-place form lives at
+ `torch.ops.aten._` (for example,
+ `aten::index_put -> aten::index_put_`). Walk the aten in-place
+ package's overloads (which exposes `.overloads()` cleanly; the
+ edge package does not) to find the one whose schema matches the
+ functional schema (with the first arg promoted to `Tensor(a!)`).
+ Then look up the corresponding edge-dialect op at
+ `ops.edge.aten._.`.
+
+ Handles cross-overload-name asymmetry — e.g.
+ `aten::pow.Tensor_Scalar` -> `aten::pow_.Scalar`
+ where simple name-based lookup `aten.pow_.Tensor_Scalar` fails.
+
+ Note: if multiple aten in-place overloads have schemas that match
+ the functional op's schema (no known cases in aten today), the
+ first one returned by `.overloads()` wins. Callers can disambiguate
+ by passing an explicit `inplace_overrides` entry.
+
+ Returns None if no schema match is found, or the op is not an
+ aten op. Callers should provide an explicit override via
+ `inplace_overrides` for non-conventional rewrites.
+ """
+ schema = _op_schema(functional_op)
+ name = schema.name # e.g. "aten::index_put"
+ if "::" not in name:
+ return None
+ namespace, base = name.split("::", 1)
+ if namespace != "aten":
+ return None
+
+ # Find the matching aten in-place overload first (aten exposes
+ # `.overloads()` cleanly; the edge package does not).
+ aten_inplace_pkg = getattr(torch.ops.aten, base + "_", None)
+ if aten_inplace_pkg is None:
+ return None
+ matched_overload_name: Optional[str] = None
+ for overload_name in aten_inplace_pkg.overloads():
+ candidate = getattr(aten_inplace_pkg, overload_name, None)
+ if candidate is None:
+ continue
+ cand_schema = getattr(candidate, "_schema", None)
+ if cand_schema is None:
+ continue
+ if _is_inplace_of(schema, cand_schema):
+ matched_overload_name = overload_name
+ break
+ if matched_overload_name is None:
+ return None
+
+ # Translate to the edge-dialect op of the same name + overload.
+ edge_inplace_pkg = getattr(ops.edge.aten, base + "_", None)
+ if edge_inplace_pkg is None:
+ return None
+ return getattr(edge_inplace_pkg, matched_overload_name, None)
+
+
+def _validate_inplace_mapping(functional_op: Any, inplace_op: Any) -> None:
+ """Validate that `inplace_op` is a plausible in-place form of
+ `functional_op`. Raises `ValueError` on misregistration.
+
+ Two checks:
+ 1. **Schema shape**: same arg types positionally, first arg of
+ `inplace_op` carries `Tensor(a!)`. Catches gross mismatches
+ like `add -> mul_` (different signatures).
+ 2. **Name affinity**: in-place op's name starts with the
+ functional op's name + "_". Catches subtle mismatches with
+ matching schemas like `add -> sub_`.
+ """
+ f_schema = _op_schema(functional_op)
+ i_schema = _op_schema(inplace_op)
+
+ if not _is_inplace_of(f_schema, i_schema):
+ raise ValueError(
+ f"Schema mismatch in reinplace registration: "
+ f"{functional_op} -> {inplace_op}. "
+ f"Expected in-place arg types to match functional arg types "
+ f"positionally with the first arg promoted to Tensor(a!). "
+ f"Got functional schema {f_schema} and in-place schema "
+ f"{i_schema}."
+ )
+
+ expected_prefix = f_schema.name + "_"
+ if not i_schema.name.startswith(expected_prefix):
+ raise ValueError(
+ f"Suspicious reinplace registration: "
+ f"{functional_op} -> {inplace_op}. "
+ f"In-place op name '{i_schema.name}' should start with "
+ f"'{expected_prefix}'. This usually indicates a typo "
+ f"(e.g. 'add -> sub_' instead of 'add -> add_')."
+ )
+
+
+def _derive_mutated_args(inplace_op: Any) -> Tuple[int, ...]:
+ """Return the positional indices of args that the in-place op
+ mutates, derived from the schema's `Tensor(a!)` annotations.
+
+ Computed via ``output_to_aliased_input_map``: each return with a
+ write alias points back to the input arg position carrying the
+ same alias set. The mutated positions are the input indices that
+ appear as values in that map.
+
+ Raises `ValueError` if `inplace_op` has no write-aliased outputs
+ that match an input (i.e. it isn't actually an in-place op). The
+ schema is the source of truth — if a custom op truly mutates an
+ arg, its schema must declare so via `Tensor(a!)` on both the
+ mutated input and the corresponding return. Otherwise
+ functionalization, memory planning, and autograd will all silently
+ mishandle it.
+ """
+ schema = _op_schema(inplace_op)
+ out_to_in = output_to_aliased_input_map(schema)
+ indices = tuple(sorted(set(out_to_in.values())))
+ if not indices:
+ raise ValueError(
+ f"{inplace_op} has no Tensor(a!) write-aliased args that "
+ f"match a corresponding return. If this op truly mutates, "
+ f"fix its schema to declare the mutated arg(s) with "
+ f"`Tensor(a!)` on both the input and the matching return. "
+ f"The schema is the contract that all of export, memory "
+ f"planning, and functionalization rely on."
+ )
+ return indices
+
+
+# ---------------------------------------------------------------------------
+# Internal helpers.
+# ---------------------------------------------------------------------------
def _is_safe_to_reinplace(
@@ -56,48 +261,172 @@ def _is_mutable_buffer(node: torch.fx.Node, exported_program: ExportedProgram) -
return buf in exported_program.graph_signature.buffers_to_mutate.values()
-def reinplace_pass(ep: ExportedProgram) -> ExportedProgram:
- """
- Pass that loops over nodes in an exported program and collects the first argument
- of every call_function node that is a view_copy operation.
+# ---------------------------------------------------------------------------
+# Pass entry point.
+# ---------------------------------------------------------------------------
+
+
+def reinplace_pass( # noqa: C901
+ ep: ExportedProgram,
+ ops_to_inplace: Optional[Iterable[Any]] = None,
+ inplace_overrides: Optional[Dict[Any, Any]] = None,
+) -> ExportedProgram:
+ """Rewrite functional ops in-place when safe.
+
+ Walks the graph in reverse topological order. For each
+ `call_function` node whose target is in the resolved set of
+ in-place candidates, checks whether each mutated-arg position is
+ safe to reinplace (`_is_safe_to_reinplace`). If all checks pass,
+ replaces the node with a call to its in-place op variant.
+
+ Safety rules:
+ * The mutated arg must not be used by any later node in the
+ graph.
+ * If the mutated arg is a placeholder (program input), it must
+ be a *mutable* input — i.e., declared in
+ `graph_signature.user_inputs_to_mutate` or
+ `graph_signature.buffers_to_mutate`. Immutable inputs are
+ never reinplaced because mutating them would be a side effect
+ the caller did not opt into.
Args:
- exported_program: The ExportedProgram to analyze
+ ep: The ExportedProgram to rewrite. Modified in place; returned
+ for chaining.
+ ops_to_inplace: Optional iterable of functional edge-dialect
+ ops the pass should try to reinplace. The in-place form is
+ auto-derived from the schema (`X -> X_*` by name + schema
+ match). Defaults to `DEFAULT_INPLACEABLE_OPS`. Pass an
+ empty iterable to disable all rewrites:
+
+ from executorch.exir.dialects._ops import ops
+
+ reinplace_pass(
+ ep,
+ ops_to_inplace=DEFAULT_INPLACEABLE_OPS | {
+ ops.edge.aten.index_copy.default,
+ },
+ )
+
+ inplace_overrides: Optional explicit map for non-conventional
+ rewrites (e.g. backend-fused in-place ops whose name does
+ not follow the `X -> X_*` convention). Keys are also
+ included in the set of ops to consider — you do NOT need
+ to also list them in `ops_to_inplace`. Example:
+
+ reinplace_pass(
+ ep,
+ inplace_overrides={
+ my_backend.functional: my_backend.fused_inplace,
+ },
+ )
+
+ Each entry is validated at pass startup:
+ * Schema arg types must match positionally between functional
+ and in-place forms.
+ * The in-place op's first arg must carry `Tensor(a!)`.
+ * The in-place op's name must start with the functional op's
+ name + "_" (e.g. `aten::add -> aten::add_*`).
+ * The in-place op must have at least one `Tensor(a!)` arg.
+ Misregistrations raise `ValueError` immediately.
Returns:
- Set of nodes that are first arguments to view_copy operations
+ The (possibly mutated) ExportedProgram.
"""
+ overrides = inplace_overrides or {}
+ op_set: Set[Any] = set(
+ ops_to_inplace if ops_to_inplace is not None else DEFAULT_INPLACEABLE_OPS
+ )
+ # Overrides also enroll their key in the candidate set.
+ op_set.update(overrides.keys())
+
+ # Validate every entry up front and pre-compute mutated_args so we
+ # don't re-do the schema introspection per node.
+ resolved: Dict[Any, Tuple[Any, Tuple[int, ...]]] = {}
+ for functional_op in op_set:
+ if functional_op in overrides:
+ inplace_op = overrides[functional_op]
+ else:
+ inplace_op = _derive_edge_inplace_overload(functional_op)
+ if inplace_op is None:
+ raise ValueError(
+ f"Cannot auto-derive in-place form for "
+ f"{functional_op}. Provide an explicit mapping via "
+ f"`inplace_overrides={{{functional_op}: }}`."
+ )
+ _validate_inplace_mapping(functional_op, inplace_op)
+ mutated_args = _derive_mutated_args(inplace_op)
+ resolved[functional_op] = (inplace_op, mutated_args)
+
seen_nodes: Set[torch.fx.Node] = set()
# Get all placeholders
- inputs = set()
+ inputs: Set[torch.fx.Node] = set()
for node in ep.graph.nodes:
if node.op == "placeholder":
inputs.add(node)
# Get all inputs that we could potentially mutate
- mutable_nodes = set(
- [
- node
- for node in inputs
- if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep)
- ]
- )
+ mutable_nodes: Set[torch.fx.Node] = {
+ node
+ for node in inputs
+ if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep)
+ }
- results = set()
for node in reversed(ep.graph.nodes):
- if _is_index_put(node):
- # Check if this index_put node is safe to inplace
- # The first argument is the base tensor being indexed into
- first_arg = node.args[0]
- if _is_safe_to_reinplace(first_arg, seen_nodes, inputs, mutable_nodes):
- # This index_put is safe to reinplace
+ entry = resolved.get(node.target) if node.op == "call_function" else None
+ if entry is not None:
+ inplace_op, mutated_args = entry
+ # Every mutated arg position must independently be safe.
+ all_safe = True
+ for arg_idx in mutated_args:
+ if arg_idx >= len(node.args):
+ raise ValueError(
+ f"reinplace: {node.target} call at {node} has "
+ f"{len(node.args)} positional args, but the "
+ f"schema declares position {arg_idx} as "
+ f"Tensor(a!). Export should normalize mutated "
+ f"args to positional; this graph violates that "
+ f"assumption."
+ )
+ arg_node = node.args[arg_idx]
+ if not isinstance(arg_node, torch.fx.Node):
+ raise ValueError(
+ f"reinplace: {node.target} call at {node} has a "
+ f"non-Node value {arg_node!r} at position "
+ f"{arg_idx}, but the schema declares it as "
+ f"Tensor(a!). A Tensor input in an FX graph "
+ f"must be a torch.fx.Node."
+ )
+ if not _is_safe_to_reinplace(
+ arg_node, seen_nodes, inputs, mutable_nodes
+ ):
+ all_safe = False
+ break
+ if all_safe:
with ep.graph.inserting_before(node):
+ # Forward both args and kwargs: the in-place overload
+ # is schema-matched to the functional one, so any
+ # kwarg valid on the functional op (e.g.
+ # `accumulate=` for `index_put`) is also valid on
+ # the in-place form. Dropping kwargs would silently
+ # change semantics.
new_node = ep.graph.call_function(
- ops.edge.aten.index_put_.default, args=node.args
+ inplace_op,
+ args=node.args,
+ kwargs=node.kwargs,
)
new_node.meta["val"] = node.meta["val"]
node.replace_all_uses_with(new_node)
ep.graph.erase_node(node)
- results.add(first_arg)
- elif node.op == "call_function":
+ # No explicit `seen_nodes` update needed: the new
+ # in-place node's target isn't in `op_set`, so the
+ # reverse iterator visits it next and falls through
+ # to the generic update below.
+ continue
+ # Note: this intentionally falls through for mapping-matched
+ # nodes that failed the safety check. Their inputs *are* added
+ # to seen_nodes, so further-upstream candidates correctly see
+ # those tensors as "used later" and refuse to reinplace any op
+ # that mutates them.
+ # See test_unsafe_downstream_blocks_upstream_reinplace.
+ if node.op == "call_function":
seen_nodes.update(node.all_input_nodes)
return ep
diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py
index e0df7a713e6..8227f3a54b0 100644
--- a/exir/tests/test_memory_planning.py
+++ b/exir/tests/test_memory_planning.py
@@ -24,7 +24,7 @@
del logging
import torch
-from executorch.exir import ExecutorchBackendConfig, to_edge
+from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
from executorch.exir.capture._capture import patch_forward
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.memory_planning import (
@@ -46,6 +46,7 @@
SpecPropPass,
ToOutVarPass,
)
+from executorch.exir.passes.reinplace import DEFAULT_INPLACEABLE_OPS, reinplace_pass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.schema import DeviceType
from executorch.exir.tensor import TensorSpec
@@ -676,51 +677,69 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
et = to_edge(export(model, inputs, strict=True)).to_executorch()
- # The mutable buffer (5x5 float32 = 100 bytes) should not be double allocated.
- # The input and output of copy_ should share the same memory location.
- values = et.executorch_program.execution_plan[0].values
- expected_buffer_size = 5 * 5 * 4 # 5x5 float32
-
- # Collect all tensor allocations by their (memory_id, offset) and track sizes
- # Size is computed from tensor's sizes and scalar_type, not from allocation_info
- # (memory_offset_low/high are low/high 32-bit parts of a 64-bit offset, not bounds)
- scalar_type_sizes = {
- 0: 1, # BYTE
- 1: 1, # CHAR
- 2: 2, # SHORT
- 3: 4, # INT
- 4: 8, # LONG
- 5: 2, # HALF
- 6: 4, # FLOAT
- 7: 8, # DOUBLE
- }
- offset_to_indices = {}
- for i, val in enumerate(values):
- tensor = val.val
- if hasattr(tensor, "allocation_info") and tensor.allocation_info:
- alloc = tensor.allocation_info
- # Compute tensor size from sizes and scalar_type
- num_elements = 1
- for dim in tensor.sizes:
- num_elements *= dim
- element_size = scalar_type_sizes.get(int(tensor.scalar_type), 4)
- size = num_elements * element_size
- key = (alloc.memory_id, alloc.memory_offset)
- if key not in offset_to_indices:
- offset_to_indices[key] = {"indices": [], "size": size}
- offset_to_indices[key]["indices"].append(i)
-
- # Find shared allocations matching the mutable buffer size (before/after copy_)
- mutable_buffer_shares = [
- info
- for info in offset_to_indices.values()
- if len(info["indices"]) == 2 and info["size"] == expected_buffer_size
- ]
+ # The mutable buffer (5x5 float32 = 100 bytes) should not be
+ # double allocated. After the upstream emit dedup
+ # (`_emit_spec` reusing value_id when two FX nodes share a
+ # TensorSpec via the planner's `_alias_inplace_result_specs`),
+ # the `copy_` writeback's "out" arg uses the SAME value_id as
+ # its "self" arg (the buffer), rather than creating a separate
+ # Value at the same (mem_id, offset).
+ execution_plan = et.executorch_program.execution_plan[0]
+ values = execution_plan.values
+
+ # Find the `copy_` writeback instruction.
+ copy_instructions = []
+ for chain in execution_plan.chains:
+ for ins in chain.instructions:
+ inner = ins.instr_args
+ if hasattr(inner, "op_index"):
+ op = execution_plan.operators[inner.op_index]
+ if op.name == "aten::copy_":
+ copy_instructions.append(inner)
self.assertEqual(
- len(mutable_buffer_shares),
+ len(copy_instructions),
1,
- f"Expected exactly one shared allocation of size {expected_buffer_size} "
- f"with 2 values (copy_ input/output), found: {mutable_buffer_shares}",
+ "Expected exactly one copy_ writeback for the buffer mutation",
+ )
+
+ # For an in-place copy_(self, src, ..., out), self (arg 0) and
+ # out (the emitted synthetic last arg) must share a value_id
+ # per the `(a!)` schema annotation. Emit's spec2id_dict
+ # dedup enforces this.
+ copy_args = list(copy_instructions[0].args)
+ self.assertEqual(
+ copy_args[0],
+ copy_args[-1],
+ f"copy_'s out arg should reference the same value_id as its "
+ f"self arg (buffer) via emit dedup. args={copy_args}",
+ )
+
+ # Additionally verify no distinct second Value at the buffer's
+ # (mem_id, offset): after dedup, the buffer occupies its slot alone.
+ buffer_value_id = copy_args[0]
+ buffer_val = values[buffer_value_id].val
+ self.assertTrue(
+ hasattr(buffer_val, "allocation_info") and buffer_val.allocation_info,
+ "Buffer value should have allocation_info",
+ )
+ buffer_alloc = buffer_val.allocation_info
+ duplicates_at_buffer_slot = [
+ i
+ for i, val in enumerate(values)
+ if i != buffer_value_id
+ and hasattr(val.val, "allocation_info")
+ and val.val.allocation_info
+ and val.val.allocation_info.memory_id == buffer_alloc.memory_id
+ and val.val.allocation_info.memory_offset == buffer_alloc.memory_offset
+ ]
+ self.assertEqual(
+ duplicates_at_buffer_slot,
+ [],
+ f"Expected no other Values at the buffer's allocation "
+ f"(mem_id={buffer_alloc.memory_id}, "
+ f"offset={buffer_alloc.memory_offset}); emit dedup should "
+ f"collapse placeholder + writeback into one value_id. "
+ f"Found duplicates at indices: {duplicates_at_buffer_slot}",
)
def test_mutable_buffers_infinite_lifespan(self) -> None:
@@ -764,6 +783,147 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
self.assertTrue(not_overlapping)
+ def test_custom_inplace_op_memory_aliasing(self) -> None:
+ """Memory planning correctly handles in-place ops registered via
+ the ``ops_to_inplace`` extension API (i.e. outside
+ ``DEFAULT_INPLACEABLE_OPS``).
+
+ Uses the HF-static-cache pattern: ``index_copy_`` updates two
+ mutable buffers (``keys``, ``values``). We:
+ 1. Preserve ``index_copy`` through edge lowering.
+ 2. Manually call ``reinplace_pass`` with a custom set that
+ includes ``index_copy`` (the in-place form is
+ auto-derived).
+ 3. Lower with ``run_reinplace_pass=False`` (the pass already
+ ran).
+
+ Then assert that no other planned tensor's allocation overlaps
+ either buffer's storage region. This pins the schema-driven
+ ``_alias_inplace_result_specs`` path for non-default ops: the
+ ``index_copy_`` result spec must be aliased to the buffer's
+ spec, otherwise the planner would carve out a separate
+ allocation that could land inside the buffer's slot.
+ """
+ max_batch_size, num_heads, max_cache_len, head_dim = 1, 2, 4, 8
+
+ class HFStyleStaticCache(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer(
+ "keys",
+ torch.zeros((max_batch_size, num_heads, max_cache_len, head_dim)),
+ )
+ self.register_buffer(
+ "values",
+ torch.zeros((max_batch_size, num_heads, max_cache_len, head_dim)),
+ )
+
+ def forward(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ cache_position: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ self.keys.index_copy_(2, cache_position, key_states)
+ self.values.index_copy_(2, cache_position, value_states)
+ return self.keys, self.values
+
+ model = HFStyleStaticCache()
+ key_states = torch.full((max_batch_size, num_heads, 1, head_dim), 1.0)
+ value_states = torch.full((max_batch_size, num_heads, 1, head_dim), 2.0)
+ cache_position = torch.tensor([1])
+
+ exported_program = export(
+ model, (key_states, value_states, cache_position), strict=True
+ )
+
+ edge = to_edge(
+ exported_program,
+ compile_config=EdgeCompileConfig(
+ _check_ir_validity=False,
+ preserve_ops=[torch.ops.aten.index_copy.default],
+ ),
+ )
+
+ # Manually run reinplace_pass with a custom set that
+ # includes the (non-default) index_copy edge op. The in-place
+ # form is auto-derived by name + schema match.
+ custom_set = DEFAULT_INPLACEABLE_OPS | {
+ exir_ops.edge.aten.index_copy.default,
+ }
+ edge_program = reinplace_pass(
+ edge.exported_program(), ops_to_inplace=custom_set
+ )
+ # Sanity: both updates are now in-place.
+ inplace_nodes = [
+ n
+ for n in edge_program.graph.nodes
+ if n.op == "call_function" and "index_copy_" in str(n.target)
+ ]
+ self.assertEqual(
+ len(inplace_nodes),
+ 2,
+ "Both buffer updates should be reinplaced before lowering",
+ )
+
+ # Lower with run_reinplace_pass=False — the pass already ran
+ # with our custom set above. Memory planning should now
+ # correctly alias the index_copy_ result spec onto the buffer
+ # placeholder spec via _alias_inplace_result_specs.
+ et = edge.to_executorch(
+ ExecutorchBackendConfig(
+ emit_mutable_buffer_names=True,
+ run_reinplace_pass=False,
+ )
+ )
+
+ execution_plan = et.executorch_program.execution_plan[0]
+ values = execution_plan.values
+
+ # Collect the keys / values buffer Values by FQN.
+ buffer_value_ids: dict[str, int] = {}
+ for i, value in enumerate(values):
+ val = value.val
+ extra = getattr(val, "extra_tensor_info", None)
+ fqn = getattr(extra, "fully_qualified_name", None) if extra else None
+ if fqn in ("keys", "values"):
+ buffer_value_ids[fqn] = i
+
+ self.assertEqual(
+ set(buffer_value_ids.keys()),
+ {"keys", "values"},
+ "Both keys and values buffers should appear in the program "
+ "with their FQN",
+ )
+
+ # For each buffer, verify no other planned Value's allocation
+ # overlaps the buffer's memory region.
+ for fqn, vid in buffer_value_ids.items():
+ buf_alloc = values[vid].val.allocation_info
+ self.assertIsNotNone(buf_alloc, f"Buffer {fqn} should have allocation_info")
+ buf_base = buf_alloc.memory_offset_low
+ # 4 bytes per float32 element.
+ num_elements = max_batch_size * num_heads * max_cache_len * head_dim
+ buf_end = buf_base + num_elements * 4
+
+ for j, other in enumerate(values):
+ if j == vid:
+ continue
+ other_alloc = getattr(other.val, "allocation_info", None)
+ if other_alloc is None:
+ continue
+ if other_alloc.memory_id != buf_alloc.memory_id:
+ continue
+ offset = other_alloc.memory_offset_low
+ overlaps = buf_base <= offset < buf_end
+ self.assertFalse(
+ overlaps,
+ f"Value {j} (alloc offset={offset}) overlaps the "
+ f"{fqn} buffer's region [{buf_base}, {buf_end}) — "
+ "the in-place index_copy_ result spec was not "
+ "correctly aliased to the buffer spec",
+ )
+
def test_constants_not_memory_planned(self) -> None:
class Simple(torch.nn.Module):
def __init__(self) -> None:
diff --git a/exir/tests/test_reinplace_pass.py b/exir/tests/test_reinplace_pass.py
index 13661ef8cf9..c7b72a73a13 100644
--- a/exir/tests/test_reinplace_pass.py
+++ b/exir/tests/test_reinplace_pass.py
@@ -7,15 +7,37 @@
# pyre-strict
import unittest
+from typing import List, Optional
import torch
-from executorch.exir import to_edge
+from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.capture._config import ExecutorchBackendConfig
-from executorch.exir.passes.reinplace import reinplace_pass
+from executorch.exir.dialects._ops import ops as edge_ops
+from executorch.exir.passes.reinplace import DEFAULT_INPLACEABLE_OPS, reinplace_pass
from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib
_load_for_executorch_from_buffer,
)
from torch.export import export
+from torch.export.exported_program import ExportedProgram
+
+
+def _find_nodes(
+ ep: ExportedProgram,
+ contains: str,
+ excludes: Optional[str] = None,
+) -> List[torch.fx.Node]:
+ """Return all ``call_function`` nodes whose target string contains
+ ``contains``. If ``excludes`` is given, also drop any node whose
+ target string contains that substring (used to distinguish the
+ functional form ``index_put`` from the in-place ``index_put_``).
+ """
+ return [
+ n
+ for n in ep.graph.nodes
+ if n.op == "call_function"
+ and contains in str(n.target)
+ and (excludes is None or excludes not in str(n.target))
+ ]
class TestReinplacePass(unittest.TestCase):
@@ -42,33 +64,24 @@ def forward(
edge = to_edge(exported_program)
edge_program = edge._edge_programs["forward"]
- # Find the index_put node
- index_put_node = None
- for node in edge_program.graph.nodes:
- if node.op == "call_function" and "index_put" in str(node.target):
- index_put_node = node
- break
-
- self.assertIsNotNone(index_put_node, "Should find an index_put node")
+ self.assertEqual(
+ len(_find_nodes(edge_program, "index_put")),
+ 1,
+ "Should find an index_put node",
+ )
et = edge.to_executorch(ExecutorchBackendConfig(run_reinplace_pass=True))
- # Find the index_put node
- index_put_node = None
- for node in et.exported_program().graph.nodes:
- if node.op == "call_function" and "index_put_" in str(node.target):
- index_put_node = node
- break
-
- self.assertIsNotNone(index_put_node, "Should find an index_put_ node")
-
- # Find the copy_ node
- copy_node = None
- for node in et.exported_program().graph.nodes:
- if node.op == "call_function" and "copy_" in str(node.target):
- copy_node = node
- break
-
- self.assertIsNone(copy_node, "Shouldn't find an copy_ node")
+ et_program = et.exported_program()
+ self.assertEqual(
+ len(_find_nodes(et_program, "index_put_")),
+ 1,
+ "Should find an index_put_ node",
+ )
+ self.assertEqual(
+ len(_find_nodes(et_program, "copy_")),
+ 0,
+ "Shouldn't find a copy_ node",
+ )
e = _load_for_executorch_from_buffer(et.buffer)
self.assertTrue(
@@ -98,27 +111,471 @@ def forward(
values = torch.tensor([1.0])
exported_program = export(model, (indices, values), strict=True)
- edge_program = to_edge(exported_program).exported_program()
+ edge = to_edge(exported_program)
+ edge_program = edge.exported_program()
+
+ self.assertEqual(
+ len(_find_nodes(edge_program, "index_put")),
+ 1,
+ "Should find an index_put node",
+ )
+
+ ep = reinplace_pass(edge_program)
+ self.assertEqual(
+ len(_find_nodes(ep, "index_put", excludes="index_put_")),
+ 1,
+ "Should still find a functional index_put node",
+ )
+
+ # Lower to ExecuTorch and verify runtime correctness against eager.
+ et = edge.to_executorch()
+ loaded = _load_for_executorch_from_buffer(et.buffer)
+ et_out = loaded.forward((indices, values))
+ eager_out = IndexPutModel()(indices, values)
+ self.assertTrue(torch.allclose(et_out[0], eager_out))
+
+ def test_unsafe_downstream_blocks_upstream_reinplace(self) -> None:
+ """When an upstream index_put's mutated arg is also an input to an
+ unsafe downstream index_put, the upstream one must not be reinplaced
+ either. Otherwise the in-place mutation by the upstream op would
+ change the value the (still-functional) downstream op reads.
+ """
+
+ class TwoIndexPutModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("state", torch.zeros(5))
+
+ def forward(
+ self, indices: torch.Tensor, values: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # Intermediate tensor consumed by both index_puts.
+ t = values + 1.0
+ # Upstream index_put — mutates t (an intermediate, so
+ # by itself it would look safe to reinplace).
+ a = t.index_put((indices,), torch.tensor([5.0]))
+ # Downstream index_put — reads t. This one is itself
+ # unsafe because state is read again by the add_ below.
+ b = self.state.index_put((indices,), t)
+ self.state.add_(1)
+ return a, b
+
+ model = TwoIndexPutModel()
+ # `indices` and `values` must have matching shape so the
+ # downstream `state.index_put((indices,), t)` is well-formed
+ # (state[indices[i]] = t[i]). The original test used values
+ # shape [5] which is fine for graph-shape checks but not
+ # runnable in ET.
+ indices = torch.tensor([0])
+ values = torch.tensor([2.0])
+
+ exported_program = export(model, (indices, values), strict=True)
+ edge = to_edge(exported_program)
+ edge_program = edge.exported_program()
+
+ # Sanity check: both functional index_puts present pre-pass.
+ self.assertEqual(
+ len(_find_nodes(edge_program, "index_put", excludes="index_put_")),
+ 2,
+ "Should have two functional index_put nodes",
+ )
+
+ ep = reinplace_pass(edge_program)
+
+ # Neither index_put should be reinplaced. The downstream one is
+ # unsafe directly (state used later); the upstream one is unsafe
+ # because its mutated arg `t` is read by the unsafe downstream op.
+ self.assertEqual(
+ len(_find_nodes(ep, "index_put", excludes="index_put_")),
+ 2,
+ "Neither index_put should be reinplaced",
+ )
+ self.assertEqual(
+ len(_find_nodes(ep, "index_put_")),
+ 0,
+ "No index_put_ should have been introduced",
+ )
+
+ # Lower to ExecuTorch and verify runtime correctness against eager.
+ et = edge.to_executorch()
+ loaded = _load_for_executorch_from_buffer(et.buffer)
+ et_out = loaded.forward((indices, values))
+ eager_a, eager_b = TwoIndexPutModel()(indices, values)
+ self.assertTrue(torch.allclose(et_out[0], eager_a))
+ self.assertTrue(torch.allclose(et_out[1], eager_b))
+
+ def test_kwargs_are_forwarded(self) -> None:
+ """When the matched node carries a value in ``node.kwargs`` (e.g.
+ ``accumulate=True`` for ``index_put``), the rewrite must forward
+ those kwargs to the in-place form. Otherwise the in-place op
+ falls back to the schema default and silently changes semantics.
+
+ ``export`` normalizes most arguments into positional form, so we
+ explicitly move ``accumulate`` into ``node.kwargs`` after export
+ to exercise the kwarg-forwarding path.
+ """
+
+ class IndexPutAccumulateModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("state", torch.zeros(5))
+
+ def forward(
+ self, indices: torch.Tensor, values: torch.Tensor
+ ) -> torch.Tensor:
+ self.state.index_put_((indices,), values, accumulate=True)
+ return self.state
+
+ model = IndexPutAccumulateModel()
+ indices = torch.tensor([0, 0])
+ values = torch.tensor([1.0, 2.0])
+
+ exported_program = export(model, (indices, values), strict=True)
+ edge = to_edge(exported_program)
+ edge_program = edge.exported_program()
- # Find the index_put node
- index_put_node = None
- for node in edge_program.graph.nodes:
- if node.op == "call_function" and "index_put" in str(node.target):
- index_put_node = node
- break
+ # Find the functional index_put node and force `accumulate` onto
+ # its kwargs (export normalized it to positional). This is what
+ # exercises the kwarg-forwarding path in reinplace_pass.
+ functionals = _find_nodes(edge_program, "index_put", excludes="index_put_")
+ self.assertEqual(len(functionals), 1, "Should find a functional index_put")
+ functional = functionals[0]
- self.assertIsNotNone(index_put_node, "Should find an index_put node")
+ # index_put schema: (self, indices, values, accumulate=False).
+ # Move arg[3] -> kwargs["accumulate"] if present.
+ if len(functional.args) >= 4:
+ new_args = functional.args[:3]
+ new_kwargs = dict(functional.kwargs)
+ new_kwargs["accumulate"] = functional.args[3]
+ functional.args = new_args
+ functional.kwargs = new_kwargs
+ self.assertEqual(
+ functional.kwargs.get("accumulate"),
+ True,
+ "Test setup: accumulate should now be a kwarg",
+ )
ep = reinplace_pass(edge_program)
- # Find the index_put node
- index_put_node = None
- for node in ep.graph.nodes:
- if (
- node.op == "call_function"
- and "index_put" in str(node.target)
- and "index_put_" not in str(node.target)
- ):
- index_put_node = node
- break
-
- self.assertIsNotNone(index_put_node, "Should still find an index_put node")
+
+ # Find the rewritten in-place index_put_ node.
+ inplace_nodes = _find_nodes(ep, "index_put_")
+ self.assertEqual(len(inplace_nodes), 1, "Should find an index_put_ node")
+ index_put_inplace = inplace_nodes[0]
+
+ # `accumulate` must survive the rewrite, in either args or kwargs.
+ accumulate = index_put_inplace.kwargs.get("accumulate")
+ if accumulate is None and len(index_put_inplace.args) >= 4:
+ accumulate = index_put_inplace.args[3]
+ self.assertEqual(
+ accumulate,
+ True,
+ "accumulate=True must be preserved through the rewrite",
+ )
+
+ def test_ops_to_inplace_extends_with_add(self) -> None:
+ """A custom ``ops_to_inplace`` set can extend the pass to ops
+ outside the default set. Add the edge-dialect ``add.Tensor`` to
+ the set; the in-place form (``add_.Tensor``) is auto-derived by
+ name + schema match. Verify a safe-to-reinplace add gets
+ rewritten while an unsafe one (mutating an immutable input)
+ does not.
+ """
+
+ class TwoAddModel(torch.nn.Module):
+ def forward(
+ self,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ z: torch.Tensor,
+ ) -> torch.Tensor:
+ # First add: mutated arg `x` is an immutable user input
+ # → not safe to reinplace.
+ t = x + y
+ # Second add: mutated arg `t` is an intermediate with no
+ # later use → safe to reinplace.
+ return t + z
+
+ model = TwoAddModel()
+ args = (torch.zeros(5), torch.ones(5), torch.full((5,), 2.0))
+ exported_program = export(model, args, strict=True)
+ edge = to_edge(exported_program)
+ edge_program = edge.exported_program()
+
+ # Sanity: no in-place adds before the pass.
+ self.assertEqual(len(_find_nodes(edge_program, "add_")), 0)
+
+ custom_set = {edge_ops.edge.aten.add.Tensor}
+ ep = reinplace_pass(edge_program, ops_to_inplace=custom_set)
+
+ # Exactly one of the two adds should be reinplaced.
+ self.assertEqual(
+ len(_find_nodes(ep, "add_")),
+ 1,
+ "Exactly one (intermediate-mutating) add should be reinplaced",
+ )
+ self.assertEqual(
+ len(_find_nodes(ep, "aten.add.Tensor")),
+ 1,
+ "The add mutating an immutable input must remain functional",
+ )
+
+ # Lower to ExecuTorch. The portable runtime does not register a
+ # kernel for `add_.Tensor`, so we only verify lowering succeeds
+ # (the in-place rewrite must serialize cleanly into the ET
+ # program). Runtime execution is covered by ops with portable
+ # kernels in the other tests in this file.
+ edge.to_executorch()
+
+ def test_ops_to_inplace_empty_disables_all_rewrites(self) -> None:
+ """Passing an empty ``ops_to_inplace`` set should disable every
+ rewrite, even ops that are in ``DEFAULT_INPLACEABLE_OPS``.
+ """
+
+ class IndexPutModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("state", torch.zeros(5))
+
+ def forward(
+ self, indices: torch.Tensor, values: torch.Tensor
+ ) -> torch.Tensor:
+ self.state.index_put_((indices,), values)
+ return self.state
+
+ model = IndexPutModel()
+ indices = torch.tensor([0])
+ values = torch.tensor([1.0])
+
+ exported_program = export(model, (indices, values), strict=True)
+ edge = to_edge(exported_program)
+ edge_program = edge.exported_program()
+
+ # Sanity: index_put is in the default set, so without our
+ # explicit override it would be reinplaced.
+ self.assertIn(
+ edge_ops.edge.aten.index_put.default,
+ DEFAULT_INPLACEABLE_OPS,
+ "Sanity check: edge index_put should be in the default set",
+ )
+
+ ep = reinplace_pass(edge_program, ops_to_inplace={})
+
+ self.assertEqual(
+ len(_find_nodes(ep, "index_put_")),
+ 0,
+ "Empty set must disable all rewrites, including the default ones",
+ )
+ self.assertEqual(
+ len(_find_nodes(ep, "index_put", excludes="index_put_")),
+ 1,
+ "The functional index_put must remain",
+ )
+
+ # Lower to ExecuTorch and verify runtime correctness against eager.
+ et = edge.to_executorch()
+ loaded = _load_for_executorch_from_buffer(et.buffer)
+ et_out = loaded.forward((indices, values))
+ eager_out = IndexPutModel()(indices, values)
+ self.assertTrue(torch.allclose(et_out[0], eager_out))
+
+ def test_ops_to_inplace_custom_does_not_inherit_default(self) -> None:
+ """A custom ``ops_to_inplace`` set replaces — not augments —
+ ``DEFAULT_INPLACEABLE_OPS``. Passing a set that doesn't include
+ ``index_put`` leaves it functional, even though it would be
+ reinplaced under the default. Callers who want to extend rather
+ than replace should union with ``DEFAULT_INPLACEABLE_OPS``
+ explicitly (per the docstring guidance on ``reinplace_pass``).
+ """
+
+ class IndexPutModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("state", torch.zeros(5))
+
+ def forward(
+ self, indices: torch.Tensor, values: torch.Tensor
+ ) -> torch.Tensor:
+ self.state.index_put_((indices,), values)
+ return self.state
+
+ model = IndexPutModel()
+ indices = torch.tensor([0])
+ values = torch.tensor([1.0])
+
+ exported_program = export(model, (indices, values), strict=True)
+ edge = to_edge(exported_program)
+ edge_program = edge.exported_program()
+
+ # Custom set containing only `add` — no `index_put`.
+ custom_set = {edge_ops.edge.aten.add.Tensor}
+ ep = reinplace_pass(edge_program, ops_to_inplace=custom_set)
+
+ self.assertEqual(
+ len(_find_nodes(ep, "index_put_")),
+ 0,
+ "index_put must remain functional when not in the custom set",
+ )
+
+ # Lower to ExecuTorch and verify runtime correctness against eager.
+ et = edge.to_executorch()
+ loaded = _load_for_executorch_from_buffer(et.buffer)
+ et_out = loaded.forward((indices, values))
+ eager_out = IndexPutModel()(indices, values)
+ self.assertTrue(torch.allclose(et_out[0], eager_out))
+
+ def test_kv_cache_style_reinplace(self) -> None:
+ """HF-style static KV cache update via ``index_copy_`` along the
+ ``cache_len`` dim. Mirrors transformers' ``StaticCache.update``::
+
+ self.keys.index_copy_(2, cache_position, key_states)
+ self.values.index_copy_(2, cache_position, value_states)
+
+ Cache shape mirrors HF::
+ (max_batch_size, num_heads, max_cache_len, head_dim)
+
+ ``index_copy`` is **not** in ``DEFAULT_INPLACEABLE_OPS``, so this
+ also exercises the new ``ops_to_inplace`` extension API. We use
+ ``EdgeCompileConfig(preserve_ops=[index_copy])`` to keep the op
+ from being decomposed by ``to_edge``, then add ``index_copy`` to
+ a custom set (the in-place form is auto-derived by name + schema
+ match) and verify both buffer updates are rewritten in-place.
+ """
+
+ max_batch_size, num_heads, max_cache_len, head_dim = 1, 2, 4, 8
+
+ class HFStyleStaticCache(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer(
+ "keys",
+ torch.zeros((max_batch_size, num_heads, max_cache_len, head_dim)),
+ )
+ self.register_buffer(
+ "values",
+ torch.zeros((max_batch_size, num_heads, max_cache_len, head_dim)),
+ )
+
+ def forward(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ cache_position: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # HF static-cache style: in-place index_copy along dim=2.
+ self.keys.index_copy_(2, cache_position, key_states)
+ self.values.index_copy_(2, cache_position, value_states)
+ return self.keys, self.values
+
+ model = HFStyleStaticCache()
+ key_states = torch.full((max_batch_size, num_heads, 1, head_dim), 1.0)
+ value_states = torch.full((max_batch_size, num_heads, 1, head_dim), 2.0)
+ cache_position = torch.tensor([1])
+
+ exported_program = export(
+ model, (key_states, value_states, cache_position), strict=True
+ )
+
+ # Preserve `index_copy` through the edge lowering so the pass
+ # actually sees it (default to_edge decomposes it into index_put).
+ edge = to_edge(
+ exported_program,
+ compile_config=EdgeCompileConfig(
+ _check_ir_validity=False,
+ preserve_ops=[torch.ops.aten.index_copy.default],
+ ),
+ )
+ edge_program = edge.exported_program()
+
+ # Sanity: pre-pass, both updates are functional `index_copy`.
+ self.assertEqual(
+ len(_find_nodes(edge_program, "index_copy", excludes="index_copy_")),
+ 2,
+ "Pre-pass: should have two functional index_copy nodes (keys, values)",
+ )
+
+ # Add `index_copy` to the set; the in-place form is
+ # auto-derived by name + schema match
+ # (ops.edge.aten.index_copy_.default).
+ custom_set = DEFAULT_INPLACEABLE_OPS | {
+ edge_ops.edge.aten.index_copy.default,
+ }
+ ep = reinplace_pass(edge_program, ops_to_inplace=custom_set)
+
+ # Both updates should now be in-place.
+ self.assertEqual(
+ len(_find_nodes(ep, "index_copy_")),
+ 2,
+ "Both keys and values index_copy ops should be reinplaced",
+ )
+ self.assertEqual(
+ len(_find_nodes(ep, "index_copy", excludes="index_copy_")),
+ 0,
+ "No functional index_copy nodes should remain",
+ )
+
+ # Lower to ExecuTorch. The portable runtime does not register a
+ # kernel for `index_copy_`, so we only verify lowering succeeds
+ # (the in-place rewrite must serialize cleanly into the ET
+ # program).
+ edge.to_executorch()
+
+ def test_chain_of_inplaceable_ops(self) -> None:
+ """A chain of safe-to-reinplace ops gets fully rewritten in
+ topological order. Exercises:
+ * Multiple distinct ops (`add` and `relu`) registered together
+ via a single set, with in-place forms auto-derived.
+ * Reverse-walk safety propagation: each intermediate is
+ consumed exactly once by the next op, so every step except
+ the first sees its mutated arg as not-yet-used.
+ * The first ``add`` mutates an immutable user input
+ (``x``) and must remain functional.
+ """
+
+ class ChainModel(torch.nn.Module):
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ t = x + y # add #1: mutates `x` (immutable input) -> unsafe.
+ t = torch.relu(t) # relu #1: mutates intermediate -> safe.
+ t = t + y # add #2: mutates intermediate -> safe.
+ t = torch.relu(t) # relu #2: mutates intermediate -> safe.
+ return t
+
+ model = ChainModel()
+ x = torch.tensor([-1.0, 2.0, -3.0, 4.0])
+ y = torch.tensor([1.0, 1.0, 1.0, 1.0])
+
+ exported_program = export(model, (x, y), strict=True)
+ edge = to_edge(exported_program)
+ edge_program = edge.exported_program()
+
+ custom_set = {
+ edge_ops.edge.aten.add.Tensor,
+ edge_ops.edge.aten.relu.default,
+ }
+ ep = reinplace_pass(edge_program, ops_to_inplace=custom_set)
+
+ self.assertEqual(
+ len(_find_nodes(ep, "aten.add.Tensor")),
+ 1,
+ "First add mutates an immutable input; must stay functional",
+ )
+ self.assertEqual(
+ len(_find_nodes(ep, "add_")),
+ 1,
+ "Second add mutates an intermediate; should be reinplaced",
+ )
+ self.assertEqual(
+ len(_find_nodes(ep, "aten.relu.default")),
+ 0,
+ "Both relus mutate intermediates; neither should remain functional",
+ )
+ self.assertEqual(
+ len(_find_nodes(ep, "relu_")),
+ 2,
+ "Both relus should be reinplaced",
+ )
+
+ # Lower to ExecuTorch. The portable runtime does not register
+ # kernels for `add_.Tensor` or `relu_.default`, so we only
+ # verify lowering succeeds (the in-place rewrites must
+ # serialize cleanly into the ET program).
+ edge.to_executorch()