Skip to content

Commit 30c6c9e

Browse files
Reza Sajadianyfacebook-github-bot
authored andcommitted
memory planner to allocate element-wise output buffer in place of input
Summary: Adds a pass namely `InPlaceElemWiseLikeOpsPass` which checks for possible elem-wise ops in the graph w/o skip conection from input. The pass then annotate the output as a new alloc type called `memory.alloc_inplace`. In memory planning, the nodes with output spec type of `alloc_inplace` get output allocation in place of the same node's input. Differential Revision: D100371295
1 parent 9ce1b55 commit 30c6c9e

8 files changed

Lines changed: 162 additions & 6 deletions

File tree

exir/capture/_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class ExecutorchBackendConfig:
6161
# EdgeProgramManager or can be defined per program.
6262
memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass()
6363
to_out_var_pass: PassType = ToOutVarPass(ignore_to_out_var_failure=False)
64+
inplace_elem_wise_like_ops_pass: Optional[PassType] = None
6465
dynamic_memory_planning_mode: DynamicMemoryPlanningMode = (
6566
DynamicMemoryPlanningMode.UPPER_BOUND
6667
)

exir/emit/_emitter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,9 @@ def call_function( # pyre-fixme[14]
17551755
assert len(args) == 1
17561756
return self._emit_spec(self.node.meta["spec"])
17571757

1758+
elif target == memory.alloc_inplace:
1759+
return self._emit_spec(self.node.meta["spec"])
1760+
17581761
elif target == memory.view:
17591762
return self._emit_view(args)
17601763

exir/memory.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,26 @@ def alloc(spec: AllocSpec) -> pytree.PyTree:
3333
return torch.empty(shape, dtype=dtype)
3434

3535

36+
def alloc_inplace(base: torch.Tensor, spec: AllocSpec) -> torch.Tensor:
37+
"""
38+
Allocate output tensor in the same memory as base tensor.
39+
40+
This is used by InPlaceElemWiseLikeOpsPass to signal to the memory planner
41+
that the output should share the same memory offset as the base (input)
42+
tensor. The base tensor must have allocated_memory >= output's
43+
allocated_memory, and the base must be dead after the consuming op.
44+
45+
At runtime this behaves identically to alloc() — the in-place semantics
46+
are resolved at planning time.
47+
"""
48+
if isinstance(spec, list):
49+
return [alloc_inplace(base, s) for s in spec]
50+
51+
shape, dtype = spec
52+
shape = eval_shape(shape)
53+
return torch.empty(shape, dtype=dtype)
54+
55+
3656
def free(spec: TensorSpec) -> None:
3757
"""
3858
The function is nop. The major purpose is to put it in the Fx IR.

exir/memory_planning.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,16 @@ def verify_storage_reuse(
186186
if not allow_lifetime_and_storage_overlap and self.lifetime_overlap(
187187
lhs_spec, rhs_spec
188188
):
189-
raise InternalError(
190-
f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}"
189+
# In-place element-wise ops intentionally share storage
190+
# between input and output despite overlapping lifetimes.
191+
is_inplace_pair = (
192+
getattr(lhs_spec, "inplace_base", None) is rhs_spec
193+
or getattr(rhs_spec, "inplace_base", None) is lhs_spec
191194
)
195+
if not is_inplace_pair:
196+
raise InternalError(
197+
f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}"
198+
)
192199

193200
# Check that each mem_obj_id is consistent with whether the tensors have
194201
# storage overlap
@@ -485,6 +492,7 @@ def collect_specs_from_nodes( # noqa: C901
485492
or node.target
486493
in [
487494
memory.alloc,
495+
memory.alloc_inplace,
488496
memory.view,
489497
operator.getitem,
490498
torch.ops.higher_order.cond,
@@ -838,22 +846,46 @@ def greedy(
838846

839847
sorted_specs.reverse()
840848

849+
deferred_inplace: List[TensorSpec] = []
850+
841851
for spec in sorted_specs:
842-
# Create an entry for this TensorSpec in the result object that we'll be
843-
# returning from this algorithm.
844852
spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
845853
if spec.mem_id is None:
846854
spec_alloc_result.mem_id = 1
847855
else:
848856
spec_alloc_result.mem_id = spec.mem_id
849857
greedy_result.spec_dict[spec] = spec_alloc_result
850858
spec.realign(alignment)
859+
860+
if getattr(spec, "inplace_base", None) is not None:
861+
deferred_inplace.append(spec)
862+
continue
863+
851864
spec2obj[spec] = pick_shared_obj(
852865
shared_objects[spec_alloc_result.mem_id],
853866
spec,
854867
allow_overlapping_allocations,
855868
)
856869

870+
for spec in deferred_inplace:
871+
base = spec.inplace_base
872+
assert base in spec2obj, (
873+
f"In-place base spec not found in allocated objects. "
874+
f"Base allocated_memory={base.allocated_memory}, "
875+
f"spec allocated_memory={spec.allocated_memory}"
876+
)
877+
sobj = spec2obj[base]
878+
879+
base_alloc_offset = 0
880+
for alloc_entry in sobj.allocations:
881+
if alloc_entry.spec is base:
882+
base_alloc_offset = alloc_entry.offset
883+
break
884+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
885+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
886+
sobj.allocations.append(AllocationSpec(base_alloc_offset, spec))
887+
spec2obj[spec] = sobj
888+
857889
if len(shared_objects) == 0:
858890
# Cannot find any tensor in the graph that needs to be allocated.
859891
# Return [0, 0] to be consistent with default behavior of naive.

exir/passes/__init__.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
to_scratch_op,
3333
)
3434
from executorch.exir.pass_base import ExportPass
35+
from executorch.exir.tensor import TensorSpec
3536
from executorch.exir.pass_manager import PassManager, PassType
3637
from executorch.exir.passes.const_prop_pass import ConstPropPass
3738
from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass
@@ -76,6 +77,7 @@
7677
"ToDevicePass",
7778
"EdgeToBackendOpsPass",
7879
"MemoryFormatOpsPass",
80+
"InPlaceElemWiseLikeOpsPass",
7981
"MemoryPlanningPass",
8082
"HintBasedSymShapeEvalPass",
8183
"insert_write_back_for_buffers_pass",
@@ -260,6 +262,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
260262
# we won't see it in the input graph to the to_out_variant pass, unless
261263
# it's retraced after running to_out_variant with the first trace.
262264
memory.alloc,
265+
memory.alloc_inplace,
263266
memory.view,
264267
executorch_call_delegate,
265268
}
@@ -444,6 +447,86 @@ def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule:
444447
return PassResult(graph_module, True)
445448

446449

450+
class InPlaceElemWiseLikeOpsPass(PassBase):
451+
"""Replace memory.alloc with memory.alloc_inplace for element-wise-like ops.
452+
453+
For out-variant ops that are element-wise, the output can be allocated in
454+
the same memory as the input when:
455+
1. output_bytes <= input_bytes
456+
2. The input tensor has no other users after this op (dead after consumption)
457+
458+
This pass replaces the memory.alloc node for the output with
459+
memory.alloc_inplace(input_node, spec), which signals to the memory planner
460+
to place the output at the same offset as the input.
461+
462+
Eligible ops are specified via the constructor's eligible_ops parameter.
463+
"""
464+
465+
def __init__(
466+
self, eligible_ops: Optional[Set[Callable[..., Any]]] = None
467+
) -> None:
468+
self._eligible_ops = eligible_ops or set()
469+
470+
def _is_eligible(self, target: Any) -> bool:
471+
return target in self._eligible_ops
472+
473+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
474+
changed = False
475+
for node in graph_module.graph.nodes:
476+
if node.op != "call_function":
477+
continue
478+
if not self._is_eligible(node.target):
479+
continue
480+
if not memory_planning._is_out_var_node(node):
481+
continue
482+
483+
out_arg_names = get_out_args_from_opoverload(node.target)
484+
if len(out_arg_names) != 1:
485+
continue
486+
487+
out_alloc_node = node.kwargs.get(out_arg_names[0])
488+
if out_alloc_node is None or out_alloc_node.target != memory.alloc:
489+
continue
490+
491+
input_node = node.args[0]
492+
if not isinstance(input_node, torch.fx.Node):
493+
continue
494+
495+
# Compute sizes from FakeTensor metadata (specs may not be set yet
496+
# on alloc nodes since _set_alloc_node_spec runs in MemoryPlanningPass).
497+
out_val = out_alloc_node.meta.get("val")
498+
in_val = input_node.meta.get("val")
499+
if out_val is None or in_val is None:
500+
continue
501+
if not isinstance(out_val, torch.Tensor) or not isinstance(
502+
in_val, torch.Tensor
503+
):
504+
continue
505+
506+
out_nbytes = out_val.nelement() * out_val.element_size()
507+
in_nbytes = in_val.nelement() * in_val.element_size()
508+
if out_nbytes > in_nbytes:
509+
continue
510+
511+
# Input must have no other users besides this node
512+
input_users = [u for u in input_node.users if u != node]
513+
if len(input_users) > 0:
514+
continue
515+
516+
with graph_module.graph.inserting_before(out_alloc_node):
517+
inplace_node = graph_module.graph.call_function(
518+
memory.alloc_inplace,
519+
(input_node, out_alloc_node.args[0]),
520+
)
521+
inplace_node.meta = out_alloc_node.meta.copy()
522+
523+
out_alloc_node.replace_all_uses_with(inplace_node)
524+
graph_module.graph.erase_node(out_alloc_node)
525+
changed = True
526+
527+
return PassResult(graph_module, changed)
528+
529+
447530
def to_scratch_op_pass(graph_module: torch.fx.GraphModule) -> PassResult:
448531
for node in graph_module.graph.nodes:
449532
if node.op != "call_function":

exir/passes/memory_planning_pass.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
from executorch.exir._warnings import deprecated
1616
from executorch.exir.error import internal_assert
17-
from executorch.exir.memory import alloc
17+
from executorch.exir.memory import alloc, alloc_inplace
1818
from executorch.exir.memory_planning import (
1919
_is_out_var_node,
2020
apply_algo,
@@ -192,6 +192,13 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
192192
if len(out_arg_names) == 1:
193193
out_alloc_node = node.kwargs[out_arg_names[0]]
194194
out_alloc_node.meta["spec"] = node.meta["spec"]
195+
if (
196+
out_alloc_node.target == alloc_inplace
197+
and isinstance(out_alloc_node.args[0], Node)
198+
):
199+
base_spec = out_alloc_node.args[0].meta.get("spec")
200+
if base_spec is not None:
201+
node.meta["spec"].inplace_base = base_spec
195202
continue
196203
specs = get_node_tensor_specs(node)
197204
i = 0
@@ -206,7 +213,7 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
206213
# dont increment i as we dont have a spec for this node
207214
internal_assert(
208215
out_alloc_node.op == "call_function"
209-
and out_alloc_node.target == alloc,
216+
and out_alloc_node.target in (alloc, alloc_inplace),
210217
f"Out-var's node {out_alloc_node} has op {out_alloc_node.op} and target {out_alloc_node.target}",
211218
)
212219
internal_assert(

exir/program/_program.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,18 +821,25 @@ def pre_memory_planning_passes(
821821
raise RuntimeError(
822822
f"sym_shape_eval_pass must be a dict or a PassBase, got {config.sym_shape_eval_pass}"
823823
)
824+
inplace_pass = (
825+
[config.inplace_elem_wise_like_ops_pass]
826+
if config.inplace_elem_wise_like_ops_pass is not None
827+
else []
828+
)
824829
if config.remove_view_copy:
825830
return [
826831
NormalizeViewCopyBasePass(),
827832
dead_code_elimination_pass,
828833
ReplaceViewCopyWithViewPass(),
829834
sym_shape_eval_pass,
830835
config.to_out_var_pass,
836+
*inplace_pass,
831837
]
832838
else:
833839
return [
834840
sym_shape_eval_pass,
835841
config.to_out_var_pass,
842+
*inplace_pass,
836843
]
837844

838845

exir/tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def init_mem_planning_fields(self) -> None:
216216
self.mem_id = None
217217
self.mem_obj_id = None
218218
self.mem_offset = None
219+
# Set by InPlaceElemWiseLikeOpsPass: the base TensorSpec whose memory
220+
# this spec should share (output allocated in-place over the input).
221+
self.inplace_base: Optional["TensorSpec"] = None
219222

220223
@property
221224
def dtype(self) -> torch.dtype:

0 commit comments

Comments
 (0)