-
Notifications
You must be signed in to change notification settings - Fork 992
Expand file tree
/
Copy pathIndexTensor.cpp
More file actions
80 lines (65 loc) · 2.27 KB
/
IndexTensor.cpp
File metadata and controls
80 lines (65 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
namespace vkcompute {
void resize_index_tensor_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)resize_args;
const ValueRef out = args.at(0).refs.at(0);
const ValueRef index = args.at(1).refs.at(1);
std::vector<int64_t> out_sizes = graph->sizes_of(index);
graph->virtual_resize(out, out_sizes);
}
void add_index_tensor_node(
ComputeGraph& graph,
const ValueRef self,
const ValueRef index,
const ValueRef out) {
std::string kernel_name = "index_tensor";
kernel_name.reserve(kShaderNameReserve);
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));
vkapi::ParamsBindList param_ubos = {
graph.meta_ubo(out), graph.meta_ubo(self), graph.meta_ubo(index)};
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{self, index}, vkapi::kRead}},
// Shader params buffers
param_ubos,
// Push Constants
{},
// Specialization Constants
{},
// Resize Args
{},
// Resizing Logic
resize_index_tensor_node));
}
void index_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
ValueRef self = args[0];
ValueRef indices_list_ref = args[1];
ValueRef out = args[2];
ValueListPtr indices_list = graph.get_value_list(indices_list_ref);
VK_CHECK_COND(
indices_list->size() == 1,
"index.Tensor: only one index tensor is supported");
ValueRef index = indices_list->at(0);
add_index_tensor_node(graph, self, index, out);
}
REGISTER_OPERATORS {
VK_REGISTER_OP(aten.index.Tensor, index_tensor);
}
} // namespace vkcompute