Skip to content

Commit 4e5a6c6

Browse files
authored
SplitLayer (#192)
1 parent b35c658 commit 4e5a6c6

3 files changed

Lines changed: 384 additions & 0 deletions

File tree

include/layers/SplitLayer.hpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#pragma once
2+
#include <optional>
3+
#include <stdexcept>
4+
#include <vector>
5+
6+
#include "layers/Layer.hpp"
7+
#include "layers/Tensor.hpp"
8+
9+
namespace it_lab_ai {
10+
11+
class SplitLayer : public Layer {
12+
public:
13+
SplitLayer(int axis, std::vector<int> splits)
14+
: axis_(axis), splits_(std::move(splits)) {}
15+
16+
SplitLayer(int axis, int num_outputs)
17+
: axis_(axis), num_outputs_(num_outputs) {}
18+
void run(const Tensor& input, Tensor& output) override;
19+
void run(const Tensor& input, std::vector<Tensor>& outputs);
20+
21+
static std::string get_name() { return "SplitLayer"; }
22+
23+
#ifdef ENABLE_STATISTIC_WEIGHTS
24+
Tensor get_weights() override { return Tensor(); }
25+
#endif
26+
27+
private:
28+
int axis_;
29+
std::optional<std::vector<int>> splits_;
30+
std::optional<int> num_outputs_;
31+
32+
void validate(const Tensor& input) const;
33+
int get_normalized_axis(int rank) const;
34+
template <typename T>
35+
void split_impl(const Tensor& input, std::vector<Tensor>& outputs) const;
36+
};
37+
38+
} // namespace it_lab_ai

src/layers/SplitLayer.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#include "layers/SplitLayer.hpp"
2+
3+
#include <algorithm>
4+
#include <cstring>
5+
6+
namespace it_lab_ai {
7+
8+
void SplitLayer::run(const Tensor& input, Tensor& output) { output = input; }
9+
10+
void SplitLayer::run(const Tensor& input, std::vector<Tensor>& outputs) {
11+
validate(input);
12+
13+
switch (input.get_type()) {
14+
case Type::kFloat:
15+
split_impl<float>(input, outputs);
16+
break;
17+
case Type::kInt:
18+
split_impl<int>(input, outputs);
19+
break;
20+
default:
21+
throw std::runtime_error("Unsupported tensor data type");
22+
}
23+
}
24+
25+
template <typename T>
26+
void SplitLayer::split_impl(const Tensor& input,
27+
std::vector<Tensor>& outputs) const {
28+
const auto& input_data = *input.as<T>();
29+
const Shape& shape = input.get_shape();
30+
const int axis = get_normalized_axis(static_cast<int>(shape.dims()));
31+
32+
std::vector<int> part_sizes;
33+
if (splits_) {
34+
part_sizes = *splits_;
35+
} else {
36+
const int total_size = static_cast<int>(shape[axis]);
37+
const int base_size = total_size / *num_outputs_;
38+
const int remainder = total_size % *num_outputs_;
39+
40+
part_sizes.reserve(*num_outputs_);
41+
for (int i = 0; i < *num_outputs_; ++i) {
42+
part_sizes.push_back(i < remainder ? base_size + 1 : base_size);
43+
}
44+
}
45+
46+
size_t outer_size = 1;
47+
for (int i = 0; i < axis; ++i) {
48+
outer_size *= shape[i];
49+
}
50+
51+
size_t inner_size = 1;
52+
for (size_t i = axis + 1; i < shape.dims(); ++i) {
53+
inner_size *= shape[i];
54+
}
55+
56+
const size_t input_axis_stride = shape[axis] * inner_size;
57+
58+
outputs.clear();
59+
outputs.reserve(part_sizes.size());
60+
61+
size_t input_offset = 0;
62+
for (const auto part_size : part_sizes) {
63+
const auto output_axis_size = static_cast<size_t>(part_size);
64+
65+
std::vector<size_t> output_shape_vec(shape.dims());
66+
for (size_t i = 0; i < shape.dims(); ++i) {
67+
output_shape_vec[i] =
68+
(static_cast<int>(i) == axis) ? output_axis_size : shape[i];
69+
}
70+
Shape output_shape(output_shape_vec);
71+
72+
outputs.emplace_back(output_shape, input.get_type());
73+
auto& output_data = *outputs.back().as<T>();
74+
75+
const size_t output_part_size = output_axis_size * inner_size;
76+
77+
for (size_t outer = 0; outer < outer_size; ++outer) {
78+
const T* input_start =
79+
&input_data[outer * input_axis_stride + input_offset * inner_size];
80+
T* output_start = &output_data[outer * output_part_size];
81+
82+
std::copy_n(input_start, output_part_size, output_start);
83+
}
84+
85+
input_offset += output_axis_size;
86+
}
87+
}
88+
89+
void SplitLayer::validate(const Tensor& input) const {
90+
if (input.get_shape().dims() == 0) {
91+
throw std::runtime_error("Cannot split scalar tensor");
92+
}
93+
94+
const int axis =
95+
get_normalized_axis(static_cast<int>(input.get_shape().dims()));
96+
const int axis_size = static_cast<int>(input.get_shape()[axis]);
97+
98+
if (splits_) {
99+
int sum = 0;
100+
for (int s : *splits_) {
101+
if (s <= 0) throw std::runtime_error("Split size must be positive");
102+
sum += s;
103+
}
104+
if (sum != axis_size) {
105+
throw std::runtime_error("Sum of splits must match axis size");
106+
}
107+
} else {
108+
if (*num_outputs_ <= 0) {
109+
throw std::runtime_error("num_outputs must be positive");
110+
}
111+
if (*num_outputs_ > axis_size) {
112+
throw std::runtime_error("num_outputs (" + std::to_string(*num_outputs_) +
113+
") cannot be greater than axis size (" +
114+
std::to_string(axis_size) + ")");
115+
}
116+
}
117+
}
118+
119+
int SplitLayer::get_normalized_axis(int rank) const {
120+
if (axis_ < -rank || axis_ >= rank) {
121+
throw std::runtime_error("Axis out of bounds");
122+
}
123+
return (axis_ < 0) ? axis_ + rank : axis_;
124+
}
125+
126+
template void SplitLayer::split_impl<float>(const Tensor&,
127+
std::vector<Tensor>&) const;
128+
template void SplitLayer::split_impl<int>(const Tensor&,
129+
std::vector<Tensor>&) const;
130+
131+
} // namespace it_lab_ai
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#include <vector>
2+
3+
#include "gtest/gtest.h"
4+
#include "layers/SplitLayer.hpp"
5+
#include "layers/Tensor.hpp"
6+
7+
using namespace it_lab_ai;
8+
9+
TEST(SplitLayerTests, SplitEqualParts1D) {
10+
Tensor input = make_tensor<float>({1, 2, 3, 4, 5, 6}, {6});
11+
SplitLayer splitter(0, 3);
12+
13+
std::vector<Tensor> outputs;
14+
splitter.run(input, outputs);
15+
16+
ASSERT_EQ(outputs.size(), 3);
17+
EXPECT_EQ(outputs[0].get_shape(), Shape({2}));
18+
EXPECT_EQ(outputs[1].get_shape(), Shape({2}));
19+
EXPECT_EQ(outputs[2].get_shape(), Shape({2}));
20+
EXPECT_FLOAT_EQ(outputs[0].get<float>({0}), 1.0f);
21+
EXPECT_FLOAT_EQ(outputs[1].get<float>({0}), 3.0f);
22+
EXPECT_FLOAT_EQ(outputs[2].get<float>({0}), 5.0f);
23+
}
24+
25+
TEST(SplitLayerTests, SplitVariableParts1D) {
26+
Tensor input = make_tensor<float>({1, 2, 3, 4, 5, 6}, {6});
27+
SplitLayer splitter(0, {2, 4});
28+
29+
std::vector<Tensor> outputs;
30+
splitter.run(input, outputs);
31+
32+
ASSERT_EQ(outputs.size(), 2);
33+
EXPECT_EQ(outputs[0].get_shape(), Shape({2}));
34+
EXPECT_EQ(outputs[1].get_shape(), Shape({4}));
35+
EXPECT_FLOAT_EQ(outputs[0].get<float>({1}), 2.0f);
36+
EXPECT_FLOAT_EQ(outputs[1].get<float>({3}), 6.0f);
37+
}
38+
39+
TEST(SplitLayerTests, Split2DAlongAxis0) {
40+
Tensor input = make_tensor<float>({1, 2, 3, 4, 5, 6}, {2, 3});
41+
SplitLayer splitter(0, {1, 1});
42+
43+
std::vector<Tensor> outputs;
44+
splitter.run(input, outputs);
45+
46+
ASSERT_EQ(outputs.size(), 2);
47+
EXPECT_EQ(outputs[0].get_shape(), Shape({1, 3}));
48+
EXPECT_EQ(outputs[1].get_shape(), Shape({1, 3}));
49+
EXPECT_FLOAT_EQ(outputs[0].get<float>({0, 2}), 3.0f);
50+
EXPECT_FLOAT_EQ(outputs[1].get<float>({0, 0}), 4.0f);
51+
}
52+
53+
TEST(SplitLayerTests, Split2DAlongAxis1) {
54+
Tensor input = make_tensor<float>({1, 2, 3, 4, 5, 6}, {2, 3});
55+
SplitLayer splitter(1, {1, 2});
56+
57+
std::vector<Tensor> outputs;
58+
splitter.run(input, outputs);
59+
60+
ASSERT_EQ(outputs.size(), 2);
61+
EXPECT_EQ(outputs[0].get_shape(), Shape({2, 1}));
62+
EXPECT_EQ(outputs[1].get_shape(), Shape({2, 2}));
63+
EXPECT_FLOAT_EQ(outputs[0].get<float>({1, 0}), 4.0f);
64+
EXPECT_FLOAT_EQ(outputs[1].get<float>({0, 1}), 3.0f);
65+
}
66+
67+
TEST(SplitLayerTests, Split3DEqualParts) {
68+
std::vector<float> data(2 * 3 * 4);
69+
std::iota(data.begin(), data.end(), 0.0f);
70+
Tensor input = make_tensor<float>(data, {2, 3, 4});
71+
72+
SplitLayer splitter(1, 3);
73+
74+
std::vector<Tensor> outputs;
75+
splitter.run(input, outputs);
76+
77+
ASSERT_EQ(outputs.size(), 3);
78+
EXPECT_EQ(outputs[0].get_shape(), Shape({2, 1, 4}));
79+
EXPECT_EQ(outputs[1].get<float>({1, 0, 3}), 19.0f);
80+
}
81+
82+
TEST(SplitLayerTests, Split4DVariableParts) {
83+
std::vector<float> data(1 * 3 * 2 * 4);
84+
std::iota(data.begin(), data.end(), 0.0f);
85+
Tensor input = make_tensor<float>(data, {1, 3, 2, 4});
86+
87+
SplitLayer splitter(2, {1, 1});
88+
89+
std::vector<Tensor> outputs;
90+
splitter.run(input, outputs);
91+
92+
ASSERT_EQ(outputs.size(), 2);
93+
EXPECT_EQ(outputs[0].get_shape(), Shape({1, 3, 1, 4}));
94+
EXPECT_EQ(outputs[1].get<float>({0, 2, 0, 3}), 23.0f);
95+
}
96+
97+
TEST(SplitLayerTests, SplitNegativeAxis) {
98+
Tensor input = make_tensor<float>({1, 2, 3, 4, 5, 6}, {2, 3});
99+
SplitLayer splitter(-1, {1, 2});
100+
101+
std::vector<Tensor> outputs;
102+
splitter.run(input, outputs);
103+
104+
ASSERT_EQ(outputs.size(), 2);
105+
EXPECT_EQ(outputs[0].get_shape(), Shape({2, 1}));
106+
EXPECT_EQ(outputs[1].get_shape(), Shape({2, 2}));
107+
}
108+
109+
TEST(SplitLayerTests, InvalidSplitSizes) {
110+
Tensor input = make_tensor<float>({1, 2, 3, 4}, {4});
111+
112+
SplitLayer splitter(0, {1, 2});
113+
114+
std::vector<Tensor> outputs;
115+
EXPECT_THROW(splitter.run(input, outputs), std::runtime_error);
116+
}
117+
118+
TEST(SplitLayerTests, EmptyInputTensor) {
119+
Tensor input = make_tensor<float>({}, {0});
120+
121+
SplitLayer splitter(0, {});
122+
123+
std::vector<Tensor> outputs;
124+
EXPECT_THROW(splitter.run(input, outputs), std::runtime_error);
125+
}
126+
127+
TEST(SplitLayerTests, Split192IntoTwo96) {
128+
std::vector<float> input_data(1 * 192 * 56 * 56);
129+
std::iota(input_data.begin(), input_data.end(), 0.0f);
130+
Tensor input = make_tensor<float>(input_data, {1, 192, 56, 56});
131+
132+
SplitLayer splitter(1, {96, 96});
133+
std::vector<Tensor> outputs;
134+
splitter.run(input, outputs);
135+
136+
ASSERT_EQ(outputs.size(), 2);
137+
EXPECT_EQ(outputs[0].get_shape(), Shape({1, 96, 56, 56}));
138+
EXPECT_EQ(outputs[1].get_shape(), Shape({1, 96, 56, 56}));
139+
EXPECT_FLOAT_EQ(outputs[0].get<float>({0, 0, 0, 0}), 0.0f);
140+
EXPECT_FLOAT_EQ(outputs[1].get<float>({0, 0, 0, 0}), 96 * 56 * 56);
141+
}
142+
143+
TEST(SplitLayerTests, UnevenSplitWithRemainder) {
144+
Tensor input = make_tensor<float>({1, 2, 3, 4, 5}, {5});
145+
SplitLayer splitter(0, 3);
146+
147+
std::vector<Tensor> outputs;
148+
splitter.run(input, outputs);
149+
150+
ASSERT_EQ(outputs.size(), 3);
151+
EXPECT_EQ(outputs[0].get_shape(), Shape({2}));
152+
EXPECT_EQ(outputs[1].get_shape(), Shape({2}));
153+
EXPECT_EQ(outputs[2].get_shape(), Shape({1}));
154+
EXPECT_FLOAT_EQ(outputs[0].get<float>({1}), 2.0f);
155+
EXPECT_FLOAT_EQ(outputs[1].get<float>({1}), 4.0f);
156+
EXPECT_FLOAT_EQ(outputs[2].get<float>({0}), 5.0f);
157+
}
158+
159+
TEST(SplitLayerTests, NumOutputsGreaterThanAxisSize) {
160+
Tensor input = make_tensor<float>({1, 2, 3}, {3});
161+
SplitLayer splitter(0, 5);
162+
163+
std::vector<Tensor> outputs;
164+
EXPECT_THROW(splitter.run(input, outputs), std::runtime_error);
165+
}
166+
167+
TEST(SplitLayerTests, IntegerDataType) {
168+
Tensor input = make_tensor<int>({1, 2, 3, 4, 5, 6}, {2, 3});
169+
SplitLayer splitter(1, {1, 2});
170+
171+
std::vector<Tensor> outputs;
172+
splitter.run(input, outputs);
173+
174+
ASSERT_EQ(outputs.size(), 2);
175+
EXPECT_EQ(outputs[0].get_shape(), Shape({2, 1}));
176+
EXPECT_EQ(outputs[1].get_shape(), Shape({2, 2}));
177+
EXPECT_EQ(outputs[0].get<int>({1, 0}), 4);
178+
EXPECT_EQ(outputs[1].get<int>({0, 1}), 3);
179+
}
180+
181+
TEST(SplitLayerTests, NegativeAxis2D) {
182+
Tensor input = make_tensor<float>({1, 2, 3, 4}, {2, 2});
183+
SplitLayer splitter(-2, {1, 1});
184+
185+
std::vector<Tensor> outputs;
186+
splitter.run(input, outputs);
187+
188+
ASSERT_EQ(outputs.size(), 2);
189+
EXPECT_EQ(outputs[0].get_shape(), Shape({1, 2}));
190+
EXPECT_EQ(outputs[1].get_shape(), Shape({1, 2}));
191+
}
192+
193+
TEST(SplitLayerTests, NegativeAxis3D) {
194+
std::vector<float> data(2 * 3 * 4);
195+
std::iota(data.begin(), data.end(), 1.0f);
196+
Tensor input = make_tensor<float>(data, {2, 3, 4});
197+
198+
SplitLayer splitter(-1, {1, 3});
199+
200+
std::vector<Tensor> outputs;
201+
splitter.run(input, outputs);
202+
203+
ASSERT_EQ(outputs.size(), 2);
204+
EXPECT_EQ(outputs[0].get_shape(), Shape({2, 3, 1}));
205+
EXPECT_EQ(outputs[1].get_shape(), Shape({2, 3, 3}));
206+
EXPECT_FLOAT_EQ(outputs[0].get<float>({1, 2, 0}), 21.0f);
207+
}
208+
209+
TEST(SplitLayerTests, LargeAxisValue) {
210+
Tensor input = make_tensor<float>({1, 2, 3, 4}, {2, 2});
211+
212+
SplitLayer splitter(10, {1, 1});
213+
std::vector<Tensor> outputs;
214+
EXPECT_THROW(splitter.run(input, outputs), std::runtime_error);
215+
}

0 commit comments

Comments
 (0)