diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index a48d88fa224..231a7c9b1b8 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -622,7 +622,16 @@ def _emit_evalue(self, val: EValue) -> _AbstractValue: return _AbstractValue(len(self.emitter_state.values) - 1, tensor) def _emit_spec(self, spec: ValueSpec) -> _EmitterValue: - """Given the provided spec constructs the corresponding EValue from it and then emits it.""" + """Given the provided spec constructs the corresponding EValue from it and then emits it. + + If `spec` was already emitted earlier (e.g., because two FX + nodes share the same TensorSpec object — typically because the + planner's `_alias_inplace_result_specs` aliased an in-place + op's result onto its mutated input), reuse the existing + value_id. This keeps the invariant "one TensorSpec ↔ one + Value" so downstream emit doesn't create duplicate Values for + aliased FX nodes. + """ def _process(spec: LeafValueSpec) -> _AbstractValue: if isinstance(spec, (list, tuple)): @@ -642,6 +651,25 @@ def _process(spec: LeafValueSpec) -> _AbstractValue: f"Invalid node spec expected TensorSpec received {spec}", ) + # Spec was already emitted — reuse the existing Value so + # two FX nodes sharing one TensorSpec also share one + # value_id in the lowered IR. + existing_id = self.emitter_state.spec2id_dict.get(spec) + if existing_id is not None: + existing_evalue = self.emitter_state.values[existing_id] + # Both insertion sites for `spec2id_dict` (this method + # and the placeholder emitter) only register ids whose + # EValue wraps a `Tensor` (built via + # `_tensor_spec_to_evalue` from a `TensorSpec`). + self._internal_assert_emitter( + isinstance(existing_evalue.val, Tensor), + self.node, + f"spec2id_dict entry for TensorSpec must point to a " + f"Tensor EValue, got " + f"{type(existing_evalue.val).__name__}", + ) + return _AbstractValue(existing_id, existing_evalue.val) + ret = self._emit_evalue(self._tensor_spec_to_evalue(spec)) # pyre-ignore self.emitter_state.spec2id_dict[spec] = ret.id # pyre-ignore return ret @@ -2020,6 +2048,14 @@ def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool: ) value = self._emit_evalue(evalue) + # Populate spec2id_dict so downstream `_emit_spec` calls (e.g., + # for in-place op result FX nodes whose spec was aliased onto + # this placeholder's spec by the planner's + # `_alias_inplace_result_specs`) reuse this placeholder's + # value_id rather than creating a new Value. + if isinstance(spec, TensorSpec): + self.emitter_state.spec2id_dict[spec] = value.id + # Only user inputs should remain as inputs. if is_user_input: self.inputs.append(value.id) diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 9b19cbc7770..3c9f4313ae2 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -30,7 +30,12 @@ from executorch.exir.control_flow import while_loop as exir_while from executorch.exir.delegate import executorch_call_delegate from executorch.exir.error import internal_assert, InternalError -from executorch.exir.operator.convert import is_inplace_variant, is_out_variant +from executorch.exir.operator.convert import ( + is_inplace_variant, + is_out_variant, + output_to_aliased_input_map, + unwrap_op_overload, +) from executorch.exir.schema import DeviceType, NonConstBufferDevice, TensorShapeDynamism from executorch.exir.tensor import TensorSpec from torch import fx @@ -302,13 +307,137 @@ def _is_out_var_node(node: torch.fx.Node) -> bool: def _is_inplace_node(node: torch.fx.Node) -> bool: - return ( - node.op == "call_function" - and isinstance(node.target, torch._ops.OpOverload) - and is_inplace_variant( - node.target._schema.name, node.target._schema.overload_name + if node.op != "call_function": + return False + target = node.target + if not isinstance(target, torch._ops.OpOverload) and not isinstance( + getattr(target, "_op", None), torch._ops.OpOverload + ): + return False + op = unwrap_op_overload(target) + return is_inplace_variant(op._schema.name, op._schema.overload_name) + + +def _alias_inplace_result_specs(node: torch.fx.Node) -> None: # noqa: C901 + """Alias an in-place op's result TensorSpec(s) onto the corresponding + input's spec. + + In-place ops (schema kind == inplace) mutate one or more of their + inputs and return tensors that alias them, declared via the + ``Tensor(a!)`` schema annotation. To make the memory planner treat + result and aliased input as one storage, we copy the input's spec + object onto the output's ``node.meta["spec"]`` slot. + + Output→input correspondence is computed via + ``output_to_aliased_input_map``, which matches each return's + write-alias set against the inputs that share it. + + Gating: + + - Only runs for in-place nodes (caller checks ``_is_inplace_node``). + - Multi-output in-place ops are supported when each return's alias + set matches exactly one input's alias set. + - Falls through silently when alias info is absent or unparseable, + preserving the original spec. ``logging.debug`` records each + early-return reason so silent regressions are observable. + """ + target = node.target + op = unwrap_op_overload(target) + + schema = op._schema + out_to_in = output_to_aliased_input_map(schema) + if not out_to_in: + logging.debug( + f"_alias_inplace_result_specs: schema for {op} declares no " + f"write-aliased outputs matching an input; skipping." ) - ) + return + + # Normalize the current spec container into a list for uniform + # handling. Caller guarantees this node was identified as an + # in-place op with a meta spec (see `_is_inplace_node`), so an + # unrecognized container shape is a real bug — assert loudly so it + # surfaces in tests rather than silently disabling aliasing. + current = node.meta.get("spec") + if isinstance(current, TensorSpec): + out_specs_list: List[Optional[TensorSpec]] = [current] + return_container_kind = "scalar" + elif isinstance(current, (list, tuple)): + out_specs_list = list(current) + return_container_kind = type(current).__name__ + else: + raise InternalError( + f"_alias_inplace_result_specs: in-place node {node.name} " + f"({op}) has unrecognized spec container of type " + f"{type(current).__name__!r}; expected TensorSpec, list, " + f"or tuple." + ) + + # Compute new spec for each return; None means "keep original". + # Mutated inputs are usually positional (the `Tensor(a!)` `self` + # arg), but custom ops may pass them via kwargs — fall back to + # `node.kwargs[arg_name]` in that case. + replacements: List[Optional[TensorSpec]] = [None] * len(out_specs_list) + + for out_idx, in_idx in out_to_in.items(): + if out_idx >= len(out_specs_list): + logging.debug( + f"_alias_inplace_result_specs: schema for {op} declares " + f"return {out_idx} but spec container has only " + f"{len(out_specs_list)} entries; skipping this return." + ) + continue + # Resolve the mutated input. Prefer positional args, fall back + # to kwargs by argument name (custom ops may pass `Tensor(a!)` + # args via kwargs). + in_node: object + if in_idx < len(node.args): + in_node = node.args[in_idx] + else: + arg_name = ( + schema.arguments[in_idx].name + if in_idx < len(schema.arguments) + else None + ) + if arg_name is None or arg_name not in node.kwargs: + logging.debug( + f"_alias_inplace_result_specs: schema for {op} " + f"expects mutated input at position {in_idx} " + f"(name={arg_name!r}) but it is supplied neither " + f"positionally nor via kwargs; skipping." + ) + continue + in_node = node.kwargs[arg_name] + if not isinstance(in_node, torch.fx.Node): + continue + # NOTE: alias unconditionally — including when the input is a + # placeholder (named buffer / mutable input). Skipping + # placeholders would leave a dangling out-arg on in-place op + # instructions whose self is a buffer; the runtime kernel + # writes to that out-arg's storage rather than mutating the + # buffer, producing incorrect results. The companion change in + # `_emit_spec` (emit/_emitter.py) deduplicates by spec identity + # so this aliasing doesn't produce two Values for the same FQN. + in_spec = in_node.meta.get("spec") + if not isinstance(in_spec, TensorSpec): + continue + replacements[out_idx] = in_spec + + if not any(r is not None for r in replacements): + return + + # Assemble the new spec container, preserving the original shape. + new_list = [ + replacements[i] if replacements[i] is not None else out_specs_list[i] + for i in range(len(out_specs_list)) + ] + if return_container_kind == "scalar": + if isinstance(new_list[0], TensorSpec): + node.meta["spec"] = new_list[0] + elif return_container_kind == "list": + node.meta["spec"] = list(new_list) + elif return_container_kind == "tuple": + node.meta["spec"] = tuple(new_list) def update_tensor_lifetime( @@ -474,6 +603,11 @@ def collect_specs_from_nodes( # noqa: C901 continue if _is_inplace_node(node): + # Schema-driven: alias the in-place op's result spec(s) onto + # the corresponding mutated input's spec so the planner gives + # them the same allocation slot. See `_alias_inplace_result_specs` + # for the full rationale and gating rules. + _alias_inplace_result_specs(node) continue if _is_mutable_buffer(node, graph_signature) and ignore_mutable_buffers: diff --git a/exir/operator/convert.py b/exir/operator/convert.py index 74bd686c542..ab0dced87f5 100644 --- a/exir/operator/convert.py +++ b/exir/operator/convert.py @@ -25,7 +25,7 @@ import dataclasses import logging -from typing import Dict, Optional, Tuple +from typing import Dict, FrozenSet, Optional, Tuple import torch from torch._ops import OpOverload @@ -105,6 +105,70 @@ def get_out_args_from_schema(out_var_schema: FunctionSchema) -> Tuple[str]: return tuple(arg.name for arg in out_var_schema.arguments.out) +def unwrap_op_overload(target: object) -> torch._ops.OpOverload: + """Return the underlying ``torch._ops.OpOverload`` for a node target. + + Handles both raw aten ops (``torch.ops.aten.X.Y``) and edge-dialect + wrappers (``ops.edge.aten.X.Y`` / ``BackendOpOverload``), which expose + the underlying ATen op as ``target._op`` but are not themselves + subclasses of ``torch._ops.OpOverload``. + + Raises ``TypeError`` if ``target`` is neither an ``OpOverload`` nor + wraps one. + """ + if isinstance(target, torch._ops.OpOverload): + return target + underlying = getattr(target, "_op", None) + if isinstance(underlying, torch._ops.OpOverload): + return underlying + raise TypeError( + f"unwrap_op_overload: expected a torch._ops.OpOverload or a " + f"wrapper exposing one as `_op`, got {type(target).__name__}: " + f"{target!r}" + ) + + +def output_to_aliased_input_map( + schema: torch.FunctionSchema, +) -> Dict[int, int]: + """For a mutating op, return a map from output index to input arg + index, where the output aliases (i.e., mutates) the input via a + shared ``Tensor(a!)`` write-alias set. + + Only outputs whose ``alias_info.is_write`` is True are considered. + Inputs without a write-aliased ``alias_info`` cannot be aliased + targets and are skipped. If multiple inputs share a single + return's alias set (no known cases in aten today), the first + matching input wins. + + Returns an empty dict for ops with no write-aliased outputs (i.e., + purely functional ops, or ops whose schema declares no returns). + + Note: ``schema`` is the pybind ``torch._C.FunctionSchema`` (as + obtained from ``op._schema``), not the torchgen native + ``FunctionSchema`` used elsewhere in this file. + """ + # Build alias_set -> input_idx so we can look up each return. + alias_set_to_input: Dict[FrozenSet[str], int] = {} + for in_idx, arg in enumerate(schema.arguments): + info = arg.alias_info + if info is None or not info.is_write: + continue + alias_set = frozenset(info.before_set) + alias_set_to_input.setdefault(alias_set, in_idx) + + out_to_in: Dict[int, int] = {} + for out_idx, ret in enumerate(schema.returns): + info = ret.alias_info + if info is None or not info.is_write: + continue + alias_set = frozenset(info.before_set) + in_idx = alias_set_to_input.get(alias_set) + if in_idx is not None: + out_to_in[out_idx] = in_idx + return out_to_in + + def parse_qualified_opname(qualified_opname: str) -> Tuple[str, str]: """ Given a qualified opname like aten::add, return a tuple for namespace diff --git a/exir/operator/test/test_operator.py b/exir/operator/test/test_operator.py index 8235fd0ba82..770991b15f9 100644 --- a/exir/operator/test/test_operator.py +++ b/exir/operator/test/test_operator.py @@ -9,7 +9,12 @@ import unittest import torch -from executorch.exir.operator.convert import _get_overload_schema, to_out_variant +from executorch.exir.operator.convert import ( + _get_overload_schema, + output_to_aliased_input_map, + to_out_variant, + unwrap_op_overload, +) from executorch.exir.operator.util import gen_out_variant_schema from torch.library import _scoped_library, impl, impl_abstract @@ -82,3 +87,86 @@ def custom_mutator_out( schema.__str__(), "DO_NOT_USE_TEST_ONLY::custom_mutator.out(Tensor x, Tensor(a!) y, *, Tensor(b!) out) -> Tensor(b!)", ) + + +class TestUnwrapOpOverload(unittest.TestCase): + def test_aten_overload_returned_as_is(self) -> None: + op = torch.ops.aten.add.Tensor + self.assertIs(unwrap_op_overload(op), op) + + def test_wrapper_with_op_attr_peeled_to_aten(self) -> None: + # Mimic the structural shape of `EdgeOpOverload` / + # `BackendOpOverload`: a non-OpOverload wrapper that exposes + # the underlying aten op via `_op`. + class _FakeWrapper: # noqa: B903 + def __init__(self, op: torch._ops.OpOverload) -> None: + self._op = op + + aten_op = torch.ops.aten.add.Tensor + wrapper = _FakeWrapper(aten_op) + self.assertIs(unwrap_op_overload(wrapper), aten_op) + + def test_non_op_raises(self) -> None: + with self.assertRaises(TypeError): + unwrap_op_overload("not an op") + with self.assertRaises(TypeError): + unwrap_op_overload(None) + with self.assertRaises(TypeError): + unwrap_op_overload(42) + + def test_wrapper_with_non_op_underlying_raises(self) -> None: + class _BadWrapper: + _op = "not an op overload" + + with self.assertRaises(TypeError): + unwrap_op_overload(_BadWrapper()) + + +class TestOutputToAliasedInputMap(unittest.TestCase): + def test_functional_op_returns_empty(self) -> None: + # `aten::add.Tensor` is purely functional — no Tensor(a!) on + # any return. + schema = torch.ops.aten.add.Tensor._schema + self.assertEqual(output_to_aliased_input_map(schema), {}) + + def test_single_output_inplace_op(self) -> None: + # `aten::index_put_` mutates `self` (arg 0) and returns it + # (return 0). + schema = torch.ops.aten.index_put_.default._schema + self.assertEqual(output_to_aliased_input_map(schema), {0: 0}) + + def test_single_output_inplace_via_pybind_parse(self) -> None: + # Synthetic single-mutation schema parsed via the pybind + # FunctionSchema parser; mutates `self` at position 0 and + # returns it. + schema = torch._C.parse_schema( + "test::single(Tensor(a!) self, int n) -> Tensor(a!)" + ) + self.assertEqual(output_to_aliased_input_map(schema), {0: 0}) + + def test_multi_output_inplace_via_pybind_parse(self) -> None: + # Synthetic multi-mutation schema: two write-aliased inputs + # `a` and `b`, each returned as its own aliased output. + schema = torch._C.parse_schema( + "test::multi(Tensor(x!) a, Tensor(y!) b, int n) " + "-> (Tensor(x!), Tensor(y!))" + ) + # Output 0 (alias set {x}) → input 0 (a). + # Output 1 (alias set {y}) → input 1 (b). + self.assertEqual(output_to_aliased_input_map(schema), {0: 0, 1: 1}) + + def test_partial_aliasing_returns_only_matched(self) -> None: + # Two returns, only the first carries write-alias info. + schema = torch._C.parse_schema( + "test::partial(Tensor(z!) self, Tensor other) " "-> (Tensor(z!), Tensor)" + ) + self.assertEqual(output_to_aliased_input_map(schema), {0: 0}) + + def test_tied_inputs_first_match_wins(self) -> None: + # Two inputs share the same write-alias set; per the docstring + # contract ("the first matching input wins"), the helper must + # map the single output back to input index 0, not 1. + schema = torch._C.parse_schema( + "test::tied(Tensor(a!) x, Tensor(a!) y) -> Tensor(a!)" + ) + self.assertEqual(output_to_aliased_input_map(schema), {0: 0}) diff --git a/exir/passes/reinplace.py b/exir/passes/reinplace.py index 349869a2f4b..3c6bad77da7 100644 --- a/exir/passes/reinplace.py +++ b/exir/passes/reinplace.py @@ -6,19 +6,224 @@ # pyre-strict -from typing import Set +from typing import Any, Dict, FrozenSet, Iterable, Optional, Set, Tuple import torch from executorch.exir.dialects._ops import ops +from executorch.exir.operator.convert import ( + output_to_aliased_input_map, + unwrap_op_overload, +) from torch.export import ExportedProgram -def _is_index_put(node: torch.fx.Node) -> bool: - """Check if a node is an index_put operation.""" - return node.op == "call_function" and node.target in ( - torch.ops.aten.index_put.default, +# --------------------------------------------------------------------------- +# Public API for extending the pass with additional ops. +# --------------------------------------------------------------------------- + + +# Default set of edge-dialect functional ops the pass attempts to rewrite +# in-place. Today only `index_put -> index_put_`. The pass auto-derives +# the in-place form by name + schema match — callers who add ops here +# do not need to specify the in-place op explicitly. +# +# `reinplace_pass` runs after `to_edge` (inside `to_executorch`), so it +# only ever sees edge-dialect targets (`EdgeOpOverload`). Aten targets +# do not appear at this stage; only edge ops belong here. +DEFAULT_INPLACEABLE_OPS: FrozenSet[Any] = frozenset( + { ops.edge.aten.index_put.default, - ) + } +) + + +# --------------------------------------------------------------------------- +# Schema-based discovery and validation. +# --------------------------------------------------------------------------- + + +def _op_schema(op: Any) -> torch.FunctionSchema: + """Return the underlying ``FunctionSchema`` for an op overload. + + Delegates to ``unwrap_op_overload`` to peel any edge-dialect or + backend wrapper down to the bare ``torch._ops.OpOverload`` before + reading ``_schema``. Falls back to ``op._schema`` if the input is + a schema-bearing object that ``unwrap_op_overload`` doesn't + recognize (e.g., a custom op-like with a ``_schema`` attribute but + not a true ``OpOverload``). + """ + try: + return unwrap_op_overload(op)._schema + except TypeError: + return op._schema + + +def _is_inplace_of( + functional_schema: torch.FunctionSchema, + candidate_schema: torch.FunctionSchema, +) -> bool: + """Return True if `candidate_schema` is the in-place form of + `functional_schema`: same input arg types positionally, with the + first arg of the candidate carrying a `Tensor(a!)` write alias. + """ + f_args = functional_schema.arguments + c_args = candidate_schema.arguments + if len(f_args) != len(c_args): + return False + if not c_args: + return False + first = c_args[0] + if first.alias_info is None or not first.alias_info.is_write: + return False + for fa, ca in zip(f_args, c_args): + # Compare JIT types directly; `str(...)` repr equality is + # fragile to type-printing changes (qualified vs. unqualified + # names, container parameterization). Fall back to str only if + # direct equality raises (some C-bound types may not implement + # __eq__ uniformly across Torch versions). + try: + if fa.type != ca.type: + return False + except Exception: + if str(fa.type) != str(ca.type): + return False + return True + + +def _derive_edge_inplace_overload(functional_op: Any) -> Optional[Any]: + """Auto-derive the in-place edge-dialect overload for a functional + op (works for both `EdgeOpOverload` inputs — the common case — + and bare aten `OpOverload` inputs; the result is always an + edge-dialect op). + + Strategy: peel `EdgeOpOverload._op` (no-op for bare aten) to find + the underlying aten schema. The aten in-place form lives at + `torch.ops.aten._` (for example, + `aten::index_put -> aten::index_put_`). Walk the aten in-place + package's overloads (which exposes `.overloads()` cleanly; the + edge package does not) to find the one whose schema matches the + functional schema (with the first arg promoted to `Tensor(a!)`). + Then look up the corresponding edge-dialect op at + `ops.edge.aten._.`. + + Handles cross-overload-name asymmetry — e.g. + `aten::pow.Tensor_Scalar` -> `aten::pow_.Scalar` + where simple name-based lookup `aten.pow_.Tensor_Scalar` fails. + + Note: if multiple aten in-place overloads have schemas that match + the functional op's schema (no known cases in aten today), the + first one returned by `.overloads()` wins. Callers can disambiguate + by passing an explicit `inplace_overrides` entry. + + Returns None if no schema match is found, or the op is not an + aten op. Callers should provide an explicit override via + `inplace_overrides` for non-conventional rewrites. + """ + schema = _op_schema(functional_op) + name = schema.name # e.g. "aten::index_put" + if "::" not in name: + return None + namespace, base = name.split("::", 1) + if namespace != "aten": + return None + + # Find the matching aten in-place overload first (aten exposes + # `.overloads()` cleanly; the edge package does not). + aten_inplace_pkg = getattr(torch.ops.aten, base + "_", None) + if aten_inplace_pkg is None: + return None + matched_overload_name: Optional[str] = None + for overload_name in aten_inplace_pkg.overloads(): + candidate = getattr(aten_inplace_pkg, overload_name, None) + if candidate is None: + continue + cand_schema = getattr(candidate, "_schema", None) + if cand_schema is None: + continue + if _is_inplace_of(schema, cand_schema): + matched_overload_name = overload_name + break + if matched_overload_name is None: + return None + + # Translate to the edge-dialect op of the same name + overload. + edge_inplace_pkg = getattr(ops.edge.aten, base + "_", None) + if edge_inplace_pkg is None: + return None + return getattr(edge_inplace_pkg, matched_overload_name, None) + + +def _validate_inplace_mapping(functional_op: Any, inplace_op: Any) -> None: + """Validate that `inplace_op` is a plausible in-place form of + `functional_op`. Raises `ValueError` on misregistration. + + Two checks: + 1. **Schema shape**: same arg types positionally, first arg of + `inplace_op` carries `Tensor(a!)`. Catches gross mismatches + like `add -> mul_` (different signatures). + 2. **Name affinity**: in-place op's name starts with the + functional op's name + "_". Catches subtle mismatches with + matching schemas like `add -> sub_`. + """ + f_schema = _op_schema(functional_op) + i_schema = _op_schema(inplace_op) + + if not _is_inplace_of(f_schema, i_schema): + raise ValueError( + f"Schema mismatch in reinplace registration: " + f"{functional_op} -> {inplace_op}. " + f"Expected in-place arg types to match functional arg types " + f"positionally with the first arg promoted to Tensor(a!). " + f"Got functional schema {f_schema} and in-place schema " + f"{i_schema}." + ) + + expected_prefix = f_schema.name + "_" + if not i_schema.name.startswith(expected_prefix): + raise ValueError( + f"Suspicious reinplace registration: " + f"{functional_op} -> {inplace_op}. " + f"In-place op name '{i_schema.name}' should start with " + f"'{expected_prefix}'. This usually indicates a typo " + f"(e.g. 'add -> sub_' instead of 'add -> add_')." + ) + + +def _derive_mutated_args(inplace_op: Any) -> Tuple[int, ...]: + """Return the positional indices of args that the in-place op + mutates, derived from the schema's `Tensor(a!)` annotations. + + Computed via ``output_to_aliased_input_map``: each return with a + write alias points back to the input arg position carrying the + same alias set. The mutated positions are the input indices that + appear as values in that map. + + Raises `ValueError` if `inplace_op` has no write-aliased outputs + that match an input (i.e. it isn't actually an in-place op). The + schema is the source of truth — if a custom op truly mutates an + arg, its schema must declare so via `Tensor(a!)` on both the + mutated input and the corresponding return. Otherwise + functionalization, memory planning, and autograd will all silently + mishandle it. + """ + schema = _op_schema(inplace_op) + out_to_in = output_to_aliased_input_map(schema) + indices = tuple(sorted(set(out_to_in.values()))) + if not indices: + raise ValueError( + f"{inplace_op} has no Tensor(a!) write-aliased args that " + f"match a corresponding return. If this op truly mutates, " + f"fix its schema to declare the mutated arg(s) with " + f"`Tensor(a!)` on both the input and the matching return. " + f"The schema is the contract that all of export, memory " + f"planning, and functionalization rely on." + ) + return indices + + +# --------------------------------------------------------------------------- +# Internal helpers. +# --------------------------------------------------------------------------- def _is_safe_to_reinplace( @@ -56,48 +261,172 @@ def _is_mutable_buffer(node: torch.fx.Node, exported_program: ExportedProgram) - return buf in exported_program.graph_signature.buffers_to_mutate.values() -def reinplace_pass(ep: ExportedProgram) -> ExportedProgram: - """ - Pass that loops over nodes in an exported program and collects the first argument - of every call_function node that is a view_copy operation. +# --------------------------------------------------------------------------- +# Pass entry point. +# --------------------------------------------------------------------------- + + +def reinplace_pass( # noqa: C901 + ep: ExportedProgram, + ops_to_inplace: Optional[Iterable[Any]] = None, + inplace_overrides: Optional[Dict[Any, Any]] = None, +) -> ExportedProgram: + """Rewrite functional ops in-place when safe. + + Walks the graph in reverse topological order. For each + `call_function` node whose target is in the resolved set of + in-place candidates, checks whether each mutated-arg position is + safe to reinplace (`_is_safe_to_reinplace`). If all checks pass, + replaces the node with a call to its in-place op variant. + + Safety rules: + * The mutated arg must not be used by any later node in the + graph. + * If the mutated arg is a placeholder (program input), it must + be a *mutable* input — i.e., declared in + `graph_signature.user_inputs_to_mutate` or + `graph_signature.buffers_to_mutate`. Immutable inputs are + never reinplaced because mutating them would be a side effect + the caller did not opt into. Args: - exported_program: The ExportedProgram to analyze + ep: The ExportedProgram to rewrite. Modified in place; returned + for chaining. + ops_to_inplace: Optional iterable of functional edge-dialect + ops the pass should try to reinplace. The in-place form is + auto-derived from the schema (`X -> X_*` by name + schema + match). Defaults to `DEFAULT_INPLACEABLE_OPS`. Pass an + empty iterable to disable all rewrites: + + from executorch.exir.dialects._ops import ops + + reinplace_pass( + ep, + ops_to_inplace=DEFAULT_INPLACEABLE_OPS | { + ops.edge.aten.index_copy.default, + }, + ) + + inplace_overrides: Optional explicit map for non-conventional + rewrites (e.g. backend-fused in-place ops whose name does + not follow the `X -> X_*` convention). Keys are also + included in the set of ops to consider — you do NOT need + to also list them in `ops_to_inplace`. Example: + + reinplace_pass( + ep, + inplace_overrides={ + my_backend.functional: my_backend.fused_inplace, + }, + ) + + Each entry is validated at pass startup: + * Schema arg types must match positionally between functional + and in-place forms. + * The in-place op's first arg must carry `Tensor(a!)`. + * The in-place op's name must start with the functional op's + name + "_" (e.g. `aten::add -> aten::add_*`). + * The in-place op must have at least one `Tensor(a!)` arg. + Misregistrations raise `ValueError` immediately. Returns: - Set of nodes that are first arguments to view_copy operations + The (possibly mutated) ExportedProgram. """ + overrides = inplace_overrides or {} + op_set: Set[Any] = set( + ops_to_inplace if ops_to_inplace is not None else DEFAULT_INPLACEABLE_OPS + ) + # Overrides also enroll their key in the candidate set. + op_set.update(overrides.keys()) + + # Validate every entry up front and pre-compute mutated_args so we + # don't re-do the schema introspection per node. + resolved: Dict[Any, Tuple[Any, Tuple[int, ...]]] = {} + for functional_op in op_set: + if functional_op in overrides: + inplace_op = overrides[functional_op] + else: + inplace_op = _derive_edge_inplace_overload(functional_op) + if inplace_op is None: + raise ValueError( + f"Cannot auto-derive in-place form for " + f"{functional_op}. Provide an explicit mapping via " + f"`inplace_overrides={{{functional_op}: }}`." + ) + _validate_inplace_mapping(functional_op, inplace_op) + mutated_args = _derive_mutated_args(inplace_op) + resolved[functional_op] = (inplace_op, mutated_args) + seen_nodes: Set[torch.fx.Node] = set() # Get all placeholders - inputs = set() + inputs: Set[torch.fx.Node] = set() for node in ep.graph.nodes: if node.op == "placeholder": inputs.add(node) # Get all inputs that we could potentially mutate - mutable_nodes = set( - [ - node - for node in inputs - if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep) - ] - ) + mutable_nodes: Set[torch.fx.Node] = { + node + for node in inputs + if _is_mutable_user_input(node, ep) or _is_mutable_buffer(node, ep) + } - results = set() for node in reversed(ep.graph.nodes): - if _is_index_put(node): - # Check if this index_put node is safe to inplace - # The first argument is the base tensor being indexed into - first_arg = node.args[0] - if _is_safe_to_reinplace(first_arg, seen_nodes, inputs, mutable_nodes): - # This index_put is safe to reinplace + entry = resolved.get(node.target) if node.op == "call_function" else None + if entry is not None: + inplace_op, mutated_args = entry + # Every mutated arg position must independently be safe. + all_safe = True + for arg_idx in mutated_args: + if arg_idx >= len(node.args): + raise ValueError( + f"reinplace: {node.target} call at {node} has " + f"{len(node.args)} positional args, but the " + f"schema declares position {arg_idx} as " + f"Tensor(a!). Export should normalize mutated " + f"args to positional; this graph violates that " + f"assumption." + ) + arg_node = node.args[arg_idx] + if not isinstance(arg_node, torch.fx.Node): + raise ValueError( + f"reinplace: {node.target} call at {node} has a " + f"non-Node value {arg_node!r} at position " + f"{arg_idx}, but the schema declares it as " + f"Tensor(a!). A Tensor input in an FX graph " + f"must be a torch.fx.Node." + ) + if not _is_safe_to_reinplace( + arg_node, seen_nodes, inputs, mutable_nodes + ): + all_safe = False + break + if all_safe: with ep.graph.inserting_before(node): + # Forward both args and kwargs: the in-place overload + # is schema-matched to the functional one, so any + # kwarg valid on the functional op (e.g. + # `accumulate=` for `index_put`) is also valid on + # the in-place form. Dropping kwargs would silently + # change semantics. new_node = ep.graph.call_function( - ops.edge.aten.index_put_.default, args=node.args + inplace_op, + args=node.args, + kwargs=node.kwargs, ) new_node.meta["val"] = node.meta["val"] node.replace_all_uses_with(new_node) ep.graph.erase_node(node) - results.add(first_arg) - elif node.op == "call_function": + # No explicit `seen_nodes` update needed: the new + # in-place node's target isn't in `op_set`, so the + # reverse iterator visits it next and falls through + # to the generic update below. + continue + # Note: this intentionally falls through for mapping-matched + # nodes that failed the safety check. Their inputs *are* added + # to seen_nodes, so further-upstream candidates correctly see + # those tensors as "used later" and refuse to reinplace any op + # that mutates them. + # See test_unsafe_downstream_blocks_upstream_reinplace. + if node.op == "call_function": seen_nodes.update(node.all_input_nodes) return ep diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index e0df7a713e6..8227f3a54b0 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -24,7 +24,7 @@ del logging import torch -from executorch.exir import ExecutorchBackendConfig, to_edge +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge from executorch.exir.capture._capture import patch_forward from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import ( @@ -46,6 +46,7 @@ SpecPropPass, ToOutVarPass, ) +from executorch.exir.passes.reinplace import DEFAULT_INPLACEABLE_OPS, reinplace_pass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.schema import DeviceType from executorch.exir.tensor import TensorSpec @@ -676,51 +677,69 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: et = to_edge(export(model, inputs, strict=True)).to_executorch() - # The mutable buffer (5x5 float32 = 100 bytes) should not be double allocated. - # The input and output of copy_ should share the same memory location. - values = et.executorch_program.execution_plan[0].values - expected_buffer_size = 5 * 5 * 4 # 5x5 float32 - - # Collect all tensor allocations by their (memory_id, offset) and track sizes - # Size is computed from tensor's sizes and scalar_type, not from allocation_info - # (memory_offset_low/high are low/high 32-bit parts of a 64-bit offset, not bounds) - scalar_type_sizes = { - 0: 1, # BYTE - 1: 1, # CHAR - 2: 2, # SHORT - 3: 4, # INT - 4: 8, # LONG - 5: 2, # HALF - 6: 4, # FLOAT - 7: 8, # DOUBLE - } - offset_to_indices = {} - for i, val in enumerate(values): - tensor = val.val - if hasattr(tensor, "allocation_info") and tensor.allocation_info: - alloc = tensor.allocation_info - # Compute tensor size from sizes and scalar_type - num_elements = 1 - for dim in tensor.sizes: - num_elements *= dim - element_size = scalar_type_sizes.get(int(tensor.scalar_type), 4) - size = num_elements * element_size - key = (alloc.memory_id, alloc.memory_offset) - if key not in offset_to_indices: - offset_to_indices[key] = {"indices": [], "size": size} - offset_to_indices[key]["indices"].append(i) - - # Find shared allocations matching the mutable buffer size (before/after copy_) - mutable_buffer_shares = [ - info - for info in offset_to_indices.values() - if len(info["indices"]) == 2 and info["size"] == expected_buffer_size - ] + # The mutable buffer (5x5 float32 = 100 bytes) should not be + # double allocated. After the upstream emit dedup + # (`_emit_spec` reusing value_id when two FX nodes share a + # TensorSpec via the planner's `_alias_inplace_result_specs`), + # the `copy_` writeback's "out" arg uses the SAME value_id as + # its "self" arg (the buffer), rather than creating a separate + # Value at the same (mem_id, offset). + execution_plan = et.executorch_program.execution_plan[0] + values = execution_plan.values + + # Find the `copy_` writeback instruction. + copy_instructions = [] + for chain in execution_plan.chains: + for ins in chain.instructions: + inner = ins.instr_args + if hasattr(inner, "op_index"): + op = execution_plan.operators[inner.op_index] + if op.name == "aten::copy_": + copy_instructions.append(inner) self.assertEqual( - len(mutable_buffer_shares), + len(copy_instructions), 1, - f"Expected exactly one shared allocation of size {expected_buffer_size} " - f"with 2 values (copy_ input/output), found: {mutable_buffer_shares}", + "Expected exactly one copy_ writeback for the buffer mutation", + ) + + # For an in-place copy_(self, src, ..., out), self (arg 0) and + # out (the emitted synthetic last arg) must share a value_id + # per the `(a!)` schema annotation. Emit's spec2id_dict + # dedup enforces this. + copy_args = list(copy_instructions[0].args) + self.assertEqual( + copy_args[0], + copy_args[-1], + f"copy_'s out arg should reference the same value_id as its " + f"self arg (buffer) via emit dedup. args={copy_args}", + ) + + # Additionally verify no distinct second Value at the buffer's + # (mem_id, offset): after dedup, the buffer occupies its slot alone. + buffer_value_id = copy_args[0] + buffer_val = values[buffer_value_id].val + self.assertTrue( + hasattr(buffer_val, "allocation_info") and buffer_val.allocation_info, + "Buffer value should have allocation_info", + ) + buffer_alloc = buffer_val.allocation_info + duplicates_at_buffer_slot = [ + i + for i, val in enumerate(values) + if i != buffer_value_id + and hasattr(val.val, "allocation_info") + and val.val.allocation_info + and val.val.allocation_info.memory_id == buffer_alloc.memory_id + and val.val.allocation_info.memory_offset == buffer_alloc.memory_offset + ] + self.assertEqual( + duplicates_at_buffer_slot, + [], + f"Expected no other Values at the buffer's allocation " + f"(mem_id={buffer_alloc.memory_id}, " + f"offset={buffer_alloc.memory_offset}); emit dedup should " + f"collapse placeholder + writeback into one value_id. " + f"Found duplicates at indices: {duplicates_at_buffer_slot}", ) def test_mutable_buffers_infinite_lifespan(self) -> None: @@ -764,6 +783,147 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) self.assertTrue(not_overlapping) + def test_custom_inplace_op_memory_aliasing(self) -> None: + """Memory planning correctly handles in-place ops registered via + the ``ops_to_inplace`` extension API (i.e. outside + ``DEFAULT_INPLACEABLE_OPS``). + + Uses the HF-static-cache pattern: ``index_copy_`` updates two + mutable buffers (``keys``, ``values``). We: + 1. Preserve ``index_copy`` through edge lowering. + 2. Manually call ``reinplace_pass`` with a custom set that + includes ``index_copy`` (the in-place form is + auto-derived). + 3. Lower with ``run_reinplace_pass=False`` (the pass already + ran). + + Then assert that no other planned tensor's allocation overlaps + either buffer's storage region. This pins the schema-driven + ``_alias_inplace_result_specs`` path for non-default ops: the + ``index_copy_`` result spec must be aliased to the buffer's + spec, otherwise the planner would carve out a separate + allocation that could land inside the buffer's slot. + """ + max_batch_size, num_heads, max_cache_len, head_dim = 1, 2, 4, 8 + + class HFStyleStaticCache(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "keys", + torch.zeros((max_batch_size, num_heads, max_cache_len, head_dim)), + ) + self.register_buffer( + "values", + torch.zeros((max_batch_size, num_heads, max_cache_len, head_dim)), + ) + + def forward( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_position: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + return self.keys, self.values + + model = HFStyleStaticCache() + key_states = torch.full((max_batch_size, num_heads, 1, head_dim), 1.0) + value_states = torch.full((max_batch_size, num_heads, 1, head_dim), 2.0) + cache_position = torch.tensor([1]) + + exported_program = export( + model, (key_states, value_states, cache_position), strict=True + ) + + edge = to_edge( + exported_program, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + preserve_ops=[torch.ops.aten.index_copy.default], + ), + ) + + # Manually run reinplace_pass with a custom set that + # includes the (non-default) index_copy edge op. The in-place + # form is auto-derived by name + schema match. + custom_set = DEFAULT_INPLACEABLE_OPS | { + exir_ops.edge.aten.index_copy.default, + } + edge_program = reinplace_pass( + edge.exported_program(), ops_to_inplace=custom_set + ) + # Sanity: both updates are now in-place. + inplace_nodes = [ + n + for n in edge_program.graph.nodes + if n.op == "call_function" and "index_copy_" in str(n.target) + ] + self.assertEqual( + len(inplace_nodes), + 2, + "Both buffer updates should be reinplaced before lowering", + ) + + # Lower with run_reinplace_pass=False — the pass already ran + # with our custom set above. Memory planning should now + # correctly alias the index_copy_ result spec onto the buffer + # placeholder spec via _alias_inplace_result_specs. + et = edge.to_executorch( + ExecutorchBackendConfig( + emit_mutable_buffer_names=True, + run_reinplace_pass=False, + ) + ) + + execution_plan = et.executorch_program.execution_plan[0] + values = execution_plan.values + + # Collect the keys / values buffer Values by FQN. + buffer_value_ids: dict[str, int] = {} + for i, value in enumerate(values): + val = value.val + extra = getattr(val, "extra_tensor_info", None) + fqn = getattr(extra, "fully_qualified_name", None) if extra else None + if fqn in ("keys", "values"): + buffer_value_ids[fqn] = i + + self.assertEqual( + set(buffer_value_ids.keys()), + {"keys", "values"}, + "Both keys and values buffers should appear in the program " + "with their FQN", + ) + + # For each buffer, verify no other planned Value's allocation + # overlaps the buffer's memory region. + for fqn, vid in buffer_value_ids.items(): + buf_alloc = values[vid].val.allocation_info + self.assertIsNotNone(buf_alloc, f"Buffer {fqn} should have allocation_info") + buf_base = buf_alloc.memory_offset_low + # 4 bytes per float32 element. + num_elements = max_batch_size * num_heads * max_cache_len * head_dim + buf_end = buf_base + num_elements * 4 + + for j, other in enumerate(values): + if j == vid: + continue + other_alloc = getattr(other.val, "allocation_info", None) + if other_alloc is None: + continue + if other_alloc.memory_id != buf_alloc.memory_id: + continue + offset = other_alloc.memory_offset_low + overlaps = buf_base <= offset < buf_end + self.assertFalse( + overlaps, + f"Value {j} (alloc offset={offset}) overlaps the " + f"{fqn} buffer's region [{buf_base}, {buf_end}) — " + "the in-place index_copy_ result spec was not " + "correctly aliased to the buffer spec", + ) + def test_constants_not_memory_planned(self) -> None: class Simple(torch.nn.Module): def __init__(self) -> None: diff --git a/exir/tests/test_reinplace_pass.py b/exir/tests/test_reinplace_pass.py index 13661ef8cf9..c7b72a73a13 100644 --- a/exir/tests/test_reinplace_pass.py +++ b/exir/tests/test_reinplace_pass.py @@ -7,15 +7,37 @@ # pyre-strict import unittest +from typing import List, Optional import torch -from executorch.exir import to_edge +from executorch.exir import EdgeCompileConfig, to_edge from executorch.exir.capture._config import ExecutorchBackendConfig -from executorch.exir.passes.reinplace import reinplace_pass +from executorch.exir.dialects._ops import ops as edge_ops +from executorch.exir.passes.reinplace import DEFAULT_INPLACEABLE_OPS, reinplace_pass from executorch.extension.pybindings.portable_lib import ( # @manual=//executorch/extension/pybindings:portable_lib _load_for_executorch_from_buffer, ) from torch.export import export +from torch.export.exported_program import ExportedProgram + + +def _find_nodes( + ep: ExportedProgram, + contains: str, + excludes: Optional[str] = None, +) -> List[torch.fx.Node]: + """Return all ``call_function`` nodes whose target string contains + ``contains``. If ``excludes`` is given, also drop any node whose + target string contains that substring (used to distinguish the + functional form ``index_put`` from the in-place ``index_put_``). + """ + return [ + n + for n in ep.graph.nodes + if n.op == "call_function" + and contains in str(n.target) + and (excludes is None or excludes not in str(n.target)) + ] class TestReinplacePass(unittest.TestCase): @@ -42,33 +64,24 @@ def forward( edge = to_edge(exported_program) edge_program = edge._edge_programs["forward"] - # Find the index_put node - index_put_node = None - for node in edge_program.graph.nodes: - if node.op == "call_function" and "index_put" in str(node.target): - index_put_node = node - break - - self.assertIsNotNone(index_put_node, "Should find an index_put node") + self.assertEqual( + len(_find_nodes(edge_program, "index_put")), + 1, + "Should find an index_put node", + ) et = edge.to_executorch(ExecutorchBackendConfig(run_reinplace_pass=True)) - # Find the index_put node - index_put_node = None - for node in et.exported_program().graph.nodes: - if node.op == "call_function" and "index_put_" in str(node.target): - index_put_node = node - break - - self.assertIsNotNone(index_put_node, "Should find an index_put_ node") - - # Find the copy_ node - copy_node = None - for node in et.exported_program().graph.nodes: - if node.op == "call_function" and "copy_" in str(node.target): - copy_node = node - break - - self.assertIsNone(copy_node, "Shouldn't find an copy_ node") + et_program = et.exported_program() + self.assertEqual( + len(_find_nodes(et_program, "index_put_")), + 1, + "Should find an index_put_ node", + ) + self.assertEqual( + len(_find_nodes(et_program, "copy_")), + 0, + "Shouldn't find a copy_ node", + ) e = _load_for_executorch_from_buffer(et.buffer) self.assertTrue( @@ -98,27 +111,471 @@ def forward( values = torch.tensor([1.0]) exported_program = export(model, (indices, values), strict=True) - edge_program = to_edge(exported_program).exported_program() + edge = to_edge(exported_program) + edge_program = edge.exported_program() + + self.assertEqual( + len(_find_nodes(edge_program, "index_put")), + 1, + "Should find an index_put node", + ) + + ep = reinplace_pass(edge_program) + self.assertEqual( + len(_find_nodes(ep, "index_put", excludes="index_put_")), + 1, + "Should still find a functional index_put node", + ) + + # Lower to ExecuTorch and verify runtime correctness against eager. + et = edge.to_executorch() + loaded = _load_for_executorch_from_buffer(et.buffer) + et_out = loaded.forward((indices, values)) + eager_out = IndexPutModel()(indices, values) + self.assertTrue(torch.allclose(et_out[0], eager_out)) + + def test_unsafe_downstream_blocks_upstream_reinplace(self) -> None: + """When an upstream index_put's mutated arg is also an input to an + unsafe downstream index_put, the upstream one must not be reinplaced + either. Otherwise the in-place mutation by the upstream op would + change the value the (still-functional) downstream op reads. + """ + + class TwoIndexPutModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.zeros(5)) + + def forward( + self, indices: torch.Tensor, values: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + # Intermediate tensor consumed by both index_puts. + t = values + 1.0 + # Upstream index_put — mutates t (an intermediate, so + # by itself it would look safe to reinplace). + a = t.index_put((indices,), torch.tensor([5.0])) + # Downstream index_put — reads t. This one is itself + # unsafe because state is read again by the add_ below. + b = self.state.index_put((indices,), t) + self.state.add_(1) + return a, b + + model = TwoIndexPutModel() + # `indices` and `values` must have matching shape so the + # downstream `state.index_put((indices,), t)` is well-formed + # (state[indices[i]] = t[i]). The original test used values + # shape [5] which is fine for graph-shape checks but not + # runnable in ET. + indices = torch.tensor([0]) + values = torch.tensor([2.0]) + + exported_program = export(model, (indices, values), strict=True) + edge = to_edge(exported_program) + edge_program = edge.exported_program() + + # Sanity check: both functional index_puts present pre-pass. + self.assertEqual( + len(_find_nodes(edge_program, "index_put", excludes="index_put_")), + 2, + "Should have two functional index_put nodes", + ) + + ep = reinplace_pass(edge_program) + + # Neither index_put should be reinplaced. The downstream one is + # unsafe directly (state used later); the upstream one is unsafe + # because its mutated arg `t` is read by the unsafe downstream op. + self.assertEqual( + len(_find_nodes(ep, "index_put", excludes="index_put_")), + 2, + "Neither index_put should be reinplaced", + ) + self.assertEqual( + len(_find_nodes(ep, "index_put_")), + 0, + "No index_put_ should have been introduced", + ) + + # Lower to ExecuTorch and verify runtime correctness against eager. + et = edge.to_executorch() + loaded = _load_for_executorch_from_buffer(et.buffer) + et_out = loaded.forward((indices, values)) + eager_a, eager_b = TwoIndexPutModel()(indices, values) + self.assertTrue(torch.allclose(et_out[0], eager_a)) + self.assertTrue(torch.allclose(et_out[1], eager_b)) + + def test_kwargs_are_forwarded(self) -> None: + """When the matched node carries a value in ``node.kwargs`` (e.g. + ``accumulate=True`` for ``index_put``), the rewrite must forward + those kwargs to the in-place form. Otherwise the in-place op + falls back to the schema default and silently changes semantics. + + ``export`` normalizes most arguments into positional form, so we + explicitly move ``accumulate`` into ``node.kwargs`` after export + to exercise the kwarg-forwarding path. + """ + + class IndexPutAccumulateModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.zeros(5)) + + def forward( + self, indices: torch.Tensor, values: torch.Tensor + ) -> torch.Tensor: + self.state.index_put_((indices,), values, accumulate=True) + return self.state + + model = IndexPutAccumulateModel() + indices = torch.tensor([0, 0]) + values = torch.tensor([1.0, 2.0]) + + exported_program = export(model, (indices, values), strict=True) + edge = to_edge(exported_program) + edge_program = edge.exported_program() - # Find the index_put node - index_put_node = None - for node in edge_program.graph.nodes: - if node.op == "call_function" and "index_put" in str(node.target): - index_put_node = node - break + # Find the functional index_put node and force `accumulate` onto + # its kwargs (export normalized it to positional). This is what + # exercises the kwarg-forwarding path in reinplace_pass. + functionals = _find_nodes(edge_program, "index_put", excludes="index_put_") + self.assertEqual(len(functionals), 1, "Should find a functional index_put") + functional = functionals[0] - self.assertIsNotNone(index_put_node, "Should find an index_put node") + # index_put schema: (self, indices, values, accumulate=False). + # Move arg[3] -> kwargs["accumulate"] if present. + if len(functional.args) >= 4: + new_args = functional.args[:3] + new_kwargs = dict(functional.kwargs) + new_kwargs["accumulate"] = functional.args[3] + functional.args = new_args + functional.kwargs = new_kwargs + self.assertEqual( + functional.kwargs.get("accumulate"), + True, + "Test setup: accumulate should now be a kwarg", + ) ep = reinplace_pass(edge_program) - # Find the index_put node - index_put_node = None - for node in ep.graph.nodes: - if ( - node.op == "call_function" - and "index_put" in str(node.target) - and "index_put_" not in str(node.target) - ): - index_put_node = node - break - - self.assertIsNotNone(index_put_node, "Should still find an index_put node") + + # Find the rewritten in-place index_put_ node. + inplace_nodes = _find_nodes(ep, "index_put_") + self.assertEqual(len(inplace_nodes), 1, "Should find an index_put_ node") + index_put_inplace = inplace_nodes[0] + + # `accumulate` must survive the rewrite, in either args or kwargs. + accumulate = index_put_inplace.kwargs.get("accumulate") + if accumulate is None and len(index_put_inplace.args) >= 4: + accumulate = index_put_inplace.args[3] + self.assertEqual( + accumulate, + True, + "accumulate=True must be preserved through the rewrite", + ) + + def test_ops_to_inplace_extends_with_add(self) -> None: + """A custom ``ops_to_inplace`` set can extend the pass to ops + outside the default set. Add the edge-dialect ``add.Tensor`` to + the set; the in-place form (``add_.Tensor``) is auto-derived by + name + schema match. Verify a safe-to-reinplace add gets + rewritten while an unsafe one (mutating an immutable input) + does not. + """ + + class TwoAddModel(torch.nn.Module): + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + ) -> torch.Tensor: + # First add: mutated arg `x` is an immutable user input + # → not safe to reinplace. + t = x + y + # Second add: mutated arg `t` is an intermediate with no + # later use → safe to reinplace. + return t + z + + model = TwoAddModel() + args = (torch.zeros(5), torch.ones(5), torch.full((5,), 2.0)) + exported_program = export(model, args, strict=True) + edge = to_edge(exported_program) + edge_program = edge.exported_program() + + # Sanity: no in-place adds before the pass. + self.assertEqual(len(_find_nodes(edge_program, "add_")), 0) + + custom_set = {edge_ops.edge.aten.add.Tensor} + ep = reinplace_pass(edge_program, ops_to_inplace=custom_set) + + # Exactly one of the two adds should be reinplaced. + self.assertEqual( + len(_find_nodes(ep, "add_")), + 1, + "Exactly one (intermediate-mutating) add should be reinplaced", + ) + self.assertEqual( + len(_find_nodes(ep, "aten.add.Tensor")), + 1, + "The add mutating an immutable input must remain functional", + ) + + # Lower to ExecuTorch. The portable runtime does not register a + # kernel for `add_.Tensor`, so we only verify lowering succeeds + # (the in-place rewrite must serialize cleanly into the ET + # program). Runtime execution is covered by ops with portable + # kernels in the other tests in this file. + edge.to_executorch() + + def test_ops_to_inplace_empty_disables_all_rewrites(self) -> None: + """Passing an empty ``ops_to_inplace`` set should disable every + rewrite, even ops that are in ``DEFAULT_INPLACEABLE_OPS``. + """ + + class IndexPutModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.zeros(5)) + + def forward( + self, indices: torch.Tensor, values: torch.Tensor + ) -> torch.Tensor: + self.state.index_put_((indices,), values) + return self.state + + model = IndexPutModel() + indices = torch.tensor([0]) + values = torch.tensor([1.0]) + + exported_program = export(model, (indices, values), strict=True) + edge = to_edge(exported_program) + edge_program = edge.exported_program() + + # Sanity: index_put is in the default set, so without our + # explicit override it would be reinplaced. + self.assertIn( + edge_ops.edge.aten.index_put.default, + DEFAULT_INPLACEABLE_OPS, + "Sanity check: edge index_put should be in the default set", + ) + + ep = reinplace_pass(edge_program, ops_to_inplace={}) + + self.assertEqual( + len(_find_nodes(ep, "index_put_")), + 0, + "Empty set must disable all rewrites, including the default ones", + ) + self.assertEqual( + len(_find_nodes(ep, "index_put", excludes="index_put_")), + 1, + "The functional index_put must remain", + ) + + # Lower to ExecuTorch and verify runtime correctness against eager. + et = edge.to_executorch() + loaded = _load_for_executorch_from_buffer(et.buffer) + et_out = loaded.forward((indices, values)) + eager_out = IndexPutModel()(indices, values) + self.assertTrue(torch.allclose(et_out[0], eager_out)) + + def test_ops_to_inplace_custom_does_not_inherit_default(self) -> None: + """A custom ``ops_to_inplace`` set replaces — not augments — + ``DEFAULT_INPLACEABLE_OPS``. Passing a set that doesn't include + ``index_put`` leaves it functional, even though it would be + reinplaced under the default. Callers who want to extend rather + than replace should union with ``DEFAULT_INPLACEABLE_OPS`` + explicitly (per the docstring guidance on ``reinplace_pass``). + """ + + class IndexPutModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state", torch.zeros(5)) + + def forward( + self, indices: torch.Tensor, values: torch.Tensor + ) -> torch.Tensor: + self.state.index_put_((indices,), values) + return self.state + + model = IndexPutModel() + indices = torch.tensor([0]) + values = torch.tensor([1.0]) + + exported_program = export(model, (indices, values), strict=True) + edge = to_edge(exported_program) + edge_program = edge.exported_program() + + # Custom set containing only `add` — no `index_put`. + custom_set = {edge_ops.edge.aten.add.Tensor} + ep = reinplace_pass(edge_program, ops_to_inplace=custom_set) + + self.assertEqual( + len(_find_nodes(ep, "index_put_")), + 0, + "index_put must remain functional when not in the custom set", + ) + + # Lower to ExecuTorch and verify runtime correctness against eager. + et = edge.to_executorch() + loaded = _load_for_executorch_from_buffer(et.buffer) + et_out = loaded.forward((indices, values)) + eager_out = IndexPutModel()(indices, values) + self.assertTrue(torch.allclose(et_out[0], eager_out)) + + def test_kv_cache_style_reinplace(self) -> None: + """HF-style static KV cache update via ``index_copy_`` along the + ``cache_len`` dim. Mirrors transformers' ``StaticCache.update``:: + + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + + Cache shape mirrors HF:: + (max_batch_size, num_heads, max_cache_len, head_dim) + + ``index_copy`` is **not** in ``DEFAULT_INPLACEABLE_OPS``, so this + also exercises the new ``ops_to_inplace`` extension API. We use + ``EdgeCompileConfig(preserve_ops=[index_copy])`` to keep the op + from being decomposed by ``to_edge``, then add ``index_copy`` to + a custom set (the in-place form is auto-derived by name + schema + match) and verify both buffer updates are rewritten in-place. + """ + + max_batch_size, num_heads, max_cache_len, head_dim = 1, 2, 4, 8 + + class HFStyleStaticCache(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "keys", + torch.zeros((max_batch_size, num_heads, max_cache_len, head_dim)), + ) + self.register_buffer( + "values", + torch.zeros((max_batch_size, num_heads, max_cache_len, head_dim)), + ) + + def forward( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_position: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # HF static-cache style: in-place index_copy along dim=2. + self.keys.index_copy_(2, cache_position, key_states) + self.values.index_copy_(2, cache_position, value_states) + return self.keys, self.values + + model = HFStyleStaticCache() + key_states = torch.full((max_batch_size, num_heads, 1, head_dim), 1.0) + value_states = torch.full((max_batch_size, num_heads, 1, head_dim), 2.0) + cache_position = torch.tensor([1]) + + exported_program = export( + model, (key_states, value_states, cache_position), strict=True + ) + + # Preserve `index_copy` through the edge lowering so the pass + # actually sees it (default to_edge decomposes it into index_put). + edge = to_edge( + exported_program, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + preserve_ops=[torch.ops.aten.index_copy.default], + ), + ) + edge_program = edge.exported_program() + + # Sanity: pre-pass, both updates are functional `index_copy`. + self.assertEqual( + len(_find_nodes(edge_program, "index_copy", excludes="index_copy_")), + 2, + "Pre-pass: should have two functional index_copy nodes (keys, values)", + ) + + # Add `index_copy` to the set; the in-place form is + # auto-derived by name + schema match + # (ops.edge.aten.index_copy_.default). + custom_set = DEFAULT_INPLACEABLE_OPS | { + edge_ops.edge.aten.index_copy.default, + } + ep = reinplace_pass(edge_program, ops_to_inplace=custom_set) + + # Both updates should now be in-place. + self.assertEqual( + len(_find_nodes(ep, "index_copy_")), + 2, + "Both keys and values index_copy ops should be reinplaced", + ) + self.assertEqual( + len(_find_nodes(ep, "index_copy", excludes="index_copy_")), + 0, + "No functional index_copy nodes should remain", + ) + + # Lower to ExecuTorch. The portable runtime does not register a + # kernel for `index_copy_`, so we only verify lowering succeeds + # (the in-place rewrite must serialize cleanly into the ET + # program). + edge.to_executorch() + + def test_chain_of_inplaceable_ops(self) -> None: + """A chain of safe-to-reinplace ops gets fully rewritten in + topological order. Exercises: + * Multiple distinct ops (`add` and `relu`) registered together + via a single set, with in-place forms auto-derived. + * Reverse-walk safety propagation: each intermediate is + consumed exactly once by the next op, so every step except + the first sees its mutated arg as not-yet-used. + * The first ``add`` mutates an immutable user input + (``x``) and must remain functional. + """ + + class ChainModel(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + t = x + y # add #1: mutates `x` (immutable input) -> unsafe. + t = torch.relu(t) # relu #1: mutates intermediate -> safe. + t = t + y # add #2: mutates intermediate -> safe. + t = torch.relu(t) # relu #2: mutates intermediate -> safe. + return t + + model = ChainModel() + x = torch.tensor([-1.0, 2.0, -3.0, 4.0]) + y = torch.tensor([1.0, 1.0, 1.0, 1.0]) + + exported_program = export(model, (x, y), strict=True) + edge = to_edge(exported_program) + edge_program = edge.exported_program() + + custom_set = { + edge_ops.edge.aten.add.Tensor, + edge_ops.edge.aten.relu.default, + } + ep = reinplace_pass(edge_program, ops_to_inplace=custom_set) + + self.assertEqual( + len(_find_nodes(ep, "aten.add.Tensor")), + 1, + "First add mutates an immutable input; must stay functional", + ) + self.assertEqual( + len(_find_nodes(ep, "add_")), + 1, + "Second add mutates an intermediate; should be reinplaced", + ) + self.assertEqual( + len(_find_nodes(ep, "aten.relu.default")), + 0, + "Both relus mutate intermediates; neither should remain functional", + ) + self.assertEqual( + len(_find_nodes(ep, "relu_")), + 2, + "Both relus should be reinplaced", + ) + + # Lower to ExecuTorch. The portable runtime does not register + # kernels for `add_.Tensor` or `relu_.default`, so we only + # verify lowering succeeds (the in-place rewrites must + # serialize cleanly into the ET program). + edge.to_executorch()