Skip to content

Commit c817b3d

Browse files
committed
change to optional, add tests, fix check with large negative axis
1 parent 38a1f2c commit c817b3d

3 files changed

Lines changed: 127 additions & 44 deletions

File tree

include/layers/SplitLayer.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#pragma once
2+
#include <optional>
3+
#include <stdexcept>
24
#include <vector>
35

46
#include "layers/Layer.hpp"
@@ -8,9 +10,10 @@ namespace it_lab_ai {
810

911
class SplitLayer : public Layer {
1012
public:
11-
SplitLayer(int axis, const std::vector<int>& splits)
12-
: axis_(axis), splits_(splits) {}
13+
SplitLayer(int axis, std::vector<int> splits)
14+
: axis_(axis), splits_(std::move(splits)) {}
1315

16+
// Ðåæèì 2: êîëè÷åñòâî âûõîäíûõ òåíçîðîâ
1417
SplitLayer(int axis, int num_outputs)
1518
: axis_(axis), num_outputs_(num_outputs) {}
1619
void run(const Tensor& input, Tensor& output) override;
@@ -24,8 +27,8 @@ class SplitLayer : public Layer {
2427

2528
private:
2629
int axis_;
27-
std::vector<int> splits_;
28-
int num_outputs_ = 0;
30+
std::optional<std::vector<int>> splits_;
31+
std::optional<int> num_outputs_;
2932

3033
void validate(const Tensor& input) const;
3134
int get_normalized_axis(int rank) const;

src/layers/SplitLayer.cpp

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,6 @@ void SplitLayer::run(const Tensor& input, Tensor& output) { output = input; }
66

77
void SplitLayer::run(const Tensor& input, std::vector<Tensor>& outputs) {
88
validate(input);
9-
const auto& shape = input.get_shape();
10-
const int axis = get_normalized_axis(static_cast<int>(shape.dims()));
11-
12-
std::vector<int> part_sizes;
13-
if (!splits_.empty()) {
14-
part_sizes = splits_;
15-
} else {
16-
const int base_size = static_cast<int>(shape[axis]) / num_outputs_;
17-
const int remainder = static_cast<int>(shape[axis]) % num_outputs_;
18-
part_sizes.assign(num_outputs_, base_size);
19-
if (remainder > 0) {
20-
part_sizes.back() += remainder;
21-
}
22-
}
23-
24-
outputs.clear();
25-
for (int size : part_sizes) {
26-
Shape out_shape = shape;
27-
out_shape[axis] = static_cast<size_t>(size);
28-
outputs.emplace_back(out_shape, input.get_type());
29-
}
309

3110
switch (input.get_type()) {
3211
case Type::kFloat:
@@ -36,7 +15,7 @@ void SplitLayer::run(const Tensor& input, std::vector<Tensor>& outputs) {
3615
split_impl<int>(input, outputs);
3716
break;
3817
default:
39-
throw std::runtime_error("Unsupported tensor type");
18+
throw std::runtime_error("Unsupported tensor data type");
4019
}
4120
}
4221

@@ -46,11 +25,23 @@ void SplitLayer::split_impl(const Tensor& input,
4625
const auto& input_data = *input.as<T>();
4726
const Shape& shape = input.get_shape();
4827
const int axis = get_normalized_axis(static_cast<int>(shape.dims()));
49-
const auto& part_sizes =
50-
splits_.empty()
51-
? std::vector<int>(num_outputs_,
52-
static_cast<int>(shape[axis]) / num_outputs_)
53-
: splits_;
28+
29+
std::vector<int> part_sizes;
30+
if (splits_) {
31+
part_sizes = *splits_;
32+
} else {
33+
const int total_size = static_cast<int>(shape[axis]);
34+
const int base_size = total_size / *num_outputs_;
35+
const int remainder = total_size % *num_outputs_;
36+
37+
part_sizes.reserve(*num_outputs_);
38+
for (int i = 0; i < *num_outputs_; ++i) {
39+
part_sizes.push_back(i < remainder ? base_size + 1 : base_size);
40+
}
41+
}
42+
43+
outputs.clear();
44+
outputs.reserve(part_sizes.size());
5445

5546
size_t outer_size = 1;
5647
for (int i = 0; i < axis; ++i) {
@@ -63,9 +54,17 @@ void SplitLayer::split_impl(const Tensor& input,
6354
}
6455

6556
size_t input_offset = 0;
66-
for (auto& output : outputs) {
67-
auto& output_data = *output.as<T>();
68-
const size_t output_axis_size = output.get_shape()[axis];
57+
for (size_t part = 0; part < part_sizes.size(); ++part) {
58+
const size_t output_axis_size = part_sizes[part];
59+
60+
std::vector<size_t> output_shape_vec(shape.dims());
61+
for (size_t i = 0; i < shape.dims(); ++i) {
62+
output_shape_vec[i] = (i == axis) ? output_axis_size : shape[i];
63+
}
64+
Shape output_shape(output_shape_vec);
65+
66+
outputs.emplace_back(output_shape, input.get_type());
67+
auto& output_data = *outputs.back().as<T>();
6968

7069
for (size_t outer = 0; outer < outer_size; ++outer) {
7170
for (size_t a = 0; a < output_axis_size; ++a) {
@@ -84,31 +83,37 @@ void SplitLayer::split_impl(const Tensor& input,
8483

8584
void SplitLayer::validate(const Tensor& input) const {
8685
if (input.get_shape().dims() == 0) {
87-
throw std::runtime_error("SplitLayer: Cannot split scalar tensor");
86+
throw std::runtime_error("Cannot split scalar tensor");
8887
}
8988

9089
const int axis =
9190
get_normalized_axis(static_cast<int>(input.get_shape().dims()));
92-
const size_t axis_size = input.get_shape()[axis];
91+
const int axis_size = static_cast<int>(input.get_shape()[axis]);
9392

94-
if (!splits_.empty()) {
93+
if (splits_) {
9594
int sum = 0;
96-
for (int s : splits_) {
95+
for (int s : *splits_) {
9796
if (s <= 0) throw std::runtime_error("Split size must be positive");
9897
sum += s;
9998
}
100-
if (sum != static_cast<int>(axis_size)) {
99+
if (sum != axis_size) {
101100
throw std::runtime_error("Sum of splits must match axis size");
102101
}
103-
} else if (num_outputs_ <= 0) {
104-
throw std::runtime_error("num_outputs must be positive");
102+
} else {
103+
if (*num_outputs_ <= 0) {
104+
throw std::runtime_error("num_outputs must be positive");
105+
}
106+
if (*num_outputs_ > axis_size) {
107+
throw std::runtime_error("num_outputs cannot be greater than axis size");
108+
}
105109
}
106110
}
107111

108112
int SplitLayer::get_normalized_axis(int rank) const {
109-
if (axis_ < 0) return axis_ + rank;
110-
if (axis_ >= rank) throw std::runtime_error("Axis out of bounds");
111-
return axis_;
113+
if (axis_ < -rank || axis_ >= rank) {
114+
throw std::runtime_error("Axis out of bounds");
115+
}
116+
return (axis_ < 0) ? axis_ + rank : axis_;
112117
}
113118

114119
template void SplitLayer::split_impl<float>(const Tensor&,

test/single_layer/test_splitlayer.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,79 @@ TEST(SplitLayerTests, Split192IntoTwo96) {
138138
EXPECT_EQ(outputs[1].get_shape(), Shape({1, 96, 56, 56}));
139139
EXPECT_FLOAT_EQ(outputs[0].get<float>({0, 0, 0, 0}), 0.0f);
140140
EXPECT_FLOAT_EQ(outputs[1].get<float>({0, 0, 0, 0}), 96 * 56 * 56);
141+
}
142+
143+
TEST(SplitLayerTests, UnevenSplitWithRemainder) {
144+
Tensor input = make_tensor<float>({1, 2, 3, 4, 5}, {5});
145+
SplitLayer splitter(0, 3);
146+
147+
std::vector<Tensor> outputs;
148+
splitter.run(input, outputs);
149+
150+
ASSERT_EQ(outputs.size(), 3);
151+
EXPECT_EQ(outputs[0].get_shape(), Shape({2}));
152+
EXPECT_EQ(outputs[1].get_shape(), Shape({2}));
153+
EXPECT_EQ(outputs[2].get_shape(), Shape({1}));
154+
EXPECT_FLOAT_EQ(outputs[0].get<float>({1}), 2.0f);
155+
EXPECT_FLOAT_EQ(outputs[1].get<float>({1}), 4.0f);
156+
EXPECT_FLOAT_EQ(outputs[2].get<float>({0}),
157+
5.0f);
158+
}
159+
160+
TEST(SplitLayerTests, NumOutputsGreaterThanAxisSize) {
161+
Tensor input = make_tensor<float>({1, 2, 3}, {3});
162+
SplitLayer splitter(0, 5);
163+
164+
std::vector<Tensor> outputs;
165+
EXPECT_THROW(splitter.run(input, outputs), std::runtime_error);
166+
}
167+
168+
TEST(SplitLayerTests, IntegerDataType) {
169+
Tensor input = make_tensor<int>({1, 2, 3, 4, 5, 6}, {2, 3});
170+
SplitLayer splitter(1, {1, 2});
171+
172+
std::vector<Tensor> outputs;
173+
splitter.run(input, outputs);
174+
175+
ASSERT_EQ(outputs.size(), 2);
176+
EXPECT_EQ(outputs[0].get_shape(), Shape({2, 1}));
177+
EXPECT_EQ(outputs[1].get_shape(), Shape({2, 2}));
178+
EXPECT_EQ(outputs[0].get<int>({1, 0}), 4);
179+
EXPECT_EQ(outputs[1].get<int>({0, 1}), 3);
180+
}
181+
182+
TEST(SplitLayerTests, NegativeAxis2D) {
183+
Tensor input = make_tensor<float>({1, 2, 3, 4}, {2, 2});
184+
SplitLayer splitter(-2, {1, 1});
185+
186+
std::vector<Tensor> outputs;
187+
splitter.run(input, outputs);
188+
189+
ASSERT_EQ(outputs.size(), 2);
190+
EXPECT_EQ(outputs[0].get_shape(), Shape({1, 2}));
191+
EXPECT_EQ(outputs[1].get_shape(), Shape({1, 2}));
192+
}
193+
194+
TEST(SplitLayerTests, NegativeAxis3D) {
195+
std::vector<float> data(2 * 3 * 4);
196+
std::iota(data.begin(), data.end(), 1.0f);
197+
Tensor input = make_tensor<float>(data, {2, 3, 4});
198+
199+
SplitLayer splitter(-1, {1, 3});
200+
201+
std::vector<Tensor> outputs;
202+
splitter.run(input, outputs);
203+
204+
ASSERT_EQ(outputs.size(), 2);
205+
EXPECT_EQ(outputs[0].get_shape(), Shape({2, 3, 1}));
206+
EXPECT_EQ(outputs[1].get_shape(), Shape({2, 3, 3}));
207+
EXPECT_FLOAT_EQ(outputs[0].get<float>({1, 2, 0}), 21.0f);
208+
}
209+
210+
TEST(SplitLayerTests, LargeAxisValue) {
211+
Tensor input = make_tensor<float>({1, 2, 3, 4}, {2, 2});
212+
213+
SplitLayer splitter(10, {1, 1});
214+
std::vector<Tensor> outputs;
215+
EXPECT_THROW(splitter.run(input, outputs), std::runtime_error);
141216
}

0 commit comments

Comments
 (0)