Skip to content

Commit ed54226

Browse files
authored
Sigmoid in EWLayer (#194)
1 parent acba1bb commit ed54226

2 files changed

Lines changed: 88 additions & 1 deletion

File tree

include/layers/EWLayer.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <algorithm>
33
#include <cmath>
44
#include <string>
5+
#include <type_traits>
56
#include <utility>
67

78
#include "layers/Layer.hpp"
@@ -87,9 +88,26 @@ std::vector<ValueType> EWLayerImpl<ValueType>::run(
8788
static_cast<ValueType>(beta_);
8889
};
8990
std::transform(input.begin(), input.end(), res.begin(), linear);
91+
} else if (func_ == "sigmoid") {
92+
auto sigmoid = [](ValueType x) -> ValueType {
93+
if constexpr (std::is_integral_v<ValueType>) {
94+
auto x_float = static_cast<float>(x);
95+
float result = 1.0F / (1.0F + std::exp(-x_float));
96+
return static_cast<ValueType>(std::round(result));
97+
} else {
98+
if (x >= ValueType(0)) {
99+
ValueType z = std::exp(-x);
100+
return ValueType(1) / (ValueType(1) + z);
101+
}
102+
ValueType z = std::exp(x);
103+
return z / (ValueType(1) + z);
104+
}
105+
};
106+
std::transform(input.cbegin(), input.cend(), res.begin(), sigmoid);
90107
} else {
91108
throw std::invalid_argument("No such function for EWLayer");
92109
}
110+
93111
return res;
94112
}
95113

test/single_layer/test_ewlayer.cpp

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,25 @@ INSTANTIATE_TEST_SUITE_P(
5454
std::make_tuple(basic_data2,
5555
EWLayerImpl<double>({2, 2}, "linear", 2.0F, 1.0F),
5656
std::vector<double>({3.0, -1.0, 5.0, -3.0}),
57-
std::function<double(double)>())));
57+
std::function<double(double)>()),
58+
std::make_tuple(std::vector<double>({0.0, 1.0, -1.0}),
59+
EWLayerImpl<double>({3}, "sigmoid"),
60+
std::vector<double>(),
61+
std::function<double(double)>([](double x) {
62+
return 1.0 / (1.0 + std::exp(-x));
63+
})),
64+
std::make_tuple(std::vector<double>{-100.0, -50.0, 0.0, 50.0, 100.0},
65+
EWLayerImpl<double>({5}, "sigmoid"),
66+
std::vector<double>(),
67+
std::function<double(double)>([](double x) {
68+
if (x >= 0) {
69+
double z = std::exp(-x);
70+
return 1.0 / (1.0 + z);
71+
} else {
72+
double z = std::exp(x);
73+
return z / (1.0 + z);
74+
}
75+
}))));
5876

5977
TEST(ewlayer, new_ewlayer_can_relu_float) {
6078
EWLayer layer("relu");
@@ -100,3 +118,54 @@ TEST(ewlayer, new_ewlayer_throws_with_invalid_function) {
100118
TEST(ewlayer, get_layer_name) {
101119
EXPECT_EQ(EWLayer::get_name(), "Element-wise layer");
102120
}
121+
122+
TEST(ewlayer, new_ewlayer_can_sigmoid_float) {
123+
EWLayer layer("sigmoid");
124+
Tensor input = make_tensor<float>({0.0F, -1.0F, 1.0F, 2.0F});
125+
Tensor output;
126+
std::vector<float> expected_output = {0.5F, 1.0F / (1.0F + std::exp(1.0F)),
127+
1.0F / (1.0F + std::exp(-1.0F)),
128+
1.0F / (1.0F + std::exp(-2.0F))};
129+
layer.run(input, output);
130+
for (size_t i = 0; i < 4; i++) {
131+
EXPECT_NEAR((*output.as<float>())[i], expected_output[i], 1e-5F);
132+
}
133+
}
134+
135+
TEST(ewlayer, new_ewlayer_can_sigmoid_int) {
136+
EWLayer layer("sigmoid");
137+
Tensor input = make_tensor<int>({0, -100, 100, 1, -1});
138+
Tensor output;
139+
layer.run(input, output);
140+
141+
std::vector<int> expected = {1, 0, 1, 1, 0};
142+
for (size_t i = 0; i < expected.size(); ++i) {
143+
EXPECT_EQ((*output.as<int>())[i], expected[i]);
144+
}
145+
}
146+
147+
TEST(ewlayer, new_ewlayer_can_sigmoid_float_extreme_values) {
148+
EWLayer layer("sigmoid");
149+
Tensor input = make_tensor<float>({0.0F, -1.0F, 1.0F, 2.0F, -100.0F, 100.0F});
150+
Tensor output;
151+
152+
auto stable_sigmoid = [](float x) {
153+
if (x >= 0) {
154+
float z = std::exp(-x);
155+
return 1.0F / (1.0F + z);
156+
} else {
157+
float z = std::exp(x);
158+
return z / (1.0F + z);
159+
}
160+
};
161+
162+
std::vector<float> expected_output = {
163+
stable_sigmoid(0.0F), stable_sigmoid(-1.0F), stable_sigmoid(1.0F),
164+
stable_sigmoid(2.0F), stable_sigmoid(-100.0F), stable_sigmoid(100.0F)};
165+
166+
layer.run(input, output);
167+
168+
for (size_t i = 0; i < expected_output.size(); i++) {
169+
EXPECT_NEAR((*output.as<float>())[i], expected_output[i], 1e-5F);
170+
}
171+
}

0 commit comments

Comments
 (0)