diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index a48d88fa224..5c0383533f1 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -1259,6 +1259,42 @@ def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue: self.chain.instructions.append(kernel) return out_arg + def _emit_subview(self, args: Tuple[_Argument, ...]) -> _EmitterValue: + spec = self.node.meta["spec"] + is_static = spec.is_static_shape_tensor + is_memory_planned = (spec.mem_id is not None) and (spec.mem_offset is not None) + if is_static and is_memory_planned: + return self._emit_spec(spec) + + out_arg = self._emit_argument( + self._emit_spec(self.node.meta["spec"]), torch.TensorType # pyre-ignore[6] + ) + + if self.node.target == memory.select: + self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6] + dim_arg = self._emit_argument(args[1], torch.IntType) + index_arg = self._emit_argument(args[2], torch.IntType) + op_idx, op = self._get_operator( + name="executorch_prim::et_select", + overload="default", + ) + kernel = Instruction( + KernelCall( + op_idx, + args=[self_arg.id, dim_arg.id, index_arg.id, out_arg.id], + ) + ) + else: + raise InternalError( + self._emit_node_specific_error( + self.node, + f"Unsupported subview target: {self.node.target}", + ) + ) + + self.chain.instructions.append(kernel) + return out_arg + def _add_debug_handle( self, emitter_id: int, @@ -1758,6 +1794,9 @@ def call_function( # pyre-fixme[14] elif target == memory.view: return self._emit_view(args) + elif target == memory.select: + return self._emit_subview(args) + elif target == memory.free: assert len(args) == 1 # pyre-ignore diff --git a/exir/memory.py b/exir/memory.py index 36a244bc02f..7f110023962 100644 --- a/exir/memory.py +++ b/exir/memory.py @@ -48,3 +48,12 @@ def view(base: torch.Tensor, size: List[int]) -> torch.Tensor: It is used to elide view_copy nodes. """ return base.view(size) + + +def select(base: torch.Tensor, dim: int, index: int) -> torch.Tensor: + """ + This function mimics torch.ops.aten.select.int. + + It is used to elide select_copy nodes when the result is contiguous. + """ + return base.select(dim, index) diff --git a/exir/memory_planning.py b/exir/memory_planning.py index c5d3441bcde..19b968c8238 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -486,6 +486,7 @@ def collect_specs_from_nodes( # noqa: C901 in [ memory.alloc, memory.view, + memory.select, operator.getitem, torch.ops.higher_order.cond, exir_while, @@ -763,9 +764,12 @@ def get_node_tensor_specs( has no tensor specs. """ # get tensor specs - if node.target == memory.view: + if node.target in (memory.view, memory.select): base = node.args[0] assert isinstance(base, torch.fx.Node) + while base.target in (memory.view, memory.select): + base = base.args[0] + assert isinstance(base, torch.fx.Node) specs = base.meta.get("spec") else: specs = node.meta.get("spec") diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 9b1b8efe682..1f17ab4b617 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -261,6 +261,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None: # it's retraced after running to_out_variant with the first trace. memory.alloc, memory.view, + memory.select, executorch_call_delegate, } to_out_var_skiplist.update(_EXECUTORCH_SYM_OPS) diff --git a/exir/passes/replace_view_copy_with_view_pass.py b/exir/passes/replace_view_copy_with_view_pass.py index 28fcc97aaf5..13fbd13a9cc 100644 --- a/exir/passes/replace_view_copy_with_view_pass.py +++ b/exir/passes/replace_view_copy_with_view_pass.py @@ -33,7 +33,16 @@ def _is_view_copy(node: torch.fx.Node) -> bool: ) +def _is_select_copy(node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in ( + torch.ops.aten.select_copy.int, + ops.edge.aten.select_copy.int, + ) + + _VIEW_OP = memory.view +_SELECT_OP = memory.select + class _Guard: @@ -54,7 +63,12 @@ def __call__(self, view_spec) -> None: # pyre-ignore[2] class _ViewSpec(TensorSpec): - def __init__(self, base: TensorSpec, shape: List[int]) -> None: + def __init__( + self, + base: TensorSpec, + shape: List[int], + byte_offset: int = 0, + ) -> None: """ A _ViewSpec is TensorSpec that shares non-size related fields with its base. The size-related fields are: shape, stride, dim_order, and shape_dynamism. @@ -65,7 +79,11 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None: A _ViewSpec can only be created from a non-sparse, strided TensorSpec. On creation, a _ViewSpec must be compatible with its base with respect to - shape_dynamism, dtype, and nbytes. + shape_dynamism, dtype, and nbytes (when byte_offset is 0). + + When byte_offset is non-zero (used for select/slice sub-views), the view + describes a contiguous sub-region of the base at the given byte offset. + In this case, nbytes may differ from the base and rank may change. A _ViewSpec contains _guards that are evaluated on every __getattribute__ call. The purpose of the guards is to make sure the _ViewSpec is still compatible @@ -119,6 +137,7 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None: self._guards: List[_Guard] = [] self._unguarded_access = False + self._byte_offset: int = byte_offset # Make sure base is not sparse and add a guard if base.is_sparse: @@ -154,27 +173,6 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None: torch.Size(self.shape) ) - # Check compatibility with base on creation - if self.shape_dynamism != base.shape_dynamism: - raise Exception( - f"_ViewSpec is incompatible with its base on creation. It has shape_dynamism={self.shape_dynamism}, but its base has shape_dynamism={base.shape_dynamism}." - ) - self._guards.append( - _Guard( - "shape_dynamism_init", - lambda view_spec: view_spec.shape_dynamism, - base.shape_dynamism, - ) - ) - self._guards.append( - _Guard( - "shape_dynamism_eq_base", - lambda view_spec: view_spec.shape_dynamism - == view_spec._base.shape_dynamism, - True, - ) - ) - if self.dtype != base.dtype: raise Exception( f"_ViewSpec is incompatible with its base on creation. It has dtype={self.dtype}, but its base has dtype={base.dtype}." @@ -183,15 +181,32 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None: _Guard("dtype", lambda view_spec: view_spec.dtype, base.dtype) ) - # We do not guard nbytes because dynamic symints are replaced by upper bounds. - # We do guard on rank, though - if self.nbytes() != base.nbytes(): - raise Exception( - f"_ViewSpec is incompatible with its base on creation. It has nbytes={self.nbytes()}, but its base has nbytes={base.nbytes()}." + # For full views (same nbytes, zero offset), dynamism and rank must match. + # For sub-views (select/slice), the output is a contiguous subset so + # nbytes will differ, rank may change, and dynamism may differ + # (e.g., selecting away a dynamic dim produces a static output). + is_full_view = byte_offset == 0 and self.nbytes() == base.nbytes() + if is_full_view: + if self.shape_dynamism != base.shape_dynamism: + raise Exception( + f"_ViewSpec is incompatible with its base on creation. It has shape_dynamism={self.shape_dynamism}, but its base has shape_dynamism={base.shape_dynamism}." + ) + self._guards.append( + _Guard( + "shape_dynamism_eq_base", + lambda view_spec: view_spec.shape_dynamism + == view_spec._base.shape_dynamism, + True, + ) ) - self._guards.append( - _Guard("rank", lambda view_spec: len(view_spec.shape), len(shape)) - ) + self._guards.append( + _Guard("rank", lambda view_spec: len(view_spec.shape), len(shape)) + ) + else: + if self.nbytes() + byte_offset > base.nbytes(): + raise Exception( + f"_ViewSpec sub-view extends beyond base. Sub-view needs {self.nbytes()} bytes at offset {byte_offset}, but base has {base.nbytes()} bytes." + ) def _run_guards(self) -> None: unguarded_access = self._unguarded_access @@ -211,6 +226,7 @@ def __getattribute__(self, name: str): # pyre-ignore "_guards", "_unguarded_access", "_run_guards", + "_byte_offset", ]: return object.__getattribute__(self, name) @@ -218,7 +234,19 @@ def __getattribute__(self, name: str): # pyre-ignore if name in self._self_fields: val = object.__getattribute__(self, name) elif name in self._base_fields: - val = object.__getattribute__(self._base, name) + val = getattr(self._base, name) + # For static sub-views, adjust mem_offset by byte_offset so the + # emitter can elide the op. For non-static sub-views, return + # None for mem_id/mem_offset — et_select runs at runtime and + # the output needs no allocation info. + if name in ("mem_id", "mem_offset") and val is not None: + byte_offset = object.__getattribute__(self, "_byte_offset") + if byte_offset != 0: + shape_dynamism = object.__getattribute__(self, "shape_dynamism") + if shape_dynamism != TensorShapeDynamism.STATIC: + val = None + elif name == "mem_offset": + val = val + byte_offset else: if len(name) > 0 and name[0] != "_": logger.warning( @@ -239,6 +267,7 @@ def __setattr__(self, name: str, val) -> None: # pyre-ignore "_guards", "_unguarded_access", "_run_guards", + "_byte_offset", ]: object.__setattr__(self, name, val) return @@ -293,11 +322,53 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: n_replaced += 1 + elif _is_select_copy(node) and all( + u.op != "output" for u in node.users + ): + replaced = self._try_replace_select(node) + if replaced: + n_replaced += 1 + module.recompile() logger.debug(f"Replaced {n_replaced} view_copy nodes with {_VIEW_OP} nodes.") + logger.debug( + f"Replaced {n_replaced} select_copy nodes with {_SELECT_OP} nodes." + ) return PassResult(graph_module, n_replaced > 0) + def _try_replace_select(self, node: torch.fx.Node) -> bool: + base = node.args[0] + assert isinstance(base, torch.fx.Node) + base_spec = base.meta["spec"] + + dim: int = node.args[1] + index: int = node.args[2] + base_shape = list(base_spec.shape) + + if dim < 0: + dim += len(base_shape) + + if any(base_shape[i] != 1 for i in range(dim)): + return False + + if index < 0: + index += base_shape[dim] + + base_stride = contiguous_stride_from_shape(torch.Size(base_shape)) + element_size = torch._utils._element_size(base_spec.dtype) + byte_offset = index * base_stride[dim] * element_size + + if base_spec.const: + return False + + node.target = _SELECT_OP + view_spec = _ViewSpec(base_spec, node.meta["val"].shape, byte_offset) + if base_spec.shape_dynamism != TensorShapeDynamism.STATIC: + view_spec.shape_dynamism = base_spec.shape_dynamism + node.meta["spec"] = view_spec + return True + def ensures(self, graph_module: torch.fx.GraphModule) -> None: for module in graph_module.modules(): if not isinstance(module, torch.fx.GraphModule): @@ -310,6 +381,8 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None: ) if node.op == "call_function" and node.target == _VIEW_OP: assert isinstance(node.meta["spec"], _ViewSpec) + if node.op == "call_function" and node.target == _SELECT_OP: + assert isinstance(node.meta["spec"], _ViewSpec) def requires(self, graph_module: torch.fx.GraphModule) -> None: """ diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index 7fd1f9470d4..0a594c0fa83 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -209,6 +209,7 @@ def _reverse_map(d: Dict[Any, Enum]): _KNOWN_FUNCTIONS = { exir.memory.view, + exir.memory.select, } diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index a2eb2491067..febf69d76c3 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -376,6 +376,7 @@ def serialize( _KNOWN_FUNCTIONS_MAP = { "executorch.exir.memory.view": exir.memory.view, + "executorch.exir.memory.select": exir.memory.select, } diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py index bea7e3ff83c..ec400e63d7f 100644 --- a/exir/tests/test_remove_view_copy.py +++ b/exir/tests/test_remove_view_copy.py @@ -12,6 +12,7 @@ from executorch.exir import memory, to_edge from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.passes.replace_view_copy_with_view_pass import _ViewSpec class TestModel1(nn.Module): @@ -234,3 +235,369 @@ def forward(self, x): plan = etpm.executorch_program.execution_plan[0] op_names = [op.name for op in plan.operators] self.assertTrue("executorch_prim::et_view" in op_names) + + def test_contiguous_select_replaced(self) -> None: + class SelectDim0Model(nn.Module): + __test__ = False + + def forward(self, x): + y = x + 1 + z = y.select(0, 2) + return z * 2 + + model = SelectDim0Model() + model.eval() + example_inputs = (torch.rand(4, 3),) + ep = torch.export.export(model, example_inputs, strict=True) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + found_select_view = False + for node in etpm.exported_program().graph.nodes: + if node.target == memory.select: + found_select_view = True + self.assertIsInstance(node.meta["spec"], _ViewSpec) + base = node.args[0] + self.assertEqual( + node.meta["spec"].mem_id, base.meta["spec"].mem_id + ) + self.assertEqual( + node.meta["spec"].lifetime, base.meta["spec"].lifetime + ) + self.assertTrue(found_select_view) + + def test_non_contiguous_select_not_replaced(self) -> None: + class SelectDim1Model(nn.Module): + __test__ = False + + def forward(self, x): + y = x + 1 + z = y.select(1, 1) + return z * 2 + + model = SelectDim1Model() + model.eval() + example_inputs = (torch.rand(4, 3),) + ep = torch.export.export(model, example_inputs, strict=True) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + for node in etpm.exported_program().graph.nodes: + self.assertNotEqual(node.target, memory.select) + + def test_select_output_matches(self) -> None: + class SelectModel(nn.Module): + __test__ = False + + def forward(self, x): + y = x + 1 + z = y.select(0, 2) + return z * 2 + + model = SelectModel() + model.eval() + example_inputs = (torch.rand(4, 3),) + ep = torch.export.export(model, example_inputs, strict=True) + + epm_remove = to_edge(ep) + epm_no_remove = copy.deepcopy(epm_remove) + + etpm_remove = epm_remove.to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + etpm_no_remove = epm_no_remove.to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=False, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + out_remove = etpm_remove.exported_program().module()(*example_inputs) + out_no_remove = etpm_no_remove.exported_program().module()(*example_inputs) + self.assertTrue(torch.allclose(out_remove, out_no_remove)) + + def test_select_spec_mem_offset(self) -> None: + class SelectModel(nn.Module): + __test__ = False + + def forward(self, x): + y = x + 1 + z = y.select(0, 2) + return z * 2 + + model = SelectModel() + model.eval() + example_inputs = (torch.rand(4, 3),) + ep = torch.export.export(model, example_inputs, strict=True) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + for node in etpm.exported_program().graph.nodes: + if node.target == memory.select: + base = node.args[0] + base_offset = base.meta["spec"].mem_offset + select_offset = node.meta["spec"].mem_offset + # select(dim=0, index=2) on [4,3] float32: offset = 2 * 3 * 4 = 24 + self.assertEqual(select_offset, base_offset + 24) + + def test_dynamic_shape_select_replaced(self) -> None: + class SelectDim0Model(nn.Module): + __test__ = False + + def forward(self, x): + y = x + 1 + z = y.select(0, 2) + return z * 2 + + model = SelectDim0Model() + model.eval() + example_inputs = (torch.rand(4, 3),) + dynamic_shapes = {"x": {1: torch.export.Dim("dim1", min=1, max=10)}} + ep = torch.export.export( + model, example_inputs, strict=True, dynamic_shapes=dynamic_shapes + ) + + epm = to_edge(ep) + epm_copy = copy.deepcopy(epm) + + etpm_on = epm.to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), + ), + ) + etpm_off = epm_copy.to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=False, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), + ), + ) + + found_select_view = False + for node in etpm_on.exported_program().graph.nodes: + if node.target == memory.select: + found_select_view = True + self.assertTrue(found_select_view) + + def test_dynamic_shape_select_no_allocation(self) -> None: + class SelectDim0Model(nn.Module): + __test__ = False + + def forward(self, x): + y = x + 1 + z = y.select(0, 2) + return z * 2 + + model = SelectDim0Model() + model.eval() + example_inputs = (torch.rand(4, 3),) + dynamic_shapes = {"x": {1: torch.export.Dim("dim1", min=1, max=10)}} + ep = torch.export.export( + model, example_inputs, strict=True, dynamic_shapes=dynamic_shapes + ) + + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), + ), + ) + + for node in etpm.exported_program().graph.nodes: + if node.target == memory.select: + spec = node.meta["spec"] + self.assertIsNone(spec.mem_id) + self.assertIsNone(spec.mem_offset) + + def test_input_select_shares_allocation(self) -> None: + class InputSelectModel(nn.Module): + __test__ = False + + def forward(self, x): + z = x.select(0, 1) + return z * 2 + + model = InputSelectModel() + model.eval() + example_inputs = (torch.rand(4, 3),) + ep = torch.export.export(model, example_inputs, strict=True) + + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), + ), + ) + + found_select = False + for node in etpm.exported_program().graph.nodes: + if node.target == memory.select: + found_select = True + spec = node.meta["spec"] + base = node.args[0] + self.assertEqual(spec.mem_id, base.meta["spec"].mem_id) + base_offset = base.meta["spec"].mem_offset + # select(0, 1) on [4, 3] float32: byte_offset = 1 * 3 * 4 = 12 + self.assertEqual(spec.mem_offset, base_offset + 12) + self.assertTrue(found_select) + + def test_dynamic_dim_selected_away_replaced(self) -> None: + class SelectDynDimModel(nn.Module): + __test__ = False + + def forward(self, x): + y = x + 1 + z = y.select(1, 0) + return z * 2 + + model = SelectDynDimModel() + model.eval() + example_inputs = (torch.rand(1, 4, 3),) + dynamic_shapes = {"x": {1: torch.export.Dim("seq", min=1, max=128)}} + ep = torch.export.export( + model, example_inputs, strict=True, dynamic_shapes=dynamic_shapes + ) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), + ), + ) + + found_select_view = False + for node in etpm.exported_program().graph.nodes: + if node.target == memory.select: + found_select_view = True + self.assertTrue(found_select_view) + + def test_dynamic_leading_dim_select_not_replaced(self) -> None: + class SelectDim1Model(nn.Module): + __test__ = False + + def forward(self, x): + y = x + 1 + z = y.select(1, 0) + return z * 2 + + model = SelectDim1Model() + model.eval() + example_inputs = (torch.rand(2, 3),) + dynamic_shapes = {"x": {0: torch.export.Dim("batch", min=1, max=10)}} + ep = torch.export.export( + model, example_inputs, strict=True, dynamic_shapes=dynamic_shapes + ) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True), + ), + ) + + for node in etpm.exported_program().graph.nodes: + self.assertNotEqual(node.target, memory.select) + + def test_view_then_select_chained(self) -> None: + class ViewThenSelectModel(nn.Module): + __test__ = False + + def forward(self, x): + y = x + 1 + z = y.view(3, 1, 4) + a = z.select(0, 0) + b = z.select(0, 1) + c = z.select(0, 2) + return a + b + c + + model = ViewThenSelectModel() + model.eval() + example_inputs = (torch.rand(3, 4),) + ep = torch.export.export(model, example_inputs, strict=True) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + found_select_view = False + for node in etpm.exported_program().graph.nodes: + if node.target == memory.select: + found_select_view = True + spec = node.meta["spec"] + self.assertIsInstance(spec, _ViewSpec) + self.assertIsNotNone(spec.mem_id) + self.assertIsNotNone(spec.mem_offset) + self.assertTrue(found_select_view) + + import math + from executorch.exir.schema import ScalarType + + plan = etpm.executorch_program.execution_plan[0] + buffer_sizes = plan.non_const_buffer_sizes + for tensor in plan.values: + t = tensor.val + if hasattr(t, "allocation_info") and t.allocation_info is not None: + mem_id = t.allocation_info.memory_id + if mem_id >= len(buffer_sizes): + continue + buf_size = buffer_sizes[mem_id] + offset = ( + t.allocation_info.memory_offset_high << 32 + ) | t.allocation_info.memory_offset_low + elem_size = { + ScalarType.FLOAT: 4, + ScalarType.INT: 4, + ScalarType.LONG: 8, + ScalarType.DOUBLE: 8, + ScalarType.HALF: 2, + ScalarType.BYTE: 1, + }.get(t.scalar_type, 4) + nbytes = math.prod(t.sizes) * elem_size + self.assertLessEqual( + offset + nbytes, + buf_size, + f"Tensor with sizes={t.sizes} at offset={offset} " + f"exceeds buffer[{mem_id}] size {buf_size}", + ) + + def test_constant_base_not_replaced(self) -> None: + class ConstSelectModel(nn.Module): + __test__ = False + + def __init__(self): + super().__init__() + self.param = nn.Parameter(torch.rand(4, 3)) + self.param.requires_grad = False + + def forward(self, x): + z = self.param.select(0, 1) + return x + z + + model = ConstSelectModel() + model.eval() + example_inputs = (torch.rand(3),) + ep = torch.export.export(model, example_inputs, strict=True) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + for node in etpm.exported_program().graph.nodes: + self.assertNotEqual(node.target, memory.select) diff --git a/kernels/prim_ops/et_select.cpp b/kernels/prim_ops/et_select.cpp new file mode 100644 index 00000000000..e7d86344458 --- /dev/null +++ b/kernels/prim_ops/et_select.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +using executorch::aten::SizesType; +using executorch::aten::Tensor; +using torch::executor::Error; +using torch::executor::resize_tensor; + +namespace torch { +namespace executor { +namespace function { + +constexpr size_t kTensorDimensionLimit = 16; + +void et_select(KernelRuntimeContext& context, Span stack) { + // executorch_prim::et_select.default(Tensor self, int dim, int index) -> out + ET_KERNEL_CHECK_MSG( + context, + stack.size() == 4, + InvalidProgram, + /* void */, + "Expected %zu args, got %zu", + (size_t)4, + stack.size()); + + auto self = (*stack[0]).toTensor(); + auto dim = (*stack[1]).toInt(); + auto index = (*stack[2]).toInt(); + auto out = (*stack[3]).toTensor(); + + ET_KERNEL_CHECK( + context, tensors_have_same_dtype(self, out), InvalidArgument, ); + + ET_KERNEL_CHECK_MSG( + context, + dim >= 0 && dim < self.dim(), + InvalidArgument, + , + "dim %" PRId64 " out of range for tensor with %" PRId64 " dimensions", + dim, + static_cast(self.dim())); + + // Normalize negative index (aten.select semantics) + int64_t dim_size = self.size(dim); + if (index < 0) { + index += dim_size; + } + + ET_KERNEL_CHECK_MSG( + context, + index >= 0 && index < dim_size, + InvalidArgument, + , + "index %" PRId64 " out of range for dim %" PRId64 " with size %" PRId64, + index, + dim, + static_cast(dim_size)); + + // Compute output sizes: self.sizes() with the selected dim removed. + SizesType expected_output_size[kTensorDimensionLimit]; + int out_dims = 0; + for (int i = 0; i < self.dim(); i++) { + if (i != dim) { + expected_output_size[out_dims++] = self.size(i); + } + } + + ET_KERNEL_CHECK_MSG( + context, + out_dims == out.dim(), + InvalidArgument, + , + "Expected output to have %d dims, got %" PRId64, + out_dims, + static_cast(out.dim())); + + ET_KERNEL_CHECK_MSG( + context, + resize_tensor( + out, {expected_output_size, static_cast(out_dims)}) == + Error::Ok, + Internal, + , + "Failed to resize output tensor."); + + // Compute byte offset: index * stride_at_dim * element_size + auto stride_at_dim = self.strides()[dim]; + ssize_t byte_offset = + static_cast(index) * stride_at_dim * self.element_size(); + + ET_KERNEL_CHECK_MSG( + context, + internal::set_tensor_data( + out, + static_cast(self.mutable_data_ptr()) + byte_offset, + out.nbytes()) == Error::Ok, + Internal, + , + "Failed to set data_ptr for out to self + offset."); +} + +} // namespace function +} // namespace executor +} // namespace torch diff --git a/kernels/prim_ops/et_select.h b/kernels/prim_ops/et_select.h new file mode 100644 index 00000000000..dc4dee81bb5 --- /dev/null +++ b/kernels/prim_ops/et_select.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace torch { +namespace executor { +namespace function { + +void et_select(KernelRuntimeContext& context, Span stack); + +} // namespace function +} // namespace executor +} // namespace torch diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 62eb089298a..471a10cd2d7 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -7,6 +7,7 @@ */ #include +#include #include #include #include @@ -663,6 +664,12 @@ static Kernel prim_ops[] = { Kernel("executorch_prim::et_view.default", et_view), #endif +#if !defined(EXECUTORCH_ENABLE_PRIM_OPS_SELECTIVE_BUILD) || \ + defined(INCLUDE_EXECUTORCH_PRIM_ET_SELECT_DEFAULT) + // executorch_prim::et_select.default(Tensor, int, int) -> Tensor + Kernel("executorch_prim::et_select.default", et_select), +#endif + }; executorch::runtime::Span diff --git a/kernels/prim_ops/targets.bzl b/kernels/prim_ops/targets.bzl index 1750eb5c34c..c0bf40c574f 100644 --- a/kernels/prim_ops/targets.bzl +++ b/kernels/prim_ops/targets.bzl @@ -57,6 +57,21 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "et_select" + aten_suffix, + srcs = ["et_select.cpp"], + visibility = ["PUBLIC"], + exported_headers = ["et_select.h"], + deps = [ + "//executorch/runtime/kernel:kernel_includes" + aten_suffix, + "//executorch/runtime/core:core", + ], + exported_deps = [ + "//executorch/runtime/core:evalue" + aten_suffix, + "//executorch/runtime/kernel:kernel_runtime_context" + aten_suffix, + ], + ) + runtime.cxx_library( name = "prim_ops_registry" + aten_suffix, srcs = ["register_prim_ops.cpp"], @@ -70,6 +85,7 @@ def define_common_targets(): }), deps = [ ":et_copy_index" + aten_suffix, + ":et_select" + aten_suffix, ":et_view" + aten_suffix, "//executorch/runtime/core:evalue" + aten_suffix, "//executorch/runtime/kernel:operator_registry" + aten_suffix, diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 37a4e8b157a..e0977fcbd10 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -569,6 +569,70 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) { context_, Span(bad_stack))); } +TEST_F(RegisterPrimOpsTest, TestETSelect) { + EXPECT_TRUE(hasOpsFn("executorch_prim::et_select.default")); + + testing::TensorFactory tf; + + // self: shape [4, 3], data {1..12} + auto self = tf.make({4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto self_evalue = EValue(self); + + // select(dim=0, index=2) -> shape [3], data {7, 8, 9} + auto dim_evalue = EValue(static_cast(0)); + auto index_evalue = EValue(static_cast(2)); + + auto out = tf.ones({3}); + internal::reset_data_ptr(out); + auto out_evalue = EValue(out); + + EValue* stack[4] = {&self_evalue, &dim_evalue, &index_evalue, &out_evalue}; + getOpsFn("executorch_prim::et_select.default")( + context_, Span(stack)); + + EXPECT_TENSOR_EQ(out, tf.make({3}, {7, 8, 9})); + // Verify data pointer sharing: out should point into self's memory + auto* self_base = static_cast(self.const_data_ptr()); + auto* out_base = static_cast(out.const_data_ptr()); + // offset = index(2) * stride(3) * sizeof(int32_t) = 2 * 3 * 4 = 24 + EXPECT_EQ(out_base, self_base + 24); +} + +TEST_F(RegisterPrimOpsTest, TestETSelectDynamicShape) { + testing::TensorFactory tf; + + // self: shape [2, 6] at max capacity, resized to [2, 3] to simulate + // dynamic dim 1. After resize, strides become {3, 1}, so the first 6 + // elements are the logical data for shape [2, 3]. + auto self = tf.make( + {2, 6}, + {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, + {}, + TensorShapeDynamism::DYNAMIC_BOUND); + SizesType self_sizes[2] = {2, 3}; + EXPECT_EQ(resize_tensor(self, {self_sizes, 2}), Error::Ok); + auto self_evalue = EValue(self); + + // select(dim=0, index=1) -> shape [3], data {4, 5, 6} + auto dim_evalue = EValue(static_cast(0)); + auto index_evalue = EValue(static_cast(1)); + + // out: max capacity 6, resized to 1 to force et_select to resize it. + auto out = tf.zeros({6}, TensorShapeDynamism::DYNAMIC_BOUND); + SizesType out_sizes[1] = {1}; + EXPECT_EQ(resize_tensor(out, {out_sizes, 1}), Error::Ok); + internal::reset_data_ptr(out); + auto out_evalue = EValue(out); + + EValue* stack[4] = {&self_evalue, &dim_evalue, &index_evalue, &out_evalue}; + getOpsFn("executorch_prim::et_select.default")( + context_, Span(stack)); + + // et_select should have resized out from [1] to [3] + EXPECT_EQ(out.size(0), 3); + EXPECT_TENSOR_EQ(out, tf.make({3}, {4, 5, 6})); +} + TEST_F(RegisterPrimOpsTest, TestCeil) { std::array inputs = { 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999}; diff --git a/shim_et/xplat/executorch/build/build_variables.bzl b/shim_et/xplat/executorch/build/build_variables.bzl index b0545b8ce18..0d5bf003af5 100644 --- a/shim_et/xplat/executorch/build/build_variables.bzl +++ b/shim_et/xplat/executorch/build/build_variables.bzl @@ -28,6 +28,7 @@ EXECUTORCH_SRCS = [ "kernels/prim_ops/et_copy_index.cpp", + "kernels/prim_ops/et_select.cpp", "kernels/prim_ops/et_view.cpp", "kernels/prim_ops/register_prim_ops.cpp", ] diff --git a/shim_et/xplat/executorch/kernels/prim_ops/selective_build.bzl b/shim_et/xplat/executorch/kernels/prim_ops/selective_build.bzl index 4d73e86dfb5..6f0ecd56445 100644 --- a/shim_et/xplat/executorch/kernels/prim_ops/selective_build.bzl +++ b/shim_et/xplat/executorch/kernels/prim_ops/selective_build.bzl @@ -48,6 +48,7 @@ def prim_ops_registry_selective(name, selected_prim_ops_header_target, aten_suff }) + ["-DET_PRIM_OPS_SELECTIVE_BUILD"], deps = [ "//executorch/kernels/prim_ops:et_copy_index" + aten_suffix, + "//executorch/kernels/prim_ops:et_select" + aten_suffix, "//executorch/kernels/prim_ops:et_view" + aten_suffix, "//executorch/runtime/core:evalue" + aten_suffix, "//executorch/runtime/kernel:operator_registry" + aten_suffix,