Skip to content

Commit f1fa6fc

Browse files
Reza Sajadianyfacebook-github-bot
authored andcommitted
memory planner to allocate element-wise output buffer in place of input (#19067)
Summary: Adds a pass namely `InPlaceElemWiseLikeOpsPass` which checks for possible elem-wise ops in the graph w/o skip conection from input. The pass then annotate the output as a new alloc type called `memory.alloc_inplace`. In memory planning, the nodes with output spec type of `alloc_inplace` get output allocation in place of the same node's input. Differential Revision: D100371295
1 parent d858cd9 commit f1fa6fc

7 files changed

Lines changed: 337 additions & 7 deletions

File tree

exir/emit/_emitter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,10 @@ def call_function( # pyre-fixme[14]
17551755
assert len(args) == 1
17561756
return self._emit_spec(self.node.meta["spec"])
17571757

1758+
elif target == memory.alloc_inplace:
1759+
assert len(args) == 2
1760+
return self._emit_spec(self.node.meta["spec"])
1761+
17581762
elif target == memory.view:
17591763
return self._emit_view(args)
17601764

exir/memory.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,26 @@ def alloc(spec: AllocSpec) -> pytree.PyTree:
3333
return torch.empty(shape, dtype=dtype)
3434

3535

36+
def alloc_inplace(base: torch.Tensor, spec: AllocSpec) -> pytree.PyTree:
37+
"""
38+
Allocate output tensor in the same memory as base tensor.
39+
40+
This is used by InPlaceElemWiseLikeOpsPass to signal to the memory planner
41+
that the output should share the same memory offset as the base (input)
42+
tensor. The base tensor must have allocated_memory >= output's
43+
allocated_memory, and the base must be dead after the consuming op.
44+
45+
At runtime this behaves identically to alloc() — the in-place semantics
46+
are resolved at planning time.
47+
"""
48+
if isinstance(spec, list):
49+
return [alloc_inplace(base, s) for s in spec]
50+
51+
shape, dtype = spec
52+
shape = eval_shape(shape)
53+
return torch.empty(shape, dtype=dtype)
54+
55+
3656
def free(spec: TensorSpec) -> None:
3757
"""
3858
The function is nop. The major purpose is to put it in the Fx IR.

exir/memory_planning.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,16 @@ def verify_storage_reuse(
186186
if not allow_lifetime_and_storage_overlap and self.lifetime_overlap(
187187
lhs_spec, rhs_spec
188188
):
189-
raise InternalError(
190-
f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}"
189+
# In-place element-wise ops intentionally share storage
190+
# between input and output despite overlapping lifetimes.
191+
is_inplace_pair = (
192+
lhs_spec.inplace_base is rhs_spec
193+
or rhs_spec.inplace_base is lhs_spec
191194
)
195+
if not is_inplace_pair:
196+
raise InternalError(
197+
f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}"
198+
)
192199

193200
# Check that each mem_obj_id is consistent with whether the tensors have
194201
# storage overlap
@@ -485,6 +492,7 @@ def collect_specs_from_nodes( # noqa: C901
485492
or node.target
486493
in [
487494
memory.alloc,
495+
memory.alloc_inplace,
488496
memory.view,
489497
operator.getitem,
490498
torch.ops.higher_order.cond,
@@ -838,22 +846,65 @@ def greedy(
838846

839847
sorted_specs.reverse()
840848

849+
deferred_inplace: List[TensorSpec] = []
850+
841851
for spec in sorted_specs:
842-
# Create an entry for this TensorSpec in the result object that we'll be
843-
# returning from this algorithm.
844852
spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
845853
if spec.mem_id is None:
846854
spec_alloc_result.mem_id = 1
847855
else:
848856
spec_alloc_result.mem_id = spec.mem_id
849857
greedy_result.spec_dict[spec] = spec_alloc_result
850858
spec.realign(alignment)
859+
860+
if spec.inplace_base is not None:
861+
deferred_inplace.append(spec)
862+
continue
863+
851864
spec2obj[spec] = pick_shared_obj(
852865
shared_objects[spec_alloc_result.mem_id],
853866
spec,
854867
allow_overlapping_allocations,
855868
)
856869

870+
remaining = list(deferred_inplace)
871+
while remaining:
872+
progress = False
873+
next_remaining = []
874+
for spec in remaining:
875+
base = spec.inplace_base
876+
if base not in spec2obj:
877+
next_remaining.append(spec)
878+
continue
879+
progress = True
880+
sobj = spec2obj[base]
881+
882+
base_alloc_result = greedy_result.spec_dict[base]
883+
spec_alloc_result = greedy_result.spec_dict[spec]
884+
spec_alloc_result.mem_id = base_alloc_result.mem_id
885+
886+
base_alloc_offset = None
887+
for alloc_entry in sobj.allocations:
888+
if alloc_entry.spec is base:
889+
base_alloc_offset = alloc_entry.offset
890+
break
891+
assert base_alloc_offset is not None, (
892+
f"Base allocation entry not found in shared object for spec "
893+
f"with allocated_memory={spec.allocated_memory}"
894+
)
895+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
896+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
897+
sobj.allocations.append(AllocationSpec(base_alloc_offset, spec))
898+
spec2obj[spec] = sobj
899+
if not progress:
900+
unresolved = ", ".join(
901+
f"allocated_memory={s.allocated_memory}" for s in next_remaining
902+
)
903+
raise InternalError(
904+
f"Circular or unresolvable in-place dependency chain: {unresolved}"
905+
)
906+
remaining = next_remaining
907+
857908
if len(shared_objects) == 0:
858909
# Cannot find any tensor in the graph that needs to be allocated.
859910
# Return [0, 0] to be consistent with default behavior of naive.
@@ -1012,6 +1063,12 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
10121063
bufsizes = cast(List[int], bufsizes)
10131064

10141065
for spec in specs:
1066+
if spec.inplace_base is not None:
1067+
raise InternalError(
1068+
"The naive memory planning algorithm does not support in-place "
1069+
"element-wise ops (inplace_base). Use the greedy algorithm instead."
1070+
)
1071+
10151072
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
10161073
# assume a single memory layer which has mem_id 1
10171074
if spec.mem_id is None:

exir/passes/__init__.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
"ToDevicePass",
7777
"EdgeToBackendOpsPass",
7878
"MemoryFormatOpsPass",
79+
"InPlaceElemWiseLikeOpsPass",
80+
"ElemWiseInPlaceAwareMemoryPlanningPass",
7981
"MemoryPlanningPass",
8082
"HintBasedSymShapeEvalPass",
8183
"insert_write_back_for_buffers_pass",
@@ -260,6 +262,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
260262
# we won't see it in the input graph to the to_out_variant pass, unless
261263
# it's retraced after running to_out_variant with the first trace.
262264
memory.alloc,
265+
memory.alloc_inplace,
263266
memory.view,
264267
executorch_call_delegate,
265268
}
@@ -444,6 +447,109 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
444447
return PassResult(graph_module, True)
445448

446449

450+
class InPlaceElemWiseLikeOpsPass(PassBase):
451+
"""Replace memory.alloc with memory.alloc_inplace for element-wise-like ops.
452+
453+
For out-variant ops that are element-wise, the output can be allocated in
454+
the same memory as the input when:
455+
1. output_bytes <= input_bytes
456+
2. The input tensor has no other users after this op (dead after consumption)
457+
458+
This pass replaces the memory.alloc node for the output with
459+
memory.alloc_inplace(input_node, spec), which signals to the memory planner
460+
to place the output at the same offset as the input.
461+
462+
Eligible ops are specified via the constructor's eligible_ops parameter or a
463+
default set is used set by _default_eligible_ops.
464+
"""
465+
466+
@staticmethod
467+
def _default_eligible_ops() -> Set[Callable[..., Any]]:
468+
return {
469+
torch.ops.cortex_m.quantize_per_tensor.out,
470+
}
471+
472+
def __init__(self, eligible_ops: Optional[Set[Callable[..., Any]]] = None) -> None:
473+
self._eligible_ops = (
474+
eligible_ops if eligible_ops is not None else self._default_eligible_ops()
475+
)
476+
477+
def _is_eligible(self, target: Any) -> bool:
478+
return target in self._eligible_ops
479+
480+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
481+
changed = False
482+
for node in graph_module.graph.nodes:
483+
if node.op != "call_function":
484+
continue
485+
if not self._is_eligible(node.target):
486+
continue
487+
if not memory_planning._is_out_var_node(node):
488+
continue
489+
490+
out_arg_names = get_out_args_from_opoverload(node.target)
491+
if len(out_arg_names) != 1:
492+
continue
493+
494+
out_alloc_node = node.kwargs.get(out_arg_names[0])
495+
if out_alloc_node is None or out_alloc_node.target != memory.alloc:
496+
continue
497+
498+
out_val = out_alloc_node.meta.get("val")
499+
if out_val is None or not isinstance(out_val, torch.Tensor):
500+
continue
501+
out_nbytes = out_val.nelement() * out_val.element_size()
502+
503+
input_node = None
504+
for arg in node.args:
505+
if not isinstance(arg, torch.fx.Node):
506+
continue
507+
in_val = arg.meta.get("val")
508+
if in_val is None or not isinstance(in_val, torch.Tensor):
509+
continue
510+
if in_val.nelement() * in_val.element_size() < out_nbytes:
511+
continue
512+
if any(u != node and u.target != memory.free for u in arg.users):
513+
continue
514+
input_node = arg
515+
break
516+
if input_node is None:
517+
continue
518+
519+
with graph_module.graph.inserting_before(out_alloc_node):
520+
inplace_node = graph_module.graph.call_function(
521+
memory.alloc_inplace,
522+
(input_node, out_alloc_node.args[0]),
523+
)
524+
inplace_node.meta = out_alloc_node.meta.copy()
525+
526+
out_alloc_node.replace_all_uses_with(inplace_node)
527+
graph_module.graph.erase_node(out_alloc_node)
528+
changed = True
529+
530+
return PassResult(graph_module, changed)
531+
532+
533+
class ElemWiseInPlaceAwareMemoryPlanningPass(MemoryPlanningPass):
534+
"""MemoryPlanningPass that first runs InPlaceElemWiseLikeOpsPass."""
535+
536+
def __init__(
537+
self,
538+
eligible_ops: Optional[Set[Callable[..., Any]]] = None,
539+
**kwargs: Any,
540+
) -> None:
541+
super().__init__(**kwargs)
542+
self._inplace_pass = InPlaceElemWiseLikeOpsPass(eligible_ops)
543+
544+
def run(
545+
self,
546+
graph_module: torch.fx.GraphModule,
547+
graph_signature=None,
548+
) -> PassResult:
549+
self._inplace_pass(graph_module)
550+
return super().run(graph_module, graph_signature)
551+
552+
447553
def to_scratch_op_pass(graph_module: torch.fx.GraphModule) -> PassResult:
448554
for node in graph_module.graph.nodes:
449555
if node.op != "call_function":

exir/passes/memory_planning_pass.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
from executorch.exir._warnings import deprecated
1616
from executorch.exir.error import internal_assert
17-
from executorch.exir.memory import alloc
17+
from executorch.exir.memory import alloc, alloc_inplace
1818
from executorch.exir.memory_planning import (
1919
_is_out_var_node,
2020
apply_algo,
@@ -192,6 +192,12 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
192192
if len(out_arg_names) == 1:
193193
out_alloc_node = node.kwargs[out_arg_names[0]]
194194
out_alloc_node.meta["spec"] = node.meta["spec"]
195+
if out_alloc_node.target == alloc_inplace and isinstance(
196+
out_alloc_node.args[0], Node
197+
):
198+
base_spec = out_alloc_node.args[0].meta.get("spec")
199+
if base_spec is not None:
200+
node.meta["spec"].inplace_base = base_spec
195201
continue
196202
specs = get_node_tensor_specs(node)
197203
i = 0
@@ -206,7 +212,7 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
206212
# dont increment i as we dont have a spec for this node
207213
internal_assert(
208214
out_alloc_node.op == "call_function"
209-
and out_alloc_node.target == alloc,
215+
and out_alloc_node.target in (alloc, alloc_inplace),
210216
f"Out-var's node {out_alloc_node} has op {out_alloc_node.op} and target {out_alloc_node.target}",
211217
)
212218
internal_assert(

exir/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def init_mem_planning_fields(self) -> None:
216216
self.mem_id = None
217217
self.mem_obj_id = None
218218
self.mem_offset = None
219+
# Set by InPlaceElemWiseLikeOpsPass: the base TensorSpec whose memory
220+
# this spec should share (output allocated in-place over the input).
221+
self.inplace_base: Optional["TensorSpec"] = None
219222

220223
@property
221224
def dtype(self) -> torch.dtype:

0 commit comments

Comments
 (0)