-
Notifications
You must be signed in to change notification settings - Fork 475
Expand file tree
/
Copy pathprimitives.h
More file actions
235 lines (188 loc) · 7.67 KB
/
primitives.h
File metadata and controls
235 lines (188 loc) · 7.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
// Low-level (BLAS-like) primitives.
#pragma once
#include "ctranslate2/devices.h"
#include "ctranslate2/types.h"
namespace ctranslate2 {
template <Device D = Device::CPU>
struct primitives {
template <typename T>
static T at(const T* x, dim_t index);
template <typename T>
static void fill(T* x, T a, dim_t size);
template <typename T>
static void strided_fill(T* x, T a, dim_t inc_x, dim_t size);
template <typename T>
static void indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices);
template <typename T>
static void indexed_pointwise_multiply(T* x, const T* values, const int32_t* indices, dim_t num_indices);
template <typename T>
static void copy(const T* x, T* y, dim_t size);
template <typename U, typename V>
static void convert(const U* x, V* y, dim_t size);
template <typename T>
static T sum(const T* array, dim_t size);
template <typename T>
static T mean(const T* array, dim_t size) {
return sum(array, size) / size;
}
template <typename T>
static dim_t max_element(const T* array, dim_t size);
template <typename T>
static T max(const T* array, dim_t size);
template <typename T>
static T amax(const T* array, dim_t size);
template <typename T>
static void add(T a, const T* x, T* y, dim_t size);
template <typename T>
static void add(T a, T* y, dim_t size) {
add(a, y, y, size);
}
template <typename T>
static void add(const T* a, const T* b, T* c, dim_t size);
template <typename T>
static void add(const T* x, T* y, dim_t size) {
add(x, y, y, size);
}
template <typename T>
static void add_batch_broadcast(const T* a, const T* b, T* c, dim_t a_size, dim_t b_size);
template <typename T>
static void add_batch_broadcast(const T* x, T* y, dim_t x_size, dim_t y_size) {
add_batch_broadcast(x, y, y, x_size, y_size);
}
template <typename T>
static void add_depth_broadcast(const T* a, const T* b, T* c, dim_t a_size, dim_t b_size);
template <typename T>
static void add_depth_broadcast(const T* x, T* y, dim_t x_size, dim_t y_size) {
add_depth_broadcast(x, y, y, x_size, y_size);
}
template <typename T>
static void sub(T a, const T* x, T* y, dim_t size) {
T a_rev = -a;
add(a_rev, x, y, size);
}
template <typename T>
static void sub(T a, T* y, dim_t size) {
sub(a, y, y, size);
}
template <typename T>
static void sub(const T* a, const T* b, T* c, dim_t size);
template <typename T>
static void max(T a, const T* x, T* y, dim_t size);
template <typename T>
static void max(const T* a, const T* b, T* c, dim_t size);
template <typename T>
static void max(T a, T* y, dim_t size) {
max(a, y, y, size);
}
template <typename T>
static void min(T a, const T* x, T* y, dim_t size);
template <typename T>
static void min(const T* a, const T* b, T* c, dim_t size);
template <typename T>
static void min(T a, T* y, dim_t size) {
min(a, y, y, size);
}
template <typename T>
static void mul(T a, const T* x, T* y, dim_t size);
template <typename T>
static void mul(T a, T* y, dim_t size) {
mul(a, y, y, size);
}
template <typename T>
static void mul_batch_broadcast(const T* a, const T* b, T* c, dim_t a_size, dim_t b_size);
template <typename T>
static void mul_batch_broadcast(const T* x, T* y, dim_t x_size, dim_t y_size) {
mul_batch_broadcast(x, y, y, x_size, y_size);
}
template <typename T>
static void mul(const T* a, const T* b, T* c, dim_t size);
template <typename T>
static void mul(const T* x, T* y, dim_t size) {
mul(x, y, y, size);
}
template <typename T>
static void penalize_previous_tokens(T* scores,
const T* previous_scores,
const int32_t* previous_ids,
T penalty,
dim_t batch_size,
dim_t length,
dim_t vocabulary_size);
static void prepare_length_mask(const int32_t* lengths,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
int32_t* mask);
template <typename T>
static void transpose_2d(const T* a, const dim_t* dims, T* b);
template <typename T>
static void transpose_3d(const T* a, const dim_t* dims, const dim_t* perm, T* b);
template <typename T>
static void transpose_4d(const T* a, const dim_t* dims, const dim_t* perm, T* b);
template <typename T>
static float logsumexp(const T* x, dim_t size);
template <typename T>
static void exp(const T* x, T* y, dim_t size);
template <typename T>
static void log(const T* x, T* y, dim_t size);
template <typename T>
static void cos(const T* x, T* y, dim_t size);
template <typename T>
static void sin(const T* x, T* y, dim_t size);
template <typename T>
static void tanh(const T* x, T* y, dim_t size);
template <typename T>
static void relu(const T* x, T* y, dim_t size);
template <typename T>
static void gelu(const T* x, T* y, dim_t size);
template <typename T>
static void gelu_tanh(const T* x, T* y, dim_t size);
template <typename T>
static void gelu_sigmoid(const T* x, T* y, dim_t size);
template <typename T>
static void sigmoid(const T* x, T* y, dim_t size);
template <typename T>
static void swish(const T* x, T* y, dim_t size);
static void compute_u8_compensation(const int8_t* b,
bool transpose_b,
dim_t k,
dim_t n,
float alpha,
int32_t* compensation);
// If dest is not passed, returns the number of bytes required to store the packed data,
// or 0 if packing is not supported.
template <typename T>
static dim_t gemm_pack_b(const T* b,
const bool transpose_b,
const dim_t k,
const dim_t n,
const float alpha,
T* dest = nullptr);
template <typename In, typename Out>
static void gemm(bool a_is_packed, bool b_is_packed,
bool transpose_a, bool transpose_b,
dim_t m, dim_t n, dim_t k,
float alpha,
const In* a, dim_t lda,
const In* b, dim_t ldb,
float beta,
Out* c, dim_t ldc,
const Out* a_shift_compensation = nullptr);
template <typename In, typename Out>
static void gemm_batch_strided(bool transpose_a, bool transpose_b,
dim_t m, dim_t n, dim_t k,
float alpha,
const In* a, dim_t lda, dim_t stridea,
const In* b, dim_t ldb, dim_t strideb,
float beta,
Out* c, dim_t ldc, dim_t stridec,
dim_t batch_size);
};
template <Device D1, Device D2>
struct cross_device_primitives {
template <typename T>
static void copy(const T* x, T* y, dim_t size);
};
}