-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathop_fully_connected.cpp
More file actions
68 lines (58 loc) · 2.14 KB
/
op_fully_connected.cpp
File metadata and controls
68 lines (58 loc) · 2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/cadence/generic/operators/op_fully_connected.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
namespace impl {
namespace generic {
namespace native {
using ::executorch::aten::optional;
using ::executorch::aten::Tensor;
using ::executorch::runtime::getLeadingDims;
using ::executorch::runtime::KernelRuntimeContext;
void linear(
const Tensor& input,
const Tensor& weight,
const optional<Tensor>& bias,
Tensor& output) {
const float* __restrict__ input_data = input.const_data_ptr<float>();
const float* __restrict__ weight_data = weight.const_data_ptr<float>();
const float* __restrict__ bias_data =
bias.has_value() ? bias.value().const_data_ptr<float>() : nullptr;
float* __restrict__ output_data = output.mutable_data_ptr<float>();
// input comes in shape [batch_size, in_dim]
// weight comes in shape [out_dim, in_dim]
// output comes in empty with shape [batch_size, out_dim]
// Perform matrix multiply (M x N) x (N x P) => M x P
int64_t M = weight.size(0); // = out_dim
int64_t N = weight.size(1); // = in_dim
// Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the
// leading dimensions is d0 * d1 * ... * d_{N-2}
int64_t leading_dims = getLeadingDims(input, input.dim() - 1);
for (int i = 0; i < leading_dims; ++i) {
for (int j = 0; j < M; ++j) {
float sum = bias_data != nullptr ? bias_data[j] : 0.0f;
for (int k = 0; k < N; ++k) {
sum += input_data[i * N + k] * weight_data[j * N + k];
}
output_data[i * M + j] = sum;
}
}
}
Tensor& fully_connected_out(
ET_UNUSED KernelRuntimeContext& ctx,
const Tensor& input,
const Tensor& weight,
const optional<Tensor>& bias,
Tensor& output) {
linear(input, weight, bias, output);
return output;
}
} // namespace native
} // namespace generic
} // namespace impl