Skip to content

Commit ed21ca2

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Convert contiguous select_copy to zero-copy view in ReplaceViewCopyWithViewPass (pytorch#19198)
Summary: Extends the ReplaceViewCopyWithViewPass to convert `select_copy` ops to zero-copy `memory.select` views when the output is a contiguous sub-region of the base tensor. This is the same pattern used for `view_copy` -> `memory.view`, but for select operations. The pass checks that the base is densely packed, non-constant, and static, and that the selected output forms a dense packing. For static memory-planned subviews, the emitter elides the op entirely (no runtime instruction) by serializing tensor metadata with `mem_offset = base_offset + byte_delta`. For dynamic shapes, a new `executorch_prim::et_select` runtime op sets the output data pointer to `self.data_ptr + offset`. Changes: - `exir/memory.py`: Added `memory.select` function - `exir/passes/replace_view_copy_with_view_pass.py`: Extended `_ViewSpec` with `byte_offset`, `stride`, `dim_order` params; added contiguity check using `stride_from_dim_order`; added select_copy handling in the pass - Pipeline integration: memory planner, to_out_var skiplist, emitter, serialization - `kernels/prim_ops/et_select.{h,cpp}`: C++ runtime op for dynamic select views - Tests: 5 new Python tests + 1 C++ test Authored with Claude. Differential Revision: D102396195
1 parent 5a206ab commit ed21ca2

14 files changed

Lines changed: 582 additions & 11 deletions

File tree

exir/emit/_emitter.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,6 +1259,42 @@ def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue:
12591259
self.chain.instructions.append(kernel)
12601260
return out_arg
12611261

1262+
def _emit_subview(self, args: Tuple[_Argument, ...]) -> _EmitterValue:
1263+
spec = self.node.meta["spec"]
1264+
is_static = spec.is_static_shape_tensor
1265+
is_memory_planned = (spec.mem_id is not None) and (spec.mem_offset is not None)
1266+
if is_static and is_memory_planned:
1267+
return self._emit_spec(spec)
1268+
1269+
out_arg = self._emit_argument(
1270+
self._emit_spec(self.node.meta["spec"]), torch.TensorType # pyre-ignore[6]
1271+
)
1272+
1273+
if self.node.target == memory.select:
1274+
self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6]
1275+
dim_arg = self._emit_argument(args[1], torch.IntType)
1276+
index_arg = self._emit_argument(args[2], torch.IntType)
1277+
op_idx, op = self._get_operator(
1278+
name="executorch_prim::et_select",
1279+
overload="default",
1280+
)
1281+
kernel = Instruction(
1282+
KernelCall(
1283+
op_idx,
1284+
args=[self_arg.id, dim_arg.id, index_arg.id, out_arg.id],
1285+
)
1286+
)
1287+
else:
1288+
raise InternalError(
1289+
self._emit_node_specific_error(
1290+
self.node,
1291+
f"Unsupported subview target: {self.node.target}",
1292+
)
1293+
)
1294+
1295+
self.chain.instructions.append(kernel)
1296+
return out_arg
1297+
12621298
def _add_debug_handle(
12631299
self,
12641300
emitter_id: int,
@@ -1758,6 +1794,9 @@ def call_function( # pyre-fixme[14]
17581794
elif target == memory.view:
17591795
return self._emit_view(args)
17601796

1797+
elif target == memory.select:
1798+
return self._emit_subview(args)
1799+
17611800
elif target == memory.free:
17621801
assert len(args) == 1
17631802
# pyre-ignore

exir/memory.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,12 @@ def view(base: torch.Tensor, size: List[int]) -> torch.Tensor:
4848
It is used to elide view_copy nodes.
4949
"""
5050
return base.view(size)
51+
52+
53+
def select(base: torch.Tensor, dim: int, index: int) -> torch.Tensor:
54+
"""
55+
This function mimics torch.ops.aten.select.int.
56+
57+
It is used to elide select_copy nodes when the result is contiguous.
58+
"""
59+
return base.select(dim, index)

exir/memory_planning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def collect_specs_from_nodes( # noqa: C901
486486
in [
487487
memory.alloc,
488488
memory.view,
489+
memory.select,
489490
operator.getitem,
490491
torch.ops.higher_order.cond,
491492
exir_while,
@@ -763,7 +764,7 @@ def get_node_tensor_specs(
763764
has no tensor specs.
764765
"""
765766
# get tensor specs
766-
if node.target == memory.view:
767+
if node.target in (memory.view, memory.select):
767768
base = node.args[0]
768769
assert isinstance(base, torch.fx.Node)
769770
specs = base.meta.get("spec")

exir/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
261261
# it's retraced after running to_out_variant with the first trace.
262262
memory.alloc,
263263
memory.view,
264+
memory.select,
264265
executorch_call_delegate,
265266
}
266267
to_out_var_skiplist.update(_EXECUTORCH_SYM_OPS)

exir/passes/replace_view_copy_with_view_pass.py

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,27 @@ def _is_view_copy(node: torch.fx.Node) -> bool:
3333
)
3434

