Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,43 @@ def register_where():
)


# =============================================================================
# IndexTensor.cpp
# =============================================================================


@update_features(exir_ops.edge.aten.index.Tensor)
def register_index_tensor():
def check_index_tensor_node(node: torch.fx.Node) -> bool:
self_arg = node.args[0]
indices = node.args[1]

# Only support 1D self tensor
if not isinstance(self_arg, torch.fx.Node):
return False
self_val = self_arg.meta.get("val", None)
if self_val is None:
return False
if len(self_val.size()) != 1:
return False

# Only support exactly one non-None index tensor
if not isinstance(indices, (list, tuple)):
return False
non_none = [idx for idx in indices if idx is not None]
if len(non_none) != 1:
return False

return True

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
supports_resize=True,
are_node_inputs_supported_fn=check_index_tensor_node,
)


# =============================================================================
# Arange.cpp
# =============================================================================
Expand Down
58 changes: 58 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/index_tensor_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

${define_required_extensions("buffer", DTYPE)}

#define PRECISION ${PRECISION}

#define T ${buffer_scalar_type(DTYPE)}

${define_active_storage_type("buffer")}

layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_self", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_index", "int", "buffer")}

${layout_declare_ubo(B, "BufferMetadata", "outp")}
${layout_declare_ubo(B, "BufferMetadata", "inp")}
${layout_declare_ubo(B, "BufferMetadata", "index")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// Implements aten.index.Tensor for the case where self is 1D and there is
// exactly one index tensor. Each output element is:
// output[...] = self[index[...]]

void main() {
const uint out_bufi = gl_GlobalInvocationID.x;
if (out_of_bounds(out_bufi, outp)) {
return;
}

// Convert output buffer index to tensor index
TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi);

// Read the index value at the same tensor position
const uint index_bufi = tensor_idx_to_linear_idx(index, out_tidx);
const int idx = t_index[index_bufi];

// Construct a tensor index for the 1D self tensor.
// In WHCN ordering, a 1D tensor has its elements along dim 0 (width).
TensorIndex self_tidx;
self_tidx.data[0] = uvec4(uint(idx), 0, 0, 0);
self_tidx.data[1] = uvec4(0);
const uint self_bufi = tensor_idx_to_linear_idx(inp, self_tidx);

t_out[out_bufi] = t_self[self_bufi];
}
16 changes: 16 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/index_tensor_buffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

index_tensor_buffer:
parameter_names_with_default_values:
DTYPE: float
STORAGE: buffer
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: index_tensor_buffer
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

${define_required_extensions("texture3d", DTYPE)}

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}

${define_active_storage_type("texture3d")}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

#include "common.glslh"
#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_self", DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_index", "int", "texture3d")}

${layout_declare_ubo(B, "TextureMetadata", "outp")}
${layout_declare_ubo(B, "TextureMetadata", "inp")}
${layout_declare_ubo(B, "TextureMetadata", "index")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// Implements aten.index.Tensor for the case where self is 1D and there is
// exactly one index tensor. Each output element is:
// output[...] = self[index[...]]

void main() {
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);

if (out_of_bounds(out_pos, outp)) {
return;
}

TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);
ivec4 idx_texel = texelFetch(t_index, out_pos, 0);

VEC4_T out_texel = VEC4_T(0);

int limit = min(
4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]);
for (int comp = 0; comp < limit; comp++) {
int idx = idx_texel[comp];

// Construct a tensor index for the 1D self tensor.
// In WHCN ordering, a 1D tensor has its elements along dim 0 (width).
TensorIndex4D self_tidx;
self_tidx.data = ivec4(idx, 0, 0, 0);

TextureElementIndex self_elem =
tensor4d_idx_to_texture_element_idx_simple(inp, self_tidx);

VEC4_T self_texel = texelFetch(t_self, self_elem.pos, 0);
out_texel[comp] = self_texel[self_elem.comp];

out_tidx.data[outp.packed_dim]++;
}

imageStore(t_out, out_pos, out_texel);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

index_tensor_texture:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: index_tensor_texture3d
80 changes: 80 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/IndexTensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void resize_index_tensor_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)resize_args;
const ValueRef out = args.at(0).refs.at(0);
const ValueRef index = args.at(1).refs.at(1);

std::vector<int64_t> out_sizes = graph->sizes_of(index);
graph->virtual_resize(out, out_sizes);
}

void add_index_tensor_node(
ComputeGraph& graph,
const ValueRef self,
const ValueRef index,
const ValueRef out) {
std::string kernel_name = "index_tensor";
kernel_name.reserve(kShaderNameReserve);
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

vkapi::ParamsBindList param_ubos = {
graph.meta_ubo(out), graph.meta_ubo(self), graph.meta_ubo(index)};

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{self, index}, vkapi::kRead}},
// Shader params buffers
param_ubos,
// Push Constants
{},
// Specialization Constants
{},
// Resize Args
{},
// Resizing Logic
resize_index_tensor_node));
}

void index_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
ValueRef self = args[0];
ValueRef indices_list_ref = args[1];
ValueRef out = args[2];

ValueListPtr indices_list = graph.get_value_list(indices_list_ref);
VK_CHECK_COND(
indices_list->size() == 1,
"index.Tensor: only one index tensor is supported");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about Bool vs Int? I think you are implementing the integer case only, right? If that's the case, also add a VK_CHECK_COND saying that only integer index is supported.


ValueRef index = indices_list->at(0);

add_index_tensor_node(graph, self, index, out);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.index.Tensor, index_tensor);
}

} // namespace vkcompute
27 changes: 27 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2024,6 +2024,33 @@ def get_bitwise_and_inputs():
return test_suite


@register_test_suite("aten.index.Tensor")
def get_index_tensor_inputs():
Test = namedtuple("IndexTensorTest", ["self", "indices"])

test_cases = [
# 1D index tensor
Test(self=(M1,), indices=[(S,)]),
Test(self=(M1,), indices=[(M2,)]),
# 2D index tensor
Test(self=(L,), indices=[(S, S1)]),
Test(self=(L,), indices=[(M1, M2)]),
# 3D index tensor
Test(self=(M1,), indices=[(XS, S, S1)]),
]

test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
test_suite.layouts = [
"utils::kWidthPacked",
"utils::kChannelsPacked",
]
test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
test_suite.dtypes = ["at::kFloat"]
test_suite.arg_dtype["indices"] = "at::kInt"
test_suite.arg_data_gen_fn["indices"] = "make_casted_randint_tensor"
return test_suite


@register_test_suite("aten.pow.Tensor_Scalar")
def get_pow_tensor_scalar_inputs():
test_suite = VkTestSuite(
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/test/op_tests/utils/aten_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AT_SCALAR = "at::Scalar"
AT_TENSOR = "at::Tensor"
AT_TENSOR_LIST = "at::TensorList"
OPT_TENSOR_LIST = "c10::List<::std::optional<at::Tensor>>"
BOOL = "bool"
DOUBLE = "double"
INT = "int64_t"
Expand Down
Loading
Loading