Skip to content

Commit e9f3546

Browse files
Reza Sajadianyfacebook-github-bot
authored andcommitted
memory planner to allocate element-wise output buffer in place of input (#19067)
Summary: Adds a pass namely `InPlaceElemWiseLikeOpsPass` which checks for possible elem-wise ops in the graph w/o skip conection from input. The pass then annotate the output as a new alloc type called `memory.alloc_inplace`. In memory planning, the nodes with output spec type of `alloc_inplace` get output allocation in place of the same node's input. Differential Revision: D100371295
1 parent 45ff0bd commit e9f3546

6 files changed

Lines changed: 179 additions & 6 deletions

File tree

exir/emit/_emitter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,9 @@ def call_function( # pyre-fixme[14]
17551755
assert len(args) == 1
17561756
return self._emit_spec(self.node.meta["spec"])
17571757

1758+
elif target == memory.alloc_inplace:
1759+
return self._emit_spec(self.node.meta["spec"])
1760+
17581761
elif target == memory.view:
17591762
return self._emit_view(args)
17601763

exir/memory.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,26 @@ def alloc(spec: AllocSpec) -> pytree.PyTree:
3333
return torch.empty(shape, dtype=dtype)
3434

3535

36+
def alloc_inplace(base: torch.Tensor, spec: AllocSpec) -> torch.Tensor:
37+
"""
38+
Allocate output tensor in the same memory as base tensor.
39+
40+
This is used by InPlaceElemWiseLikeOpsPass to signal to the memory planner
41+
that the output should share the same memory offset as the base (input)
42+
tensor. The base tensor must have allocated_memory >= output's
43+
allocated_memory, and the base must be dead after the consuming op.
44+
45+
At runtime this behaves identically to alloc() — the in-place semantics
46+
are resolved at planning time.
47+
"""
48+
if isinstance(spec, list):
49+
return [alloc_inplace(base, s) for s in spec]
50+
51+
shape, dtype = spec
52+
shape = eval_shape(shape)
53+
return torch.empty(shape, dtype=dtype)
54+
55+
3656
def free(spec: TensorSpec) -> None:
3757
"""
3858
The function is nop. The major purpose is to put it in the Fx IR.

exir/memory_planning.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,16 @@ def verify_storage_reuse(
186186
if not allow_lifetime_and_storage_overlap and self.lifetime_overlap(
187187
lhs_spec, rhs_spec
188188
):
189-
raise InternalError(
190-
f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}"
189+
# In-place element-wise ops intentionally share storage
190+
# between input and output despite overlapping lifetimes.
191+
is_inplace_pair = (
192+
getattr(lhs_spec, "inplace_base", None) is rhs_spec
193+
or getattr(rhs_spec, "inplace_base", None) is lhs_spec
191194
)
195+
if not is_inplace_pair:
196+
raise InternalError(
197+
f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}"
198+
)
192199

193200
# Check that each mem_obj_id is consistent with whether the tensors have
194201
# storage overlap
@@ -485,6 +492,7 @@ def collect_specs_from_nodes( # noqa: C901
485492
or node.target
486493
in [
487494
memory.alloc,
495+
memory.alloc_inplace,
488496
memory.view,
489497
operator.getitem,
490498
torch.ops.higher_order.cond,
@@ -838,22 +846,50 @@ def greedy(
838846

839847
sorted_specs.reverse()
840848

849+
deferred_inplace: List[TensorSpec] = []
850+
841851
for spec in sorted_specs:
842-
# Create an entry for this TensorSpec in the result object that we'll be
843-
# returning from this algorithm.
844852
spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
845853
if spec.mem_id is None:
846854
spec_alloc_result.mem_id = 1
847855
else:
848856
spec_alloc_result.mem_id = spec.mem_id
849857
greedy_result.spec_dict[spec] = spec_alloc_result
850858
spec.realign(alignment)
859+
860+
if getattr(spec, "inplace_base", None) is not None:
861+
deferred_inplace.append(spec)
862+
continue
863+
851864
spec2obj[spec] = pick_shared_obj(
852865
shared_objects[spec_alloc_result.mem_id],
853866
spec,
854867
allow_overlapping_allocations,
855868
)
856869

870+
for spec in deferred_inplace:
871+
base = spec.inplace_base
872+
assert base in spec2obj, (
873+
f"In-place base spec not found in allocated objects. "
874+
f"Base allocated_memory={base.allocated_memory}, "
875+
f"spec allocated_memory={spec.allocated_memory}"
876+
)
877+
sobj = spec2obj[base]
878+
879+
base_alloc_result = greedy_result.spec_dict[base]
880+
spec_alloc_result = greedy_result.spec_dict[spec]
881+
spec_alloc_result.mem_id = base_alloc_result.mem_id
882+
883+
base_alloc_offset = 0
884+
for alloc_entry in sobj.allocations:
885+
if alloc_entry.spec is base:
886+
base_alloc_offset = alloc_entry.offset
887+
break
888+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
889+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
890+
sobj.allocations.append(AllocationSpec(base_alloc_offset, spec))
891+
spec2obj[spec] = sobj
892+
857893
if len(shared_objects) == 0:
858894
# Cannot find any tensor in the graph that needs to be allocated.
859895
# Return [0, 0] to be consistent with default behavior of naive.

exir/passes/__init__.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
to_scratch_op,
3333
)
3434
from executorch.exir.pass_base import ExportPass
35+
from executorch.exir.tensor import TensorSpec
3536
from executorch.exir.pass_manager import PassManager, PassType
3637
from executorch.exir.passes.const_prop_pass import ConstPropPass
3738
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
@@ -76,6 +77,8 @@
7677
"ToDevicePass",
7778
"EdgeToBackendOpsPass",
7879
"MemoryFormatOpsPass",
80+
"InPlaceElemWiseLikeOpsPass",
81+
"ElemWiseInPlaceAwareMemoryPlanningPass",
7982
"MemoryPlanningPass",
8083
"HintBasedSymShapeEvalPass",
8184
"insert_write_back_for_buffers_pass",
@@ -260,6 +263,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
260263
# we won't see it in the input graph to the to_out_variant pass, unless
261264
# it's retraced after running to_out_variant with the first trace.
262265
memory.alloc,
266+
memory.alloc_inplace,
263267
memory.view,
264268
executorch_call_delegate,
265269
}
@@ -444,6 +448,106 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
444448
return PassResult(graph_module, True)
445449

446450

451+
class InPlaceElemWiseLikeOpsPass(PassBase):
452+
"""Replace memory.alloc with memory.alloc_inplace for element-wise-like ops.
453+
454+
For out-variant ops that are element-wise, the output can be allocated in
455+
the same memory as the input when:
456+
1. output_bytes <= input_bytes
457+
2. The input tensor has no other users after this op (dead after consumption)
458+
459+
This pass replaces the memory.alloc node for the output with
460+
memory.alloc_inplace(input_node, spec), which signals to the memory planner
461+
to place the output at the same offset as the input.
462+
463+
Eligible ops are specified via the constructor's eligible_ops parameter.
464+
"""
465+
466+
def __init__(
467+
self, eligible_ops: Optional[Set[Callable[..., Any]]] = None
468+
) -> None:
469+
self._eligible_ops = eligible_ops or set()
470+
471+
def _is_eligible(self, target: Any) -> bool:
472+
return target in self._eligible_ops
473+
474+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
475+
changed = False
476+
for node in graph_module.graph.nodes:
477+
if node.op != "call_function":
478+
continue
479+
if not self._is_eligible(node.target):
480+
continue
481+
if not memory_planning._is_out_var_node(node):
482+
continue
483+
484+
out_arg_names = get_out_args_from_opoverload(node.target)
485+
if len(out_arg_names) != 1:
486+
continue
487+
488+
out_alloc_node = node.kwargs.get(out_arg_names[0])
489+
if out_alloc_node is None or out_alloc_node.target != memory.alloc:
490+
continue
491+
492+
input_node = node.args[0]
493+
if not isinstance(input_node, torch.fx.Node):
494+
continue
495+
496+
# Compute sizes from FakeTensor metadata (specs may not be set yet
497+
# on alloc nodes since _set_alloc_node_spec runs in MemoryPlanningPass).
498+
out_val = out_alloc_node.meta.get("val")
499+
in_val = input_node.meta.get("val")
500+
if out_val is None or in_val is None:
501+
continue
502+
if not isinstance(out_val, torch.Tensor) or not isinstance(
503+
in_val, torch.Tensor
504+
):
505+
continue
506+
507+
out_nbytes = out_val.nelement() * out_val.element_size()
508+
in_nbytes = in_val.nelement() * in_val.element_size()
509+
if out_nbytes > in_nbytes:
510+
continue
511+
512+
# Input must have no other users besides this node
513+
input_users = [u for u in input_node.users if u != node]
514+
if len(input_users) > 0:
515+
continue
516+
517+
with graph_module.graph.inserting_before(out_alloc_node):
518+
inplace_node = graph_module.graph.call_function(
519+
memory.alloc_inplace,
520+
(input_node, out_alloc_node.args[0]),
521+
)
522+
inplace_node.meta = out_alloc_node.meta.copy()
523+
524+
out_alloc_node.replace_all_uses_with(inplace_node)
525+
graph_module.graph.erase_node(out_alloc_node)
526+
changed = True
527+
528+
return PassResult(graph_module, changed)
529+
530+
531+
class ElemWiseInPlaceAwareMemoryPlanningPass(MemoryPlanningPass):
532+
"""MemoryPlanningPass that first runs InPlaceElemWiseLikeOpsPass."""
533+
534+
def __init__(
535+
self,
536+
eligible_ops: Optional[Set[Callable[..., Any]]] = None,
537+
**kwargs: Any,
538+
) -> None:
539+
super().__init__(**kwargs)
540+
self._inplace_pass = InPlaceElemWiseLikeOpsPass(eligible_ops)
541+
542+
def run(
543+
self,
544+
graph_module: torch.fx.GraphModule,
545+
graph_signature: Optional["ExportGraphSignature"] = None,
546+
) -> PassResult:
547+
self._inplace_pass(graph_module)
548+
return super().run(graph_module, graph_signature)
549+
550+
447551
def to_scratch_op_pass(graph_module: torch.fx.GraphModule) -> PassResult:
448552
for node in graph_module.graph.nodes:
449553
if node.op != "call_function":

exir/passes/memory_planning_pass.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
from executorch.exir._warnings import deprecated
1616
from executorch.exir.error import internal_assert
17-
from executorch.exir.memory import alloc
17+
from executorch.exir.memory import alloc, alloc_inplace
1818
from executorch.exir.memory_planning import (
1919
_is_out_var_node,
2020
apply_algo,
@@ -192,6 +192,13 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
192192
if len(out_arg_names) == 1:
193193
out_alloc_node = node.kwargs[out_arg_names[0]]
194194
out_alloc_node.meta["spec"] = node.meta["spec"]
195+
if (
196+
out_alloc_node.target == alloc_inplace
197+
and isinstance(out_alloc_node.args[0], Node)
198+
):
199+
base_spec = out_alloc_node.args[0].meta.get("spec")
200+
if base_spec is not None:
201+
node.meta["spec"].inplace_base = base_spec
195202
continue
196203
specs = get_node_tensor_specs(node)
197204
i = 0
@@ -206,7 +213,7 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
206213
# dont increment i as we dont have a spec for this node
207214
internal_assert(
208215
out_alloc_node.op == "call_function"
209-
and out_alloc_node.target == alloc,
216+
and out_alloc_node.target in (alloc, alloc_inplace),
210217
f"Out-var's node {out_alloc_node} has op {out_alloc_node.op} and target {out_alloc_node.target}",
211218
)
212219
internal_assert(

exir/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def init_mem_planning_fields(self) -> None:
216216
self.mem_id = None
217217
self.mem_obj_id = None
218218
self.mem_offset = None
219+
# Set by InPlaceElemWiseLikeOpsPass: the base TensorSpec whose memory
220+
# this spec should share (output allocated in-place over the input).
221+
self.inplace_base: Optional["TensorSpec"] = None
219222

220223
@property
221224
def dtype(self) -> torch.dtype:

0 commit comments

Comments
 (0)