3535

36+
def _is_select_copy(node: torch.fx.Node) -> bool:
37+
return node.op == "call_function" and node.target in (
38+
torch.ops.aten.select_copy.int,
39+
ops.edge.aten.select_copy.int,
40+
)
41+
42+
3643
_VIEW_OP = memory.view
44+
_SELECT_OP = memory.select
45+
46+
47+
def _is_contiguous_base(spec: TensorSpec) -> bool:
48+
if not isinstance(spec, TensorSpec):
49+
return False
50+
if spec.const:
51+
return False
52+
shape = list(spec.shape)
53+
if tuple(spec.dim_order) != tuple(range(len(shape))):
54+
return False
55+
expected_stride = contiguous_stride_from_shape(torch.Size(shape))
56+
return tuple(spec.stride) == tuple(expected_stride)
3757

3858

3959
class _Guard:
@@ -54,7 +74,12 @@ def __call__(self, view_spec) -> None: # pyre-ignore[2]
5474

5575

5676
class _ViewSpec(TensorSpec):
57-
def __init__(self, base: TensorSpec, shape: List[int]) -> None:
77+
def __init__(
78+
self,
79+
base: TensorSpec,
80+
shape: List[int],
81+
byte_offset: int = 0,
82+
) -> None:
5883
"""
5984
A _ViewSpec is TensorSpec that shares non-size related fields with its base.
6085
The size-related fields are: shape, stride, dim_order, and shape_dynamism.
@@ -65,7 +90,11 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
6590
6691
A _ViewSpec can only be created from a non-sparse, strided TensorSpec.
6792
On creation, a _ViewSpec must be compatible with its base with respect to
68-
shape_dynamism, dtype, and nbytes.
93+
shape_dynamism, dtype, and nbytes (when byte_offset is 0).
94+
95+
When byte_offset is non-zero (used for select/slice sub-views), the view
96+
describes a contiguous sub-region of the base at the given byte offset.
97+
In this case, nbytes may differ from the base and rank may change.
6998
7099
A _ViewSpec contains _guards that are evaluated on every __getattribute__ call.
71100
The purpose of the guards is to make sure the _ViewSpec is still compatible
@@ -119,6 +148,7 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
119148

120149
self._guards: List[_Guard] = []
121150
self._unguarded_access = False
151+
self._byte_offset: int = byte_offset
122152

123153
# Make sure base is not sparse and add a guard
124154
if base.is_sparse:
@@ -183,15 +213,19 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
183213
_Guard("dtype", lambda view_spec: view_spec.dtype, base.dtype)
184214
)
185215

186-
# We do not guard nbytes because dynamic symints are replaced by upper bounds.
187-
# We do guard on rank, though
188-
if self.nbytes() != base.nbytes():
189-
raise Exception(
190-
f"_ViewSpec is incompatible with its base on creation. It has nbytes={self.nbytes()}, but its base has nbytes={base.nbytes()}."
216+
# For traditional views (same nbytes, zero offset), rank is guarded.
217+
# For sub-views (select/slice), the output is a contiguous subset so
218+
# nbytes will differ and rank may change.
219+
is_full_view = byte_offset == 0 and self.nbytes() == base.nbytes()
220+
if is_full_view:
221+
self._guards.append(
222+
_Guard("rank", lambda view_spec: len(view_spec.shape), len(shape))
191223
)
192-
self._guards.append(
193-
_Guard("rank", lambda view_spec: len(view_spec.shape), len(shape))
194-
)
224+
else:
225+
if self.nbytes() + byte_offset > base.nbytes():
226+
raise Exception(
227+
f"_ViewSpec sub-view extends beyond base. Sub-view needs {self.nbytes()} bytes at offset {byte_offset}, but base has {base.nbytes()} bytes."
228+
)
195229

196230
def _run_guards(self) -> None:
197231
unguarded_access = self._unguarded_access
@@ -211,6 +245,7 @@ def __getattribute__(self, name: str): # pyre-ignore
211245
"_guards",
212246
"_unguarded_access",
213247
"_run_guards",
248+
"_byte_offset",
214249
]:
215250
return object.__getattribute__(self, name)
216251

@@ -219,6 +254,11 @@ def __getattribute__(self, name: str): # pyre-ignore
219254
val = object.__getattribute__(self, name)
220255
elif name in self._base_fields:
221256
val = object.__getattribute__(self._base, name)
257+
# For sub-views (select/slice), adjust mem_offset by byte_offset
258+
if name == "mem_offset" and val is not None:
259+
byte_offset = object.__getattribute__(self, "_byte_offset")
260+
if byte_offset != 0:
261+
val = val + byte_offset
222262
else:
223263
if len(name) > 0 and name[0] != "_":
224264
logger.warning(
@@ -239,6 +279,7 @@ def __setattr__(self, name: str, val) -> None: # pyre-ignore
239279
"_guards",
240280
"_unguarded_access",
241281
"_run_guards",
282+
"_byte_offset",
242283
]:
243284
object.__setattr__(self, name, val)
244285
return
@@ -293,11 +334,50 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
293334

294335
n_replaced += 1
295336

337+
elif _is_select_copy(node) and all(
338+
u.op != "output" for u in node.users
339+
):
340+
replaced = self._try_replace_select(node)
341+
if replaced:
342+
n_replaced += 1
343+
296344
module.recompile()
297345

298346
logger.debug(f"Replaced {n_replaced} view_copy nodes with {_VIEW_OP} nodes.")
347+
logger.debug(
348+
f"Replaced {n_replaced} select_copy nodes with {_SELECT_OP} nodes."
349+
)
299350
return PassResult(graph_module, n_replaced > 0)
300351

352+
def _try_replace_select(self, node: torch.fx.Node) -> bool:
353+
base = node.args[0]
354+
assert isinstance(base, torch.fx.Node)
355+
base_spec = base.meta["spec"]
356+
if not _is_contiguous_base(base_spec):
357+
return False
358+
359+
dim: int = node.args[1]
360+
index: int = node.args[2]
361+
base_shape = [int(s) for s in base_spec.shape]
362+
363+
if dim < 0:
364+
dim += len(base_shape)
365+
366+
if any(base_shape[i] != 1 for i in range(dim)):
367+
return False
368+
369+
if index < 0:
370+
index += base_shape[dim]
371+
372+
out_shape = list(node.meta["spec"].shape)
373+
base_stride = contiguous_stride_from_shape(torch.Size(base_shape))
374+
element_size = torch._utils._element_size(base_spec.dtype)
375+
byte_offset = index * base_stride[dim] * element_size
376+
377+
node.target = _SELECT_OP
378+
node.meta["spec"] = _ViewSpec(base_spec, out_shape, byte_offset)
379+
return True
380+
301381
def ensures(self, graph_module: torch.fx.GraphModule) -> None:
302382
for module in graph_module.modules():
303383
if not isinstance(module, torch.fx.GraphModule):
@@ -310,6 +390,8 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None:
310390
)
311391
if node.op == "call_function" and node.target == _VIEW_OP:
312392
assert isinstance(node.meta["spec"], _ViewSpec)
393+
if node.op == "call_function" and node.target == _SELECT_OP:
394+
assert isinstance(node.meta["spec"], _ViewSpec)
313395

314396
def requires(self, graph_module: torch.fx.GraphModule) -> None:
315397
"""

exir/serde/export_serialize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def _reverse_map(d: Dict[Any, Enum]):
209209

210210
_KNOWN_FUNCTIONS = {
211211
exir.memory.view,
212+
exir.memory.select,
212213
}
213214

214215

exir/serde/serialize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ def serialize(
376376

377377
_KNOWN_FUNCTIONS_MAP = {
378378
"executorch.exir.memory.view": exir.memory.view,
379+
"executorch.exir.memory.select": exir.memory.select,
379380
}
380381

381382

0 commit comments

Comments
 (0)