-
Notifications
You must be signed in to change notification settings - Fork 998
Expand file tree
/
Copy pathRepeat.cpp
More file actions
94 lines (75 loc) · 2.94 KB
/
Repeat.cpp
File metadata and controls
94 lines (75 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Copy.h>
namespace vkcompute {
namespace {
void check_args(
const api::vTensor& in,
const std::vector<int64_t>& repeats,
const api::vTensor& out) {
VK_CHECK_COND(check_same_packed_dim(in, out));
VK_CHECK_COND(in.storage_type() == out.storage_type());
if (in.storage_type() == utils::kTexture2D) {
VK_CHECK_COND(in.dim() <= 2);
}
int64_t in_dim = in.dim();
VK_CHECK_COND(
in_dim <= repeats.size(),
"Input tensor dim size must be not greater than the repeat argument's size");
VK_CHECK_COND(
dim_at<kWidth4D>(in.sizes()) * dim_at<kWidth4D>(repeats) ==
dim_at<kWidth4D>(out.sizes()),
"Output's width doesn't match input's width * repeat count");
VK_CHECK_COND(
dim_at<kHeight4D>(in.sizes()) * dim_at<kHeight4D>(repeats) ==
dim_at<kHeight4D>(out.sizes()),
"Output's height doesn't match input's height * repeat count");
VK_CHECK_COND(
dim_at<kChannel4D>(in.sizes()) * dim_at<kChannel4D>(repeats) ==
dim_at<kChannel4D>(out.sizes()),
"Output's channel doesn't match input's channel * repeat count");
VK_CHECK_COND(
dim_at<kBatch4D>(in.sizes()) * dim_at<kBatch4D>(repeats) ==
dim_at<kBatch4D>(out.sizes()),
"Output's batch doesn't match input's batch * repeat count");
}
} // namespace
void add_repeat_node(
ComputeGraph& graph,
ValueRef in,
ValueRef repeats_ref,
ValueRef out) {
const std::vector<int64_t> repeats = *(graph.get_int_list(repeats_ref));
vTensorPtr t_in = graph.get_tensor(in);
vTensorPtr t_out = graph.get_tensor(out);
check_args(*t_in, repeats, *t_out);
const utils::ivec4 src_offset{
dim_at<kWidth4D>(t_in->sizes()),
dim_at<kHeight4D>(t_in->sizes()),
dim_at<kChannel4D>(t_in->sizes()),
dim_at<kBatch4D>(t_in->sizes())};
const utils::ivec4 dst_offset{
dim_at<kWidth4D>(repeats),
dim_at<kHeight4D>(repeats),
dim_at<kChannel4D>(repeats),
dim_at<kBatch4D>(repeats)};
add_copy_packed_dim_offset_node(
graph, in, t_out->logical_limits(), src_offset, dst_offset, out, true);
}
void repeat(ComputeGraph& graph, const std::vector<ValueRef>& args) {
add_repeat_node(graph, args[0], args[1], args[2]);
}
REGISTER_OPERATORS {
VK_REGISTER_OP(aten.repeat.default, repeat);
}
} // namespace vkcompute