This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Expand file tree
/
Copy pathelemwise_binary_op.cc
More file actions
448 lines (400 loc) · 16.5 KB
/
elemwise_binary_op.cc
File metadata and controls
448 lines (400 loc) · 16.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file elemwise_binary_op.cc
* \brief CPU implementation of elementwise binary operators
*/
#include "./elemwise_binary_op.h"
#if MXNET_USE_CUDA
#include "../../common/cuda/rtc/vectorization-inl.h"
#include "../../common/cuda/rtc.h"
#endif // MXNET_USE_CUDA
namespace mxnet {
namespace op {
bool ElemwiseBinaryOp::SparseSparseWithDenseResult(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U) << " in operator " << attrs.name;
CHECK_EQ(out_attrs->size(), 1U) << " in operator " << attrs.name;
const auto& lhs_stype = in_attrs->at(0);
const auto& rhs_stype = in_attrs->at(1);
auto& out_stype = out_attrs->at(0);
bool dispatched = false;
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex =
invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
if (!dispatched && (lhs_stype == kDefaultStorage || rhs_stype == kDefaultStorage)) {
// dns, dns -> dns
dispatched =
storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && lhs_stype == kRowSparseStorage && rhs_stype == kRowSparseStorage) {
// rsp, rsp -> dns
dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, dispatch_ex);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
bool ElemwiseBinaryOp::BackwardUseInStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
using namespace common;
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 2U);
bool dispatched = false;
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex =
invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
dispatched =
storage_type_assign(out_attrs, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched) {
if (common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) &&
common::ContainsOnlyStorage(*out_attrs, kRowSparseStorage)) {
dispatched = storage_type_assign(out_attrs, kRowSparseStorage, dispatch_mode, dispatch_ex);
}
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}
#if MXNET_USE_CUDA
struct binary_kernel_params {
const void* inputs[3];
void* outputs[2];
};
const char binary_kernel_fwd[] = R"code(
struct binary_kernel_params {
const void *inputs[3];
void *outputs[2];
};
__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void binary_kernel(const binary_kernel_params params,
const index_t lead_dim,
const index_t other_dim,
const index_t N,
const index_t num_aligned_elements) {
using namespace vector;
VectorizedLoader<InputType0, nvec, aligned> loader0(
reinterpret_cast<const InputType0*>(params.inputs[0]), N);
VectorizedLoader<InputType1, nvec, aligned> loader1(
reinterpret_cast<const InputType1*>(params.inputs[1]), N);
VectorizedStorer<OutputType0, nvec, aligned> storer(
reinterpret_cast<OutputType0*>(params.outputs[0]), N);
using IType0 = AccType<InputType0>;
using IType1 = AccType<InputType1>;
using OType = AccType<OutputType0>;
const index_t M = num_aligned_elements;
for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
loader0.load(tid, N);
loader1.load(tid, N);
if (req == OpReqType::kAddTo) {
storer.load(tid, N);
}
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const auto input0 = IType0::from(loader0.separate()[i]);
const auto input1 = IType1::from(loader1.separate()[i]);
const auto temp = OP(input0, input1); // enables returning different type
if (req == OpReqType::kAddTo) {
// temp2 may have a wider type than either temp
// or OType
const auto temp2 = op::add(temp, OType::from(storer.separate()[i]));
storer.separate()[i] = OType::to(temp2);
} else {
storer.separate()[i] = OType::to(temp);
}
}
storer.store(tid, N);
}
}
)code";
void ElemwiseBinaryRTCCompute::operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet::common::cuda::rtc;
if (req[0] == kNullOp)
return;
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
std::string code = "const OpReqType req = ";
code += util::to_string(req[0]);
code +=
";\n"
"#define OP op::";
code += OP;
code += "\n";
const int nvec = outputs[0].type_flag_ == mshadow::kFloat64 ? 2 : 4;
const index_t size = outputs[0].Size();
binary_kernel_params params = {{inputs[0].dptr_, inputs[1].dptr_, nullptr},
{outputs[0].dptr_, nullptr}};
VectorizedKernelRTCLauncher(code,
"binary_kernel",
binary_kernel_fwd,
nvec,
size,
1,
s,
params,
inputs,
outputs,
ctx.run_ctx.get_ctx().dev_id);
}
const char binary_kernel_bwd_use_none[] = R"code(
struct binary_kernel_params {
const void *inputs[3];
void *outputs[2];
};
__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void binary_kernel_bwd(const binary_kernel_params params,
const index_t lead_dim,
const index_t other_dim,
const index_t N,
const index_t num_aligned_elements) {
using namespace vector;
VectorizedLoader<InputType0, nvec, aligned> loader(
reinterpret_cast<const InputType0*>(params.inputs[0]), N);
VectorizedStorer<OutputType0, nvec, aligned> lstorer(
reinterpret_cast<OutputType0*>(params.outputs[0]), N);
VectorizedStorer<OutputType1, nvec, aligned> rstorer(
reinterpret_cast<OutputType1*>(params.outputs[1]), N);
using IType = AccType<InputType0>;
using OType0 = AccType<OutputType0>;
using OType1 = AccType<OutputType1>;
const index_t M = num_aligned_elements;
for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
if (lreq == OpReqType::kAddTo) {
lstorer.load(tid, N);
}
if (rreq == OpReqType::kAddTo) {
rstorer.load(tid, N);
}
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const auto input = IType::from(loader.separate()[i]);
if (write_left_output) {
const auto temp = LOP(input);
if (lreq == OpReqType::kAddTo) {
// temp2 may have a wider type than either temp
// or OType
const auto temp2 = op::add(temp, OType0::from(lstorer.separate()[i]));
lstorer.separate()[i] = OType0::to(temp2);
} else {
lstorer.separate()[i] = OType0::to(temp);
}
}
if (write_right_output) {
const auto temp = ROP(input);
if (rreq == OpReqType::kAddTo) {
// temp2 may have a wider type than either temp
// or OType
const auto temp2 = op::add(temp, OType1::from(rstorer.separate()[i]));
rstorer.separate()[i] = OType1::to(temp2);
} else {
rstorer.separate()[i] = OType1::to(temp);
}
}
}
if (write_left_output) {
lstorer.store(tid, N);
}
if (write_right_output) {
rstorer.store(tid, N);
}
}
}
)code";
void ElemwiseBinaryRTCBwdUseNone::operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet::common::cuda::rtc;
if (req[0] == kNullOp && req[1] == kNullOp)
return;
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U);
bool write_left_output = req[0] != kNullOp && (req[0] != kWriteInplace ||
(req[0] == kWriteInplace && LOP != "identity"));
bool write_right_output = req[1] != kNullOp && (req[1] != kWriteInplace ||
(req[1] == kWriteInplace && ROP != "identity"));
const std::string code = std::string("const OpReqType lreq = ") + util::to_string(req[0]) +
";\n"
"const OpReqType rreq = " +
util::to_string(req[1]) +
";\n"
"#define ROP op::" +
ROP +
"\n"
"#define LOP op::" +
LOP +
"\n"
"const bool write_left_output = " +
std::to_string(write_left_output) +
";\n"
"const bool write_right_output = " +
std::to_string(write_right_output) + ";\n";
const int nvec = outputs[0].type_flag_ == mshadow::kFloat64 ? 2 : 4;
const index_t size = outputs[0].Size();
binary_kernel_params params = {{inputs[0].dptr_, nullptr, nullptr},
{outputs[0].dptr_, outputs[1].dptr_}};
VectorizedKernelRTCLauncher(code,
"binary_kernel_bwd",
binary_kernel_bwd_use_none,
nvec,
size,
1,
s,
params,
inputs,
outputs,
ctx.run_ctx.get_ctx().dev_id);
}
const char binary_kernel_bwd_use_in[] = R"code(
struct binary_kernel_params {
const void *inputs[3];
void *outputs[2];
};
__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void binary_kernel_bwd(const binary_kernel_params params,
const index_t lead_dim,
const index_t other_dim,
const index_t N,
const index_t num_aligned_elements) {
using namespace vector;
VectorizedLoader<InputType0, nvec, aligned> ograd_loader(
reinterpret_cast<const InputType0*>(params.inputs[0]), N);
VectorizedLoader<InputType1, nvec, aligned> linput_loader(
reinterpret_cast<const InputType1*>(params.inputs[1]), N);
VectorizedLoader<InputType2, nvec, aligned> rinput_loader(
reinterpret_cast<const InputType2*>(params.inputs[2]), N);
VectorizedStorer<OutputType0, nvec, aligned> lstorer(
reinterpret_cast<OutputType0*>(params.outputs[0]), N);
VectorizedStorer<OutputType1, nvec, aligned> rstorer(
reinterpret_cast<OutputType1*>(params.outputs[1]), N);
using IType0 = AccType<InputType0>;
using IType1 = AccType<InputType1>;
using IType2 = AccType<InputType2>;
using OType0 = AccType<OutputType0>;
using OType1 = AccType<OutputType1>;
const index_t M = num_aligned_elements;
for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
ograd_loader.load(tid, N);
linput_loader.load(tid, N);
rinput_loader.load(tid, N);
if (lreq == OpReqType::kAddTo) {
lstorer.load(tid, N);
}
if (rreq == OpReqType::kAddTo) {
rstorer.load(tid, N);
}
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const auto ograd = IType0::from(ograd_loader.separate()[i]);
const auto linput = IType1::from(linput_loader.separate()[i]);
const auto rinput = IType2::from(rinput_loader.separate()[i]);
if (lreq != OpReqType::kNullOp) {
const auto temp = op::mul(ograd, LOP(linput, rinput));
if (lreq == OpReqType::kAddTo) {
const auto temp2 = op::add(temp, OType0::from(lstorer.separate()[i]));
lstorer.separate()[i] = OType0::to(temp2);
} else {
lstorer.separate()[i] = OType0::to(temp);
}
}
if (rreq != OpReqType::kNullOp) {
const auto temp = op::mul(ograd, ROP(linput, rinput));
if (rreq == OpReqType::kAddTo) {
const auto temp2 = op::add(temp, OType1::from(rstorer.separate()[i]));
rstorer.separate()[i] = OType1::to(temp2);
} else {
rstorer.separate()[i] = OType1::to(temp);
}
}
}
if (lreq != OpReqType::kNullOp) {
lstorer.store(tid, N);
}
if (rreq != OpReqType::kNullOp) {
rstorer.store(tid, N);
}
}
}
)code";
void ElemwiseBinaryRTCBwdUseIn::operator()(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet::common::cuda::rtc;
if (req[0] == kNullOp && req[1] == kNullOp)
return;
mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
const std::string code = std::string("const OpReqType lreq = ") + util::to_string(req[0]) +
";\n"
"const OpReqType rreq = " +
util::to_string(req[1]) +
";\n"
"#define ROP op::" +
ROP +
"\n"
"#define LOP op::" +
LOP + "\n";
// Using 64 bit loads to reduce register pressure
size_t output_type_size = common::mshadow_type_info(outputs[0].type_flag_).size;
const int nvec = output_type_size <= sizeof(uint64_t) ? (sizeof(uint64_t) / output_type_size) : 1;
const index_t size = outputs[0].Size();
binary_kernel_params params = {{inputs[0].dptr_, inputs[1].dptr_, inputs[2].dptr_},
{outputs[0].dptr_, outputs[1].dptr_}};
VectorizedKernelRTCLauncher(code,
"binary_kernel_bwd",
binary_kernel_bwd_use_in,
nvec,
size,
1,
s,
params,
inputs,
outputs,
ctx.run_ctx.get_ctx().dev_id);
}
#endif // MXNET_USE_CUDA
} // namespace op
} // namespace mxnet