-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathQ8taBinary.cpp
More file actions
153 lines (132 loc) · 5.33 KB
/
Q8taBinary.cpp
File metadata and controls
153 lines (132 loc) · 5.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
namespace vkcompute {
//
// Dispatch nodes
//
void add_q8ta_binary_node(
ComputeGraph& graph,
const ValueRef packed_int8_input_a,
const ValueRef packed_int8_input_b,
const ValueRef input_a_scale,
const ValueRef input_a_zp,
const ValueRef input_b_scale,
const ValueRef input_b_zp,
const ValueRef output_scale,
const ValueRef output_zp,
const ValueRef alpha,
const ValueRef packed_int8_output,
const std::string& op_name) {
// The implementation assumes that all participating tensors have the same
// packed dimension, and that they all have the same block size for the packed
// dimension
const api::PackedDimInfo& output_info =
graph.packed_dim_info_of(packed_int8_output);
const api::PackedDimInfo& input_a_info =
graph.packed_dim_info_of(packed_int8_input_a);
const api::PackedDimInfo& input_b_info =
graph.packed_dim_info_of(packed_int8_input_b);
VK_CHECK_COND(input_a_info.packed_dim == output_info.packed_dim);
VK_CHECK_COND(input_b_info.packed_dim == output_info.packed_dim);
VK_CHECK_COND(
input_a_info.packed_dim_block_size == output_info.packed_dim_block_size);
VK_CHECK_COND(
input_b_info.packed_dim_block_size == output_info.packed_dim_block_size);
float input_a_scale_val = graph.extract_scalar<float>(input_a_scale);
int32_t input_a_zp_val = graph.extract_scalar<int32_t>(input_a_zp);
float input_b_scale_val = graph.extract_scalar<float>(input_b_scale);
int32_t input_b_zp_val = graph.extract_scalar<int32_t>(input_b_zp);
float output_inv_scale_val = 1.0f / graph.extract_scalar<float>(output_scale);
int32_t output_zp_val = graph.extract_scalar<int32_t>(output_zp);
float alpha_val = 1.0f;
// String is checked since some ops pass in an unused string argument in
// place of alpha
if (is_valid(alpha) && !graph.val_is_string(alpha)) {
alpha_val = graph.extract_scalar<float>(alpha);
}
std::string kernel_name = "q8ta_" + op_name;
add_storage_type_suffix(
kernel_name, graph.storage_type_of(packed_int8_output));
// Pass metadata for output and input tensors
vkapi::ParamsBindList param_buffers;
param_buffers.append(graph.buffer_meta_ubo(packed_int8_output));
param_buffers.append(graph.buffer_meta_ubo(packed_int8_input_a));
param_buffers.append(graph.buffer_meta_ubo(packed_int8_input_b));
std::vector<PushConstantDataInfo> push_constants = {
PushConstantDataInfo(&input_a_scale_val, sizeof(input_a_scale_val)),
PushConstantDataInfo(&input_a_zp_val, sizeof(input_a_zp_val)),
PushConstantDataInfo(&input_b_scale_val, sizeof(input_b_scale_val)),
PushConstantDataInfo(&input_b_zp_val, sizeof(input_b_zp_val)),
PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)),
PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)),
PushConstantDataInfo(&alpha_val, sizeof(alpha_val)),
};
// Create block config for output tensor: inner_dim = output's packed_dim
const BlockConfig block_config =
create_block_config_for_tensor(graph, packed_int8_output);
// Cast block config to ValueRef for pick_linear_global_wg_with_block_config
const ValueRef block_config_ref =
static_cast<ValueRef>(block_config.as_packed_int());
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
pick_linear_global_wg_with_block_config,
pick_square_local_wg_with_block_config,
// Inputs and Outputs
{{packed_int8_output, vkapi::kWrite},
{{packed_int8_input_a, packed_int8_input_b}, vkapi::kRead}},
// Shader params buffers
param_buffers,
// Push Constants
push_constants,
// Specialization Constants
{graph.hashed_layout_of(packed_int8_output),
graph.hashed_layout_of(packed_int8_input_a),
graph.hashed_layout_of(packed_int8_input_b),
block_config.as_packed_int()},
// Resize args
{block_config_ref},
// Resizing Logic
nullptr));
}
//
// High level operator impl
//
void q8ta_add(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int32_t idx = 0;
const ValueRef packed_int8_input_a = args.at(idx++);
const ValueRef packed_int8_input_b = args.at(idx++);
const ValueRef input_a_scale = args.at(idx++);
const ValueRef input_a_zp = args.at(idx++);
const ValueRef input_b_scale = args.at(idx++);
const ValueRef input_b_zp = args.at(idx++);
const ValueRef output_scale = args.at(idx++);
const ValueRef output_zp = args.at(idx++);
const ValueRef alpha = args.at(idx++);
const ValueRef packed_int8_output = args.at(idx++);
add_q8ta_binary_node(
graph,
packed_int8_input_a,
packed_int8_input_b,
input_a_scale,
input_a_zp,
input_b_scale,
input_b_zp,
output_scale,
output_zp,
alpha,
packed_int8_output,
"add");
}
REGISTER_OPERATORS {
VK_REGISTER_OP(et_vk.q8ta_add.default, q8ta_add);
}
} // namespace vkcompute