Skip to content

Commit 24b23f5

Browse files
metascroyfacebook-github-bot
authored andcommitted
Extend ExecuTorch reinplace pass
Summary: This diff expands the ExecuTorch reinplace pass to be configurable for backends to use. We can introduce more in-place portable ops to expand the default set in ExecuTorch over time as well. Note: if two tensors have the same spec, they are now deduped and emitted as the same value. Previously (with copy_ writeback), the synthetic out got a new value, but with the same allocation info as the self value. What's new * reinplace_pass now takes ops_to_inplace= and inplace_overrides= so backends can extend the candidate set. The in-place form is auto-derived from the op schema, and registrations are validated up front. * Mutated arg positions are derived from Tensor(a!) annotations, so ops with multiple mutated args are handled correctly. kwargs are also forwarded to the in-place op (e.g. accumulate= for index_put). * Fixed a bug where, after a successful rewrite, the new node's input nodes were not added to seen_nodes, so an upstream candidate mutating one of those reads could incorrectly look safe to reinplace. Covered by a new index_put-only regression test. * Memory planner now aliases an in-place op's result TensorSpec onto its mutated input's spec (schema-driven via output_to_aliased_input_map), so they share an allocation slot. * Emitter dedupes by TensorSpec identity so two FX nodes sharing one spec share one Value in the lowered IR. Differential Revision: D104152427
1 parent a5a9621 commit 24b23f5

7 files changed

Lines changed: 1433 additions & 132 deletions

File tree

exir/emit/_emitter.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,16 @@ def _emit_evalue(self, val: EValue) -> _AbstractValue:
622622
return _AbstractValue(len(self.emitter_state.values) - 1, tensor)
623623

624624
def _emit_spec(self, spec: ValueSpec) -> _EmitterValue:
625-
"""Given the provided spec constructs the corresponding EValue from it and then emits it."""
625+
"""Given the provided spec constructs the corresponding EValue from it and then emits it.
626+
627+
If `spec` was already emitted earlier (e.g., because two FX
628+
nodes share the same TensorSpec object — typically because the
629+
planner's `_alias_inplace_result_specs` aliased an in-place
630+
op's result onto its mutated input), reuse the existing
631+
value_id. This keeps the invariant "one TensorSpec ↔ one
632+
Value" so downstream emit doesn't create duplicate Values for
633+
aliased FX nodes.
634+
"""
626635

627636
def _process(spec: LeafValueSpec) -> _AbstractValue:
628637
if isinstance(spec, (list, tuple)):
@@ -642,6 +651,25 @@ def _process(spec: LeafValueSpec) -> _AbstractValue:
642651
f"Invalid node spec expected TensorSpec received {spec}",
643652
)
644653

654+
# Spec was already emitted — reuse the existing Value so
655+
# two FX nodes sharing one TensorSpec also share one
656+
# value_id in the lowered IR.
657+
existing_id = self.emitter_state.spec2id_dict.get(spec)
658+
if existing_id is not None:
659+
existing_evalue = self.emitter_state.values[existing_id]
660+
# Both insertion sites for `spec2id_dict` (this method
661+
# and the placeholder emitter) only register ids whose
662+
# EValue wraps a `Tensor` (built via
663+
# `_tensor_spec_to_evalue` from a `TensorSpec`).
664+
self._internal_assert_emitter(
665+
isinstance(existing_evalue.val, Tensor),
666+
self.node,
667+
f"spec2id_dict entry for TensorSpec must point to a "
668+
f"Tensor EValue, got "
669+
f"{type(existing_evalue.val).__name__}",
670+
)
671+
return _AbstractValue(existing_id, existing_evalue.val)
672+
645673
ret = self._emit_evalue(self._tensor_spec_to_evalue(spec)) # pyre-ignore
646674
self.emitter_state.spec2id_dict[spec] = ret.id # pyre-ignore
647675
return ret
@@ -2020,6 +2048,14 @@ def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool:
20202048
)
20212049
value = self._emit_evalue(evalue)
20222050

