Skip to content

Commit 68ec25c

Browse files
authored
ConCatLayer (#189)
1 parent 499c1e7 commit 68ec25c

3 files changed

Lines changed: 380 additions & 0 deletions

File tree

include/layers/ConcatLayer.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#pragma once
2+
#include <cstdint>
3+
#include <numeric>
4+
#include <stdexcept>
5+
#include <vector>
6+
7+
#include "layers/Layer.hpp"
8+
#include "layers/Tensor.hpp"
9+
10+
namespace it_lab_ai {
11+
12+
class ConcatLayer : public Layer {
13+
public:
14+
explicit ConcatLayer(int64_t axis = 0) : axis_(axis) {}
15+
16+
void run(const Tensor& input, Tensor& output) override;
17+
void run(const std::vector<Tensor>& inputs, Tensor& output);
18+
19+
static std::string get_name() { return "ConcatLayer"; }
20+
21+
private:
22+
int64_t axis_;
23+
24+
void validate_inputs(const std::vector<Tensor>& inputs) const;
25+
int64_t normalize_axis(size_t rank) const;
26+
Shape calculate_output_shape(const std::vector<Tensor>& inputs) const;
27+
28+
template <typename T>
29+
void concatenate(const std::vector<Tensor>& inputs, Tensor& output) const;
30+
};
31+
32+
} // namespace it_lab_ai

src/layers/ConcatLayer.cpp

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#include "layers/ConcatLayer.hpp"
2+
3+
namespace it_lab_ai {
4+
5+
void ConcatLayer::run(const Tensor& input, Tensor& output) { output = input; }
6+
7+
void ConcatLayer::run(const std::vector<Tensor>& inputs, Tensor& output) {
8+
if (inputs.empty()) {
9+
throw std::runtime_error("ConcatLayer: No input tensors provided");
10+
}
11+
12+
validate_inputs(inputs);
13+
14+
switch (inputs[0].get_type()) {
15+
case Type::kFloat:
16+
concatenate<float>(inputs, output);
17+
break;
18+
case Type::kInt:
19+
concatenate<int>(inputs, output);
20+
break;
21+
default:
22+
throw std::runtime_error("ConcatLayer: Unsupported input tensor type");
23+
}
24+
}
25+
26+
void ConcatLayer::validate_inputs(const std::vector<Tensor>& inputs) const {
27+
if (inputs.empty()) return;
28+
29+
const Shape& first_shape = inputs[0].get_shape();
30+
Type first_type = inputs[0].get_type();
31+
const int64_t normalized_axis = normalize_axis(first_shape.dims());
32+
33+
for (size_t i = 1; i < inputs.size(); ++i) {
34+
const Shape& shape = inputs[i].get_shape();
35+
if (shape.dims() != first_shape.dims()) {
36+
throw std::runtime_error(
37+
"ConcatLayer: All input tensors must have the same rank");
38+
}
39+
40+
if (inputs[i].get_type() != first_type) {
41+
throw std::runtime_error(
42+
"ConcatLayer: All input tensors must have the same type");
43+
}
44+
45+
for (size_t dim = 0; dim < shape.dims(); ++dim) {
46+
if (dim != static_cast<size_t>(normalized_axis) &&
47+
shape[dim] != first_shape[dim]) {
48+
throw std::runtime_error(
49+
"ConcatLayer: All input tensors must have the same shape except "
50+
"for the concatenation axis");
51+
}
52+
}
53+
}
54+
}
55+
56+
int64_t ConcatLayer::normalize_axis(size_t rank) const {
57+
if (rank == 0) {
58+
throw std::runtime_error("ConcatLayer: Cannot concatenate scalar tensors");
59+
}
60+
61+
int64_t axis = axis_;
62+
63+
if (axis < 0) {
64+
axis += static_cast<int64_t>(rank);
65+
}
66+
67+
if (axis < 0 || axis >= static_cast<int64_t>(rank)) {
68+
throw std::runtime_error("ConcatLayer: Axis " + std::to_string(axis_) +
69+
" out of range for tensor rank " +
70+
std::to_string(rank));
71+
}
72+
73+
return axis;
74+
}
75+
76+
Shape ConcatLayer::calculate_output_shape(
77+
const std::vector<Tensor>& inputs) const {
78+
if (inputs.empty()) return Shape({});
79+
80+
const Shape& first_shape = inputs[0].get_shape();
81+
std::vector<size_t> output_dims(first_shape.dims());
82+
for (size_t i = 0; i < first_shape.dims(); ++i) {
83+
output_dims[i] = first_shape[i];
84+
}
85+
86+
const int64_t normalized_axis = normalize_axis(first_shape.dims());
87+
output_dims[normalized_axis] = 0;
88+
for (const auto& input : inputs) {
89+
output_dims[normalized_axis] += input.get_shape()[normalized_axis];
90+
}
91+
92+
return Shape(output_dims);
93+
}
94+
95+
template <typename T>
96+
void ConcatLayer::concatenate(const std::vector<Tensor>& inputs,
97+
Tensor& output) const {
98+
Shape output_shape = calculate_output_shape(inputs);
99+
std::vector<T> output_data(output_shape.count(), 0);
100+
101+
const int64_t axis = normalize_axis(inputs[0].get_shape().dims());
102+
const size_t outer_size = [&]() {
103+
size_t size = 1;
104+
for (int64_t i = 0; i < axis; ++i) {
105+
size *= output_shape[i];
106+
}
107+
return size;
108+
}();
109+
110+
const size_t inner_size = [&]() {
111+
size_t size = 1;
112+
for (size_t i = axis + 1; i < output_shape.dims(); ++i) {
113+
size *= output_shape[i];
114+
}
115+
return size;
116+
}();
117+
118+
size_t output_offset = 0;
119+
120+
for (const auto& input : inputs) {
121+
const auto& input_data = *input.as<T>();
122+
const Shape& input_shape = input.get_shape();
123+
const size_t input_axis_size = input_shape[axis];
124+
125+
for (size_t outer = 0; outer < outer_size; ++outer) {
126+
for (size_t a = 0; a < input_axis_size; ++a) {
127+
for (size_t inner = 0; inner < inner_size; ++inner) {
128+
size_t input_pos =
129+
outer * input_axis_size * inner_size + a * inner_size + inner;
130+
131+
size_t output_pos = outer * output_shape[axis] * inner_size +
132+
(output_offset + a) * inner_size + inner;
133+
134+
output_data[output_pos] = input_data[input_pos];
135+
}
136+
}
137+
}
138+
139+
output_offset += input_axis_size;
140+
}
141+
142+
output = make_tensor(output_data, output_shape);
143+
}
144+
145+
template void ConcatLayer::concatenate<float>(const std::vector<Tensor>&,
146+
Tensor&) const;
147+
template void ConcatLayer::concatenate<int>(const std::vector<Tensor>&,
148+
Tensor&) const;
149+
150+
} // namespace it_lab_ai
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#include <vector>
2+
3+
#include "gtest/gtest.h"
4+
#include "layers/ConcatLayer.hpp"
5+
#include "layers/Tensor.hpp"
6+
7+
using namespace it_lab_ai;
8+
9+
TEST(ConcatLayerTests, ConcatEmptyTensors) {
10+
ConcatLayer layer(0);
11+
12+
Tensor empty1 = make_tensor<float>({}, {0});
13+
Tensor empty2 = make_tensor<float>({}, {2, 0, 3});
14+
15+
Tensor output;
16+
17+
EXPECT_THROW(layer.run({empty1, empty2}, output), std::runtime_error);
18+
}
19+
20+
TEST(ConcatLayerTests, ConcatSingleElementTensors) {
21+
ConcatLayer layer(0);
22+
23+
Tensor single1 = make_tensor<float>({42.0f}, {1});
24+
Tensor single2 = make_tensor<float>({99.0f}, {1});
25+
26+
Tensor output;
27+
28+
layer.run({single1, single2}, output);
29+
30+
ASSERT_EQ(output.get_shape(), Shape({2}));
31+
EXPECT_FLOAT_EQ(output.get<float>({0}), 42.0f);
32+
EXPECT_FLOAT_EQ(output.get<float>({1}), 99.0f);
33+
}
34+
35+
TEST(ConcatLayerTests, ConcatAlongAxisWithSize1) {
36+
ConcatLayer layer(0);
37+
38+
Tensor input1 = make_tensor<float>({1, 2, 3, 4, 5, 6}, {1, 3, 2});
39+
Tensor input2 = make_tensor<float>({7, 8, 9, 10, 11, 12}, {1, 3, 2});
40+
41+
Tensor output;
42+
43+
layer.run({input1, input2}, output);
44+
45+
ASSERT_EQ(output.get_shape(), Shape({2, 3, 2}));
46+
47+
EXPECT_FLOAT_EQ(output.get<float>({0, 0, 0}), 1.0f);
48+
EXPECT_FLOAT_EQ(output.get<float>({0, 0, 1}), 2.0f);
49+
EXPECT_FLOAT_EQ(output.get<float>({0, 1, 0}), 3.0f);
50+
EXPECT_FLOAT_EQ(output.get<float>({0, 1, 1}), 4.0f);
51+
EXPECT_FLOAT_EQ(output.get<float>({0, 2, 0}), 5.0f);
52+
EXPECT_FLOAT_EQ(output.get<float>({0, 2, 1}), 6.0f);
53+
54+
EXPECT_FLOAT_EQ(output.get<float>({1, 0, 0}), 7.0f);
55+
EXPECT_FLOAT_EQ(output.get<float>({1, 0, 1}), 8.0f);
56+
EXPECT_FLOAT_EQ(output.get<float>({1, 1, 0}), 9.0f);
57+
EXPECT_FLOAT_EQ(output.get<float>({1, 1, 1}), 10.0f);
58+
EXPECT_FLOAT_EQ(output.get<float>({1, 2, 0}), 11.0f);
59+
EXPECT_FLOAT_EQ(output.get<float>({1, 2, 1}), 12.0f);
60+
}
61+
62+
TEST(ConcatLayerTests, ConcatScalars) {
63+
ConcatLayer layer(0);
64+
65+
Tensor scalar1 = make_tensor<float>({42.0f}, {});
66+
Tensor scalar2 = make_tensor<float>({99.0f}, {});
67+
68+
Tensor output;
69+
70+
EXPECT_THROW(layer.run({scalar1, scalar2}, output), std::runtime_error);
71+
}
72+
73+
TEST(ConcatLayerTests, ConcatSameShapeFloatAxis0) {
74+
ConcatLayer layer;
75+
Tensor input1 = make_tensor<float>({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
76+
Tensor input2 = make_tensor<float>({5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
77+
Tensor output;
78+
79+
layer.run({input1, input2}, output);
80+
81+
ASSERT_EQ(output.get_shape(), Shape({4, 2}));
82+
83+
EXPECT_FLOAT_EQ(output.get<float>({0, 0}), 1.0f);
84+
EXPECT_FLOAT_EQ(output.get<float>({0, 1}), 2.0f);
85+
EXPECT_FLOAT_EQ(output.get<float>({1, 0}), 3.0f);
86+
EXPECT_FLOAT_EQ(output.get<float>({1, 1}), 4.0f);
87+
88+
EXPECT_FLOAT_EQ(output.get<float>({2, 0}), 5.0f);
89+
EXPECT_FLOAT_EQ(output.get<float>({2, 1}), 6.0f);
90+
EXPECT_FLOAT_EQ(output.get<float>({3, 0}), 7.0f);
91+
EXPECT_FLOAT_EQ(output.get<float>({3, 1}), 8.0f);
92+
}
93+
94+
TEST(ConcatLayerTests, ConcatSameShapeIntAxis1) {
95+
ConcatLayer layer(1);
96+
Tensor input1 = make_tensor<int>({1, 2, 3, 4}, {2, 2});
97+
Tensor input2 = make_tensor<int>({1, 2, 3, 4}, {2, 2});
98+
Tensor output;
99+
100+
layer.run({input1, input2}, output);
101+
102+
ASSERT_EQ(output.get_shape(), Shape({2, 4}));
103+
104+
EXPECT_EQ(output.get<int>({0, 0}), 1);
105+
EXPECT_EQ(output.get<int>({0, 1}), 2);
106+
EXPECT_EQ(output.get<int>({0, 2}), 1);
107+
EXPECT_EQ(output.get<int>({0, 3}), 2);
108+
109+
EXPECT_EQ(output.get<int>({1, 0}), 3);
110+
EXPECT_EQ(output.get<int>({1, 1}), 4);
111+
EXPECT_EQ(output.get<int>({1, 2}), 3);
112+
EXPECT_EQ(output.get<int>({1, 3}), 4);
113+
}
114+
115+
TEST(ConcatLayerTests, Concat3DTensorsAxis2) {
116+
ConcatLayer layer(2);
117+
Tensor input1 = make_tensor<float>({1, 2, 3, 4, 5, 6, 7, 8}, {2, 2, 2});
118+
Tensor input2 =
119+
make_tensor<float>({9, 10, 11, 12, 13, 14, 15, 16}, {2, 2, 2});
120+
Tensor output;
121+
122+
layer.run({input1, input2}, output);
123+
124+
ASSERT_EQ(output.get_shape(), Shape({2, 2, 4}));
125+
126+
EXPECT_FLOAT_EQ(output.get<float>({0, 0, 0}), 1.0f);
127+
EXPECT_FLOAT_EQ(output.get<float>({0, 0, 1}), 2.0f);
128+
EXPECT_FLOAT_EQ(output.get<float>({0, 1, 0}), 3.0f);
129+
EXPECT_FLOAT_EQ(output.get<float>({0, 1, 1}), 4.0f);
130+
131+
EXPECT_FLOAT_EQ(output.get<float>({0, 0, 2}), 9.0f);
132+
EXPECT_FLOAT_EQ(output.get<float>({0, 0, 3}), 10.0f);
133+
EXPECT_FLOAT_EQ(output.get<float>({0, 1, 2}), 11.0f);
134+
EXPECT_FLOAT_EQ(output.get<float>({0, 1, 3}), 12.0f);
135+
136+
EXPECT_FLOAT_EQ(output.get<float>({1, 0, 0}), 5.0f);
137+
EXPECT_FLOAT_EQ(output.get<float>({1, 0, 1}), 6.0f);
138+
EXPECT_FLOAT_EQ(output.get<float>({1, 1, 0}), 7.0f);
139+
EXPECT_FLOAT_EQ(output.get<float>({1, 1, 1}), 8.0f);
140+
141+
EXPECT_FLOAT_EQ(output.get<float>({1, 0, 2}), 13.0f);
142+
EXPECT_FLOAT_EQ(output.get<float>({1, 0, 3}), 14.0f);
143+
EXPECT_FLOAT_EQ(output.get<float>({1, 1, 2}), 15.0f);
144+
EXPECT_FLOAT_EQ(output.get<float>({1, 1, 3}), 16.0f);
145+
}
146+
147+
TEST(ConcatLayerTests, NegativeAxis) {
148+
ConcatLayer layer(-1);
149+
Tensor input1 = make_tensor<float>({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
150+
Tensor input2 = make_tensor<float>({5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
151+
Tensor output;
152+
153+
layer.run({input1, input2}, output);
154+
155+
ASSERT_EQ(output.get_shape(), Shape({2, 4}));
156+
157+
EXPECT_FLOAT_EQ(output.get<float>({0, 0}), 1.0f);
158+
EXPECT_FLOAT_EQ(output.get<float>({0, 1}), 2.0f);
159+
EXPECT_FLOAT_EQ(output.get<float>({0, 2}), 5.0f);
160+
EXPECT_FLOAT_EQ(output.get<float>({0, 3}), 6.0f);
161+
162+
EXPECT_FLOAT_EQ(output.get<float>({1, 0}), 3.0f);
163+
EXPECT_FLOAT_EQ(output.get<float>({1, 1}), 4.0f);
164+
EXPECT_FLOAT_EQ(output.get<float>({1, 2}), 7.0f);
165+
EXPECT_FLOAT_EQ(output.get<float>({1, 3}), 8.0f);
166+
}
167+
168+
TEST(ConcatLayerTests, ConcatResNetStyle) {
169+
ConcatLayer layer(1);
170+
Tensor input1 = make_tensor<float>({1, 2, 3, 4, 5, 6, 7, 8}, {1, 2, 2, 2});
171+
Tensor input2 =
172+
make_tensor<float>({9, 10, 11, 12, 13, 14, 15, 16}, {1, 2, 2, 2});
173+
Tensor output;
174+
175+
layer.run({input1, input2}, output);
176+
177+
ASSERT_EQ(output.get_shape(), Shape({1, 4, 2, 2}));
178+
179+
EXPECT_FLOAT_EQ(output.get<float>({0, 0, 0, 0}), 1.0f);
180+
EXPECT_FLOAT_EQ(output.get<float>({0, 0, 0, 1}), 2.0f);
181+
EXPECT_FLOAT_EQ(output.get<float>({0, 0, 1, 0}), 3.0f);
182+
EXPECT_FLOAT_EQ(output.get<float>({0, 0, 1, 1}), 4.0f);
183+
184+
EXPECT_FLOAT_EQ(output.get<float>({0, 1, 0, 0}), 5.0f);
185+
EXPECT_FLOAT_EQ(output.get<float>({0, 1, 0, 1}), 6.0f);
186+
EXPECT_FLOAT_EQ(output.get<float>({0, 1, 1, 0}), 7.0f);
187+
EXPECT_FLOAT_EQ(output.get<float>({0, 1, 1, 1}), 8.0f);
188+
189+
EXPECT_FLOAT_EQ(output.get<float>({0, 2, 0, 0}), 9.0f);
190+
EXPECT_FLOAT_EQ(output.get<float>({0, 2, 0, 1}), 10.0f);
191+
EXPECT_FLOAT_EQ(output.get<float>({0, 2, 1, 0}), 11.0f);
192+
EXPECT_FLOAT_EQ(output.get<float>({0, 2, 1, 1}), 12.0f);
193+
194+
EXPECT_FLOAT_EQ(output.get<float>({0, 3, 0, 0}), 13.0f);
195+
EXPECT_FLOAT_EQ(output.get<float>({0, 3, 0, 1}), 14.0f);
196+
EXPECT_FLOAT_EQ(output.get<float>({0, 3, 1, 0}), 15.0f);
197+
EXPECT_FLOAT_EQ(output.get<float>({0, 3, 1, 1}), 16.0f);
198+
}

0 commit comments

Comments
 (0)