Skip to content

Commit f96de58

Browse files
committed
use block copies, fix check num_outputs_ > size_along_axis
1 parent c817b3d commit f96de58

3 files changed

Lines changed: 24 additions & 15 deletions

File tree

include/layers/SplitLayer.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ class SplitLayer : public Layer {
1313
SplitLayer(int axis, std::vector<int> splits)
1414
: axis_(axis), splits_(std::move(splits)) {}
1515

16-
// Ðåæèì 2: êîëè÷åñòâî âûõîäíûõ òåíçîðîâ
1716
SplitLayer(int axis, int num_outputs)
1817
: axis_(axis), num_outputs_(num_outputs) {}
1918
void run(const Tensor& input, Tensor& output) override;

src/layers/SplitLayer.cpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ void SplitLayer::split_impl(const Tensor& input,
4040
}
4141
}
4242

43-
outputs.clear();
44-
outputs.reserve(part_sizes.size());
45-
4643
size_t outer_size = 1;
4744
for (int i = 0; i < axis; ++i) {
4845
outer_size *= shape[i];
@@ -53,6 +50,12 @@ void SplitLayer::split_impl(const Tensor& input,
5350
inner_size *= shape[i];
5451
}
5552

53+
const size_t input_axis_stride = shape[axis] * inner_size;
54+
const size_t input_inner_stride = inner_size;
55+
56+
outputs.clear();
57+
outputs.reserve(part_sizes.size());
58+
5659
size_t input_offset = 0;
5760
for (size_t part = 0; part < part_sizes.size(); ++part) {
5861
const size_t output_axis_size = part_sizes[part];
@@ -66,17 +69,17 @@ void SplitLayer::split_impl(const Tensor& input,
6669
outputs.emplace_back(output_shape, input.get_type());
6770
auto& output_data = *outputs.back().as<T>();
6871

72+
const size_t output_part_size = output_axis_size * inner_size;
73+
const size_t input_part_size = output_part_size;
74+
6975
for (size_t outer = 0; outer < outer_size; ++outer) {
70-
for (size_t a = 0; a < output_axis_size; ++a) {
71-
for (size_t inner = 0; inner < inner_size; ++inner) {
72-
size_t input_pos = outer * shape[axis] * inner_size +
73-
(input_offset + a) * inner_size + inner;
74-
size_t output_pos =
75-
outer * output_axis_size * inner_size + a * inner_size + inner;
76-
output_data[output_pos] = input_data[input_pos];
77-
}
78-
}
76+
const T* input_start =
77+
&input_data[outer * input_axis_stride + input_offset * inner_size];
78+
T* output_start = &output_data[outer * output_part_size];
79+
80+
std::copy_n(input_start, output_part_size, output_start);
7981
}
82+
8083
input_offset += output_axis_size;
8184
}
8285
}
@@ -107,6 +110,14 @@ void SplitLayer::validate(const Tensor& input) const {
107110
throw std::runtime_error("num_outputs cannot be greater than axis size");
108111
}
109112
}
113+
114+
if (!splits_ && num_outputs_) {
115+
if (*num_outputs_ > axis_size) {
116+
throw std::runtime_error("num_outputs (" + std::to_string(*num_outputs_) +
117+
") cannot be greater than axis size (" +
118+
std::to_string(axis_size) + ")");
119+
}
120+
}
110121
}
111122

112123
int SplitLayer::get_normalized_axis(int rank) const {

test/single_layer/test_splitlayer.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ TEST(SplitLayerTests, UnevenSplitWithRemainder) {
153153
EXPECT_EQ(outputs[2].get_shape(), Shape({1}));
154154
EXPECT_FLOAT_EQ(outputs[0].get<float>({1}), 2.0f);
155155
EXPECT_FLOAT_EQ(outputs[1].get<float>({1}), 4.0f);
156-
EXPECT_FLOAT_EQ(outputs[2].get<float>({0}),
157-
5.0f);
156+
EXPECT_FLOAT_EQ(outputs[2].get<float>({0}), 5.0f);
158157
}
159158

160159
TEST(SplitLayerTests, NumOutputsGreaterThanAxisSize) {

0 commit comments

Comments
 (0)