-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathSplit.cpp
More file actions
100 lines (82 loc) · 2.98 KB
/
Split.cpp
File metadata and controls
100 lines (82 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
namespace vkcompute {
void add_split_node(
ComputeGraph& graph,
const ValueRef input,
const std::vector<int64_t>& split_sizes,
const int64_t dim,
const ValueRef out,
const int split_idx) {
std::string kernel_name = "split";
kernel_name.reserve(kShaderNameReserve);
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));
vkapi::ParamsBindList param_ubos = {
graph.meta_ubo(out), graph.meta_ubo(input)};
int64_t dim_whcn = nchw_dim_to_whcn_dim(dim, graph.dim_of(input));
// Calculate the offset for this split by summing previous split sizes
int64_t split_offset = 0;
for (int i = 0; i < split_idx; i++) {
split_offset += split_sizes[i];
}
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {input, vkapi::kRead}},
// Shader params buffers
param_ubos,
// Push Constants
{},
// Specialization Constants
{utils::safe_downcast<int32_t>(dim_whcn),
static_cast<int32_t>(split_idx),
static_cast<int32_t>(split_offset)},
// Resize Args
{},
// Resizing Logic
nullptr));
}
void add_split_with_sizes_node(
ComputeGraph& graph,
const ValueRef input,
const std::vector<int64_t>& split_sizes,
const int64_t dim,
const ValueRef out_list_ref) {
const ValueListPtr out_list = graph.get_value_list(out_list_ref);
VK_CHECK_COND(out_list->size() == split_sizes.size());
// Dispatch a shader for each output tensor
for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) {
const ValueRef out_ref = out_list->at(split_idx);
add_split_node(graph, input, split_sizes, dim, out_ref, split_idx);
}
}
void split_with_sizes_copy_default(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
ValueRef input = args[0];
ValueRef split_sizes_ref = args[1];
ValueRef dim_ref = args[2];
ValueRef out_list_ref = args[3];
int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
std::vector<int64_t> split_sizes =
graph.extract_int_or_symint_list(split_sizes_ref);
add_split_with_sizes_node(graph, input, split_sizes, dim, out_list_ref);
}
REGISTER_OPERATORS {
VK_REGISTER_OP(
aten.split_with_sizes_copy.default, split_with_sizes_copy_default);
}
} // namespace vkcompute