2051+
# Populate spec2id_dict so downstream `_emit_spec` calls (e.g.,
2052+
# for in-place op result FX nodes whose spec was aliased onto
2053+
# this placeholder's spec by the planner's
2054+
# `_alias_inplace_result_specs`) reuse this placeholder's
2055+
# value_id rather than creating a new Value.
2056+
if isinstance(spec, TensorSpec):
2057+
self.emitter_state.spec2id_dict[spec] = value.id
2058+
20232059
# Only user inputs should remain as inputs.
20242060
if is_user_input:
20252061
self.inputs.append(value.id)

exir/memory_planning.py

Lines changed: 141 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@
3030
from executorch.exir.control_flow import while_loop as exir_while
3131
from executorch.exir.delegate import executorch_call_delegate
3232
from executorch.exir.error import internal_assert, InternalError
33-
from executorch.exir.operator.convert import is_inplace_variant, is_out_variant
33+
from executorch.exir.operator.convert import (
34+
is_inplace_variant,
35+
is_out_variant,
36+
output_to_aliased_input_map,
37+
unwrap_op_overload,
38+
)
3439
from executorch.exir.schema import TensorShapeDynamism
3540
from executorch.exir.tensor import TensorSpec
3641
from torch import fx
@@ -302,13 +307,137 @@ def _is_out_var_node(node: torch.fx.Node) -> bool:
302307

303308

304309
def _is_inplace_node(node: torch.fx.Node) -> bool:
305-
return (
306-
node.op == "call_function"
307-
and isinstance(node.target, torch._ops.OpOverload)
308-
and is_inplace_variant(
309-
node.target._schema.name, node.target._schema.overload_name
310+
if node.op != "call_function":
311+
return False
312+
target = node.target
313+
if not isinstance(target, torch._ops.OpOverload) and not isinstance(
314+
getattr(target, "_op", None), torch._ops.OpOverload
315+
):
316+
return False
317+
op = unwrap_op_overload(target)
318+
return is_inplace_variant(op._schema.name, op._schema.overload_name)
319+
320+
321+
def _alias_inplace_result_specs(node: torch.fx.Node) -> None:
322+
"""Alias an in-place op's result TensorSpec(s) onto the corresponding
323+
input's spec.
324+
325+
In-place ops (schema kind == inplace) mutate one or more of their
326+
inputs and return tensors that alias them, declared via the
327+
``Tensor(a!)`` schema annotation. To make the memory planner treat
328+
result and aliased input as one storage, we copy the input's spec
329+
object onto the output's ``node.meta["spec"]`` slot.
330+
331+
Output→input correspondence is computed via
332+
``output_to_aliased_input_map``, which matches each return's
333+
write-alias set against the inputs that share it.
334+
335+
Gating:
336+
337+
- Only runs for in-place nodes (caller checks ``_is_inplace_node``).
338+
- Multi-output in-place ops are supported when each return's alias
339+
set matches exactly one input's alias set.
340+
- Falls through silently when alias info is absent or unparseable,
341+
preserving the original spec. ``logging.debug`` records each
342+
early-return reason so silent regressions are observable.
343+
"""
344+
target = node.target
345+
op = unwrap_op_overload(target)
346+
347+
schema = op._schema
348+
out_to_in = output_to_aliased_input_map(schema)
349+
if not out_to_in:
350+
logging.debug(
351+
f"_alias_inplace_result_specs: schema for {op} declares no "
352+
f"write-aliased outputs matching an input; skipping."
310353
)
311-
)
354+
return
355+
356+
# Normalize the current spec container into a list for uniform
357+
# handling. Caller guarantees this node was identified as an
358+
# in-place op with a meta spec (see `_is_inplace_node`), so an
359+
# unrecognized container shape is a real bug — assert loudly so it
360+
# surfaces in tests rather than silently disabling aliasing.
361+
current = node.meta.get("spec")
362+
if isinstance(current, TensorSpec):
363+
out_specs_list: List[Optional[TensorSpec]] = [current]
364+
return_container_kind = "scalar"
365+
elif isinstance(current, (list, tuple)):
366+
out_specs_list = list(current)
367+
return_container_kind = type(current).__name__
368+
else:
369+
raise InternalError(
370+
f"_alias_inplace_result_specs: in-place node {node.name} "
371+
f"({op}) has unrecognized spec container of type "
372+
f"{type(current).__name__!r}; expected TensorSpec, list, "
373+
f"or tuple."
374+
)
375+
376+
# Compute new spec for each return; None means "keep original".
377+
# Mutated inputs are usually positional (the `Tensor(a!)` `self`
378+
# arg), but custom ops may pass them via kwargs — fall back to
379+
# `node.kwargs[arg_name]` in that case.
380+
replacements: List[Optional[TensorSpec]] = [None] * len(out_specs_list)
381+
382+
for out_idx, in_idx in out_to_in.items():
383+
if out_idx >= len(out_specs_list):
384+
logging.debug(
385+
f"_alias_inplace_result_specs: schema for {op} declares "
386+
f"return {out_idx} but spec container has only "
387+
f"{len(out_specs_list)} entries; skipping this return."
388+
)
389+
continue
390+
# Resolve the mutated input. Prefer positional args, fall back
391+
# to kwargs by argument name (custom ops may pass `Tensor(a!)`
392+
# args via kwargs).
393+
in_node: object
394+
if in_idx < len(node.args):
395+
in_node = node.args[in_idx]
396+
else:
397+
arg_name = (
398+
schema.arguments[in_idx].name
399+
if in_idx < len(schema.arguments)
400+
else None
401+
)
402+
if arg_name is None or arg_name not in node.kwargs:
403+
logging.debug(
404+
f"_alias_inplace_result_specs: schema for {op} "
405+
f"expects mutated input at position {in_idx} "
406+
f"(name={arg_name!r}) but it is supplied neither "
407+
f"positionally nor via kwargs; skipping."
408+
)
409+
continue
410+
in_node = node.kwargs[arg_name]
411+
if not isinstance(in_node, torch.fx.Node):
412+
continue
413+
# NOTE: alias unconditionally — including when the input is a
414+
# placeholder (named buffer / mutable input). Skipping
415+
# placeholders would leave a dangling out-arg on in-place op
416+
# instructions whose self is a buffer; the runtime kernel
417+
# writes to that out-arg's storage rather than mutating the
418+
# buffer, producing incorrect results. The companion change in
419+
# `_emit_spec` (emit/_emitter.py) deduplicates by spec identity
420+
# so this aliasing doesn't produce two Values for the same FQN.
421+
in_spec = in_node.meta.get("spec")
422+
if not isinstance(in_spec, TensorSpec):
423+
continue
424+
replacements[out_idx] = in_spec
425+
426+
if not any(r is not None for r in replacements):
427+
return
428+
429+
# Assemble the new spec container, preserving the original shape.
430+
new_list = [
431+
replacements[i] if replacements[i] is not None else out_specs_list[i]
432+
for i in range(len(out_specs_list))
433+
]
434+
if return_container_kind == "scalar":
435+
if isinstance(new_list[0], TensorSpec):
436+
node.meta["spec"] = new_list[0]
437+
elif return_container_kind == "list":
438+
node.meta["spec"] = list(new_list)
439+
elif return_container_kind == "tuple":
440+
node.meta["spec"] = tuple(new_list)
312441

313442

314443
def update_tensor_lifetime(
@@ -474,6 +603,11 @@ def collect_specs_from_nodes( # noqa: C901
474603
continue
475604

476605
if _is_inplace_node(node):
606+
# Schema-driven: alias the in-place op's result spec(s) onto
607+
# the corresponding mutated input's spec so the planner gives
608+
# them the same allocation slot. See `_alias_inplace_result_specs`
609+
# for the full rationale and gating rules.
610+
_alias_inplace_result_specs(node)
477611
continue
478612

479613
if _is_mutable_buffer(node, graph_signature) and ignore_mutable_buffers:

exir/operator/convert.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import dataclasses
2727
import logging
28-
from typing import Dict, Optional, Tuple
28+
from typing import Dict, FrozenSet, Optional, Tuple
2929

3030
import torch
3131
from torch._ops import OpOverload
@@ -105,6 +105,70 @@ def get_out_args_from_schema(out_var_schema: FunctionSchema) -> Tuple[str]:
105105
return tuple(arg.name for arg in out_var_schema.arguments.out)
106106

107107

108+
def unwrap_op_overload(target: object) -> torch._ops.OpOverload:
109+
"""Return the underlying ``torch._ops.OpOverload`` for a node target.
110+
111+
Handles both raw aten ops (``torch.ops.aten.X.Y``) and edge-dialect
112+
wrappers (``ops.edge.aten.X.Y`` / ``BackendOpOverload``), which expose
113+
the underlying ATen op as ``target._op`` but are not themselves
114+
subclasses of ``torch._ops.OpOverload``.
115+
116+
Raises ``TypeError`` if ``target`` is neither an ``OpOverload`` nor
117+
wraps one.
118+
"""
119+
if isinstance(target, torch._ops.OpOverload):
120+
return target
121+
underlying = getattr(target, "_op", None)
122+
if isinstance(underlying, torch._ops.OpOverload):
123+
return underlying
124+
raise TypeError(
125+
f"unwrap_op_overload: expected a torch._ops.OpOverload or a "
126+
f"wrapper exposing one as `_op`, got {type(target).__name__}: "
127+
f"{target!r}"
128+
)
129+
130+
131+
def output_to_aliased_input_map(
132+
schema: torch.FunctionSchema,
133+
) -> Dict[int, int]:
134+
"""For a mutating op, return a map from output index to input arg
135+
index, where the output aliases (i.e., mutates) the input via a
136+
shared ``Tensor(a!)`` write-alias set.
137+
138+
Only outputs whose ``alias_info.is_write`` is True are considered.
139+
Inputs without a write-aliased ``alias_info`` cannot be aliased
140+
targets and are skipped. If multiple inputs share a single
141+
return's alias set (no known cases in aten today), the first
142+
matching input wins.
143+
144+
Returns an empty dict for ops with no write-aliased outputs (i.e.,
145+
purely functional ops, or ops whose schema declares no returns).
146+
147+
Note: ``schema`` is the pybind ``torch._C.FunctionSchema`` (as
148+
obtained from ``op._schema``), not the torchgen native
149+
``FunctionSchema`` used elsewhere in this file.
150+
"""
151+
# Build alias_set -> input_idx so we can look up each return.
152+
alias_set_to_input: Dict[FrozenSet[str], int] = {}
153+
for in_idx, arg in enumerate(schema.arguments):
154+
info = arg.alias_info
155+
if info is None or not info.is_write:
156+
continue
157+
alias_set = frozenset(info.before_set)
158+
alias_set_to_input.setdefault(alias_set, in_idx)
159+
160+
out_to_in: Dict[int, int] = {}
161+
for out_idx, ret in enumerate(schema.returns):
162+
info = ret.alias_info
163+
if info is None or not info.is_write:
164+
continue
165+
alias_set = frozenset(info.before_set)
166+
in_idx = alias_set_to_input.get(alias_set)
167+
if in_idx is not None:
168+
out_to_in[out_idx] = in_idx
169+
return out_to_in
170+
171+
108172
def parse_qualified_opname(qualified_opname: str) -> Tuple[str, str]:
109173
"""
110174
Given a qualified opname like aten::add, return a tuple for namespace

0 commit comments

Comments
 (0)