diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index e9bf7201b23..b18bf3b81c6 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1205,6 +1205,43 @@ def register_where(): ) +# ============================================================================= +# IndexTensor.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.aten.index.Tensor) +def register_index_tensor(): + def check_index_tensor_node(node: torch.fx.Node) -> bool: + self_arg = node.args[0] + indices = node.args[1] + + # Only support 1D self tensor + if not isinstance(self_arg, torch.fx.Node): + return False + self_val = self_arg.meta.get("val", None) + if self_val is None: + return False + if len(self_val.size()) != 1: + return False + + # Only support exactly one non-None index tensor + if not isinstance(indices, (list, tuple)): + return False + non_none = [idx for idx in indices if idx is not None] + if len(non_none) != 1: + return False + + return True + + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + inputs_dtypes=utils.FP_INT_T, + supports_resize=True, + are_node_inputs_supported_fn=check_index_tensor_node, + ) + + # ============================================================================= # Arange.cpp # ============================================================================= diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_tensor_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/index_tensor_buffer.glsl new file mode 100644 index 00000000000..3469bb22fcc --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/index_tensor_buffer.glsl @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_self", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_index", "int", "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "BufferMetadata", "index")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Implements aten.index.Tensor for the case where self is 1D and there is +// exactly one index tensor. Each output element is: +// output[...] = self[index[...]] + +void main() { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { + return; + } + + // Convert output buffer index to tensor index + TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi); + + // Read the index value at the same tensor position + const uint index_bufi = tensor_idx_to_linear_idx(index, out_tidx); + const int idx = t_index[index_bufi]; + + // Construct a tensor index for the 1D self tensor. + // In WHCN ordering, a 1D tensor has its elements along dim 0 (width). + TensorIndex self_tidx; + self_tidx.data[0] = uvec4(uint(idx), 0, 0, 0); + self_tidx.data[1] = uvec4(0); + const uint self_bufi = tensor_idx_to_linear_idx(inp, self_tidx); + + t_out[out_bufi] = t_self[self_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_tensor_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_tensor_buffer.yaml new file mode 100644 index 00000000000..ef79704203f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/index_tensor_buffer.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +index_tensor_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: index_tensor_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_tensor_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/index_tensor_texture.glsl new file mode 100644 index 00000000000..8f8026c0a0c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/index_tensor_texture.glsl @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("texture3d", DTYPE)} + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} + +${define_active_storage_type("texture3d")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "common.glslh" +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_self", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_index", "int", "texture3d")} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} +${layout_declare_ubo(B, "TextureMetadata", "index")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Implements aten.index.Tensor for the case where self is 1D and there is +// exactly one index tensor. Each output element is: +// output[...] = self[index[...]] + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(out_pos, outp)) { + return; + } + + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + ivec4 idx_texel = texelFetch(t_index, out_pos, 0); + + VEC4_T out_texel = VEC4_T(0); + + int limit = min( + 4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]); + for (int comp = 0; comp < limit; comp++) { + int idx = idx_texel[comp]; + + // Construct a tensor index for the 1D self tensor. + // In WHCN ordering, a 1D tensor has its elements along dim 0 (width). + TensorIndex4D self_tidx; + self_tidx.data = ivec4(idx, 0, 0, 0); + + TextureElementIndex self_elem = + tensor4d_idx_to_texture_element_idx_simple(inp, self_tidx); + + VEC4_T self_texel = texelFetch(t_self, self_elem.pos, 0); + out_texel[comp] = self_texel[self_elem.comp]; + + out_tidx.data[outp.packed_dim]++; + } + + imageStore(t_out, out_pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_tensor_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_tensor_texture.yaml new file mode 100644 index 00000000000..3e274fa177a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/index_tensor_texture.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +index_tensor_texture: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: index_tensor_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/IndexTensor.cpp b/backends/vulkan/runtime/graph/ops/impl/IndexTensor.cpp new file mode 100644 index 00000000000..b7da1b1ac40 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/IndexTensor.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +namespace vkcompute { + +void resize_index_tensor_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef index = args.at(1).refs.at(1); + + std::vector out_sizes = graph->sizes_of(index); + graph->virtual_resize(out, out_sizes); +} + +void add_index_tensor_node( + ComputeGraph& graph, + const ValueRef self, + const ValueRef index, + const ValueRef out) { + std::string kernel_name = "index_tensor"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(out), graph.meta_ubo(self), graph.meta_ubo(index)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{self, index}, vkapi::kRead}}, + // Shader params buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_index_tensor_node)); +} + +void index_tensor(ComputeGraph& graph, const std::vector& args) { + ValueRef self = args[0]; + ValueRef indices_list_ref = args[1]; + ValueRef out = args[2]; + + ValueListPtr indices_list = graph.get_value_list(indices_list_ref); + VK_CHECK_COND( + indices_list->size() == 1, + "index.Tensor: only one index tensor is supported"); + + ValueRef index = indices_list->at(0); + + add_index_tensor_node(graph, self, index, out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.index.Tensor, index_tensor); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 5ed354ebab3..fe2e4169f05 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -2024,6 +2024,33 @@ def get_bitwise_and_inputs(): return test_suite +@register_test_suite("aten.index.Tensor") +def get_index_tensor_inputs(): + Test = namedtuple("IndexTensorTest", ["self", "indices"]) + + test_cases = [ + # 1D index tensor + Test(self=(M1,), indices=[(S,)]), + Test(self=(M1,), indices=[(M2,)]), + # 2D index tensor + Test(self=(L,), indices=[(S, S1)]), + Test(self=(L,), indices=[(M1, M2)]), + # 3D index tensor + Test(self=(M1,), indices=[(XS, S, S1)]), + ] + + test_suite = VkTestSuite([tuple(tc) for tc in test_cases]) + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"] + test_suite.dtypes = ["at::kFloat"] + test_suite.arg_dtype["indices"] = "at::kInt" + test_suite.arg_data_gen_fn["indices"] = "make_casted_randint_tensor" + return test_suite + + @register_test_suite("aten.pow.Tensor_Scalar") def get_pow_tensor_scalar_inputs(): test_suite = VkTestSuite( diff --git a/backends/vulkan/test/op_tests/utils/aten_types.py b/backends/vulkan/test/op_tests/utils/aten_types.py index 6ad2f568e91..a78263987a1 100644 --- a/backends/vulkan/test/op_tests/utils/aten_types.py +++ b/backends/vulkan/test/op_tests/utils/aten_types.py @@ -12,6 +12,7 @@ AT_SCALAR = "at::Scalar" AT_TENSOR = "at::Tensor" AT_TENSOR_LIST = "at::TensorList" +OPT_TENSOR_LIST = "c10::List<::std::optional>" BOOL = "bool" DOUBLE = "double" INT = "int64_t" diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index 6a7dc2e5d0a..a09b4d36b18 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -26,6 +26,7 @@ OPT_LAYOUT, OPT_MEMORY_FORMAT, OPT_SCALAR_TYPE, + OPT_TENSOR_LIST, STRING, TENSOR_VECTOR, THREE_TENSOR_TUPLE, @@ -86,7 +87,7 @@ def vk_out(self): ValueRefList = Union[ValueRef, List[ValueRef]] -InableCppType = frozenset([AT_TENSOR, AT_TENSOR_LIST]) +InableCppType = frozenset([AT_TENSOR, AT_TENSOR_LIST, OPT_TENSOR_LIST]) class ComputeGraphGen: @@ -313,7 +314,7 @@ def create_value_decl_for(self, ref: ValueRefList) -> str: # noqa: C901 return ret_str cpp_type = "IOValueRef" if (ref.is_in or ref.requires_prepack) else "ValueRef" - if ref.src_cpp_type == AT_TENSOR_LIST: + if ref.src_cpp_type in (AT_TENSOR_LIST, OPT_TENSOR_LIST): ret_str = f"std::vector {ref.name}_io_value_refs;\n" ret_str += f"std::vector {ref.name}_value_refs;\n" return ret_str @@ -409,6 +410,25 @@ def create_value_for( # noqa: C901 ret_str += "}\n" ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n" return ret_str + elif ref.src_cpp_type == OPT_TENSOR_LIST: + assert ref.is_in, "OPT_TENSOR_LIST must be an input" + ret_str = "" + if include_declarations: + ret_str += f"std::vector {ref.name}_io_value_refs;\n" + ret_str += f"std::vector {ref.name}_value_refs;\n" + ret_str += f"for (int i=0; i < (int){ref.src_cpp_name}.size(); i++) {{\n" + ret_str += ( + f" IOValueRef io_value_ref = {self.graph}{self.dot}add_input_tensor(\n" + ) + ret_str += f" {ref.src_cpp_name}[i]->sizes().vec(),\n" + ret_str += ( + f" from_at_scalartype({ref.src_cpp_name}[i]->scalar_type())); \n" + ) + ret_str += f" {ref.name}_value_refs.emplace_back(io_value_ref.value);\n" + ret_str += f" {ref.name}_io_value_refs.emplace_back(io_value_ref);\n" + ret_str += "}\n" + ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n" + return ret_str elif ref.src_cpp_type == TENSOR_VECTOR: ret_str = "" if include_declarations: @@ -491,7 +511,7 @@ def create_op_call(self) -> str: for aten_arg in self.args: ref = self.refs[aten_arg.name] - if ref.src_cpp_type == AT_TENSOR_LIST: + if ref.src_cpp_type in (AT_TENSOR_LIST, OPT_TENSOR_LIST): # Special case. Underlying tensors are input tensors, but the # container itself is just a normal value. op_create_code += f"{ref.name}, " @@ -553,10 +573,20 @@ def virtual_resize(self, ref: ValueRefList) -> str: ret_str += f"{ref.src_cpp_name}.sizes().vec());\n" elif ref.src_cpp_type == AT_TENSOR_LIST: ret_str = "" - ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" + ret_str += ( + f"for (int i=0; i < (int){ref.name}_io_value_refs.size(); i++) {{\n" + ) ret_str += f" {self.graph}{self.dot}virtual_resize({ref.name}_io_value_refs[i].value, " ret_str += f"{ref.src_cpp_name}[i].sizes().vec());\n" ret_str += "}\n" + elif ref.src_cpp_type == OPT_TENSOR_LIST: + ret_str = "" + ret_str += ( + f"for (int i=0; i < (int){ref.name}_io_value_refs.size(); i++) {{\n" + ) + ret_str += f" {self.graph}{self.dot}virtual_resize({ref.name}_io_value_refs[i].value, " + ret_str += f"{ref.src_cpp_name}[i]->sizes().vec());\n" + ret_str += "}\n" else: raise AssertionError(f"{ref.src_cpp_type} not expected") @@ -577,13 +607,26 @@ def copy_into_staging(self, ref: ValueRefList) -> str: ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type()));\n" elif ref.src_cpp_type == AT_TENSOR_LIST: ret_str = "" - ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" + ret_str += ( + f"for (int i=0; i < (int){ref.name}_io_value_refs.size(); i++) {{\n" + ) ret_str += f" {self.graph}{self.dot}maybe_cast_and_copy_into_staging(" ret_str += f"{ref.name}_io_value_refs[i].staging, " ret_str += f"{ref.src_cpp_name}[i].const_data_ptr(), " ret_str += f"{ref.src_cpp_name}[i].numel(), " ret_str += f"from_at_scalartype({ref.src_cpp_name}[i].scalar_type()));\n" ret_str += "}\n" + elif ref.src_cpp_type == OPT_TENSOR_LIST: + ret_str = "" + ret_str += ( + f"for (int i=0; i < (int){ref.name}_io_value_refs.size(); i++) {{\n" + ) + ret_str += f" {self.graph}{self.dot}maybe_cast_and_copy_into_staging(" + ret_str += f"{ref.name}_io_value_refs[i].staging, " + ret_str += f"{ref.src_cpp_name}[i]->const_data_ptr(), " + ret_str += f"{ref.src_cpp_name}[i]->numel(), " + ret_str += f"from_at_scalartype({ref.src_cpp_name}[i]->scalar_type()));\n" + ret_str += "}\n" else: raise AssertionError(f"{ref.src_cpp_type} not expected") return ret_str diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 15627726173..efd073a0cfb 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -25,6 +25,7 @@ OPT_LAYOUT, OPT_MEMORY_FORMAT, OPT_SCALAR_TYPE, + OPT_TENSOR_LIST, STRING, ) from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite @@ -166,6 +167,12 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901 ret_str += f"{cpp_type} {arg.name} = tensor_vec;\n" return ret_str + "\n" + if cpp_type == OPT_TENSOR_LIST: + ret_str = f"{OPT_TENSOR_LIST} {arg.name};\n" + for elem in data: + ret_str += f"{arg.name}.push_back({self.call_data_gen_fn(arg, elem, False)});\n" + return ret_str + "\n" + if cpp_type == AT_INT_ARRAY_REF: ret_str = f"std::vector {arg.name} = " elif cpp_type == OPT_AT_DOUBLE_ARRAY_REF and str(data) != "None":