Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,16 @@ def _emit_evalue(self, val: EValue) -> _AbstractValue:
return _AbstractValue(len(self.emitter_state.values) - 1, tensor)

def _emit_spec(self, spec: ValueSpec) -> _EmitterValue:
"""Given the provided spec constructs the corresponding EValue from it and then emits it."""
"""Given the provided spec constructs the corresponding EValue from it and then emits it.

If `spec` was already emitted earlier (e.g., because two FX
nodes share the same TensorSpec object — typically because the
planner's `_alias_inplace_result_specs` aliased an in-place
op's result onto its mutated input), reuse the existing
value_id. This keeps the invariant "one TensorSpec ↔ one
Value" so downstream emit doesn't create duplicate Values for
aliased FX nodes.
"""

def _process(spec: LeafValueSpec) -> _AbstractValue:
if isinstance(spec, (list, tuple)):
Expand All @@ -642,6 +651,25 @@ def _process(spec: LeafValueSpec) -> _AbstractValue:
f"Invalid node spec expected TensorSpec received {spec}",
)

# Spec was already emitted — reuse the existing Value so
# two FX nodes sharing one TensorSpec also share one
# value_id in the lowered IR.
existing_id = self.emitter_state.spec2id_dict.get(spec)
if existing_id is not None:
existing_evalue = self.emitter_state.values[existing_id]
# Both insertion sites for `spec2id_dict` (this method
# and the placeholder emitter) only register ids whose
# EValue wraps a `Tensor` (built via
# `_tensor_spec_to_evalue` from a `TensorSpec`).
self._internal_assert_emitter(
isinstance(existing_evalue.val, Tensor),
self.node,
f"spec2id_dict entry for TensorSpec must point to a "
f"Tensor EValue, got "
f"{type(existing_evalue.val).__name__}",
)
return _AbstractValue(existing_id, existing_evalue.val)

ret = self._emit_evalue(self._tensor_spec_to_evalue(spec)) # pyre-ignore
self.emitter_state.spec2id_dict[spec] = ret.id # pyre-ignore
return ret
Expand Down Expand Up @@ -2020,6 +2048,14 @@ def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool:
)
value = self._emit_evalue(evalue)

# Populate spec2id_dict so downstream `_emit_spec` calls (e.g.,
# for in-place op result FX nodes whose spec was aliased onto
# this placeholder's spec by the planner's
# `_alias_inplace_result_specs`) reuse this placeholder's
# value_id rather than creating a new Value.
if isinstance(spec, TensorSpec):
self.emitter_state.spec2id_dict[spec] = value.id

# Only user inputs should remain as inputs.
if is_user_input:
self.inputs.append(value.id)
Expand Down
148 changes: 141 additions & 7 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
from executorch.exir.control_flow import while_loop as exir_while
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.error import internal_assert, InternalError
from executorch.exir.operator.convert import is_inplace_variant, is_out_variant
from executorch.exir.operator.convert import (
is_inplace_variant,
is_out_variant,
output_to_aliased_input_map,
unwrap_op_overload,
)
from executorch.exir.schema import DeviceType, NonConstBufferDevice, TensorShapeDynamism
from executorch.exir.tensor import TensorSpec
from torch import fx
Expand Down Expand Up @@ -302,13 +307,137 @@ def _is_out_var_node(node: torch.fx.Node) -> bool:


def _is_inplace_node(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and isinstance(node.target, torch._ops.OpOverload)
and is_inplace_variant(
node.target._schema.name, node.target._schema.overload_name
if node.op != "call_function":
return False
target = node.target
if not isinstance(target, torch._ops.OpOverload) and not isinstance(
getattr(target, "_op", None), torch._ops.OpOverload
):
return False
op = unwrap_op_overload(target)
return is_inplace_variant(op._schema.name, op._schema.overload_name)


def _alias_inplace_result_specs(node: torch.fx.Node) -> None: # noqa: C901
"""Alias an in-place op's result TensorSpec(s) onto the corresponding
input's spec.

In-place ops (schema kind == inplace) mutate one or more of their
inputs and return tensors that alias them, declared via the
``Tensor(a!)`` schema annotation. To make the memory planner treat
result and aliased input as one storage, we copy the input's spec
object onto the output's ``node.meta["spec"]`` slot.

Output→input correspondence is computed via
``output_to_aliased_input_map``, which matches each return's
write-alias set against the inputs that share it.

Gating:

- Only runs for in-place nodes (caller checks ``_is_inplace_node``).
- Multi-output in-place ops are supported when each return's alias
set matches exactly one input's alias set.
- Falls through silently when alias info is absent or unparseable,
preserving the original spec. ``logging.debug`` records each
early-return reason so silent regressions are observable.
"""
target = node.target
op = unwrap_op_overload(target)

schema = op._schema
out_to_in = output_to_aliased_input_map(schema)
if not out_to_in:
logging.debug(
f"_alias_inplace_result_specs: schema for {op} declares no "
f"write-aliased outputs matching an input; skipping."
)
)
return

# Normalize the current spec container into a list for uniform
# handling. Caller guarantees this node was identified as an
# in-place op with a meta spec (see `_is_inplace_node`), so an
# unrecognized container shape is a real bug — assert loudly so it
# surfaces in tests rather than silently disabling aliasing.
current = node.meta.get("spec")
if isinstance(current, TensorSpec):
out_specs_list: List[Optional[TensorSpec]] = [current]
return_container_kind = "scalar"
elif isinstance(current, (list, tuple)):
out_specs_list = list(current)
return_container_kind = type(current).__name__
else:
raise InternalError(
f"_alias_inplace_result_specs: in-place node {node.name} "
f"({op}) has unrecognized spec container of type "
f"{type(current).__name__!r}; expected TensorSpec, list, "
f"or tuple."
)

# Compute new spec for each return; None means "keep original".
# Mutated inputs are usually positional (the `Tensor(a!)` `self`
# arg), but custom ops may pass them via kwargs — fall back to
# `node.kwargs[arg_name]` in that case.
replacements: List[Optional[TensorSpec]] = [None] * len(out_specs_list)

for out_idx, in_idx in out_to_in.items():
if out_idx >= len(out_specs_list):
logging.debug(
f"_alias_inplace_result_specs: schema for {op} declares "
f"return {out_idx} but spec container has only "
f"{len(out_specs_list)} entries; skipping this return."
)
continue
# Resolve the mutated input. Prefer positional args, fall back
# to kwargs by argument name (custom ops may pass `Tensor(a!)`
# args via kwargs).
in_node: object
if in_idx < len(node.args):
in_node = node.args[in_idx]
else:
arg_name = (
schema.arguments[in_idx].name
if in_idx < len(schema.arguments)
else None
)
if arg_name is None or arg_name not in node.kwargs:
logging.debug(
f"_alias_inplace_result_specs: schema for {op} "
f"expects mutated input at position {in_idx} "
f"(name={arg_name!r}) but it is supplied neither "
f"positionally nor via kwargs; skipping."
)
continue
in_node = node.kwargs[arg_name]
if not isinstance(in_node, torch.fx.Node):
continue
# NOTE: alias unconditionally — including when the input is a
# placeholder (named buffer / mutable input). Skipping
# placeholders would leave a dangling out-arg on in-place op
# instructions whose self is a buffer; the runtime kernel
# writes to that out-arg's storage rather than mutating the
# buffer, producing incorrect results. The companion change in
# `_emit_spec` (emit/_emitter.py) deduplicates by spec identity
# so this aliasing doesn't produce two Values for the same FQN.
in_spec = in_node.meta.get("spec")
if not isinstance(in_spec, TensorSpec):
continue
replacements[out_idx] = in_spec

if not any(r is not None for r in replacements):
return

# Assemble the new spec container, preserving the original shape.
new_list = [
replacements[i] if replacements[i] is not None else out_specs_list[i]
for i in range(len(out_specs_list))
]
if return_container_kind == "scalar":
if isinstance(new_list[0], TensorSpec):
node.meta["spec"] = new_list[0]
elif return_container_kind == "list":
node.meta["spec"] = list(new_list)
elif return_container_kind == "tuple":
node.meta["spec"] = tuple(new_list)


def update_tensor_lifetime(
Expand Down Expand Up @@ -474,6 +603,11 @@ def collect_specs_from_nodes( # noqa: C901
continue

if _is_inplace_node(node):
# Schema-driven: alias the in-place op's result spec(s) onto
# the corresponding mutated input's spec so the planner gives
# them the same allocation slot. See `_alias_inplace_result_specs`
# for the full rationale and gating rules.
_alias_inplace_result_specs(node)
continue

if _is_mutable_buffer(node, graph_signature) and ignore_mutable_buffers:
Expand Down
66 changes: 65 additions & 1 deletion exir/operator/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import dataclasses
import logging
from typing import Dict, Optional, Tuple
from typing import Dict, FrozenSet, Optional, Tuple

import torch
from torch._ops import OpOverload
Expand Down Expand Up @@ -105,6 +105,70 @@ def get_out_args_from_schema(out_var_schema: FunctionSchema) -> Tuple[str]:
return tuple(arg.name for arg in out_var_schema.arguments.out)


def unwrap_op_overload(target: object) -> torch._ops.OpOverload:
"""Return the underlying ``torch._ops.OpOverload`` for a node target.

Handles both raw aten ops (``torch.ops.aten.X.Y``) and edge-dialect
wrappers (``ops.edge.aten.X.Y`` / ``BackendOpOverload``), which expose
the underlying ATen op as ``target._op`` but are not themselves
subclasses of ``torch._ops.OpOverload``.

Raises ``TypeError`` if ``target`` is neither an ``OpOverload`` nor
wraps one.
"""
if isinstance(target, torch._ops.OpOverload):
return target
underlying = getattr(target, "_op", None)
if isinstance(underlying, torch._ops.OpOverload):
return underlying
raise TypeError(
f"unwrap_op_overload: expected a torch._ops.OpOverload or a "
f"wrapper exposing one as `_op`, got {type(target).__name__}: "
f"{target!r}"
)


def output_to_aliased_input_map(
schema: torch.FunctionSchema,
) -> Dict[int, int]:
"""For a mutating op, return a map from output index to input arg
index, where the output aliases (i.e., mutates) the input via a
shared ``Tensor(a!)`` write-alias set.

Only outputs whose ``alias_info.is_write`` is True are considered.
Inputs without a write-aliased ``alias_info`` cannot be aliased
targets and are skipped. If multiple inputs share a single
return's alias set (no known cases in aten today), the first
matching input wins.

Returns an empty dict for ops with no write-aliased outputs (i.e.,
purely functional ops, or ops whose schema declares no returns).

Note: ``schema`` is the pybind ``torch._C.FunctionSchema`` (as
obtained from ``op._schema``), not the torchgen native
``FunctionSchema`` used elsewhere in this file.
"""
# Build alias_set -> input_idx so we can look up each return.
alias_set_to_input: Dict[FrozenSet[str], int] = {}
for in_idx, arg in enumerate(schema.arguments):
info = arg.alias_info
if info is None or not info.is_write:
continue
alias_set = frozenset(info.before_set)
alias_set_to_input.setdefault(alias_set, in_idx)

out_to_in: Dict[int, int] = {}
for out_idx, ret in enumerate(schema.returns):
info = ret.alias_info
if info is None or not info.is_write:
continue
alias_set = frozenset(info.before_set)
in_idx = alias_set_to_input.get(alias_set)
if in_idx is not None:
out_to_in[out_idx] = in_idx
return out_to_in


def parse_qualified_opname(qualified_opname: str) -> Tuple[str, str]:
"""
Given a qualified opname like aten::add, return a tuple for namespace
Expand Down
Loading
Loading