@@ -40,9 +40,6 @@ void SplitLayer::split_impl(const Tensor& input,
4040 }
4141 }
4242
43- outputs.clear ();
44- outputs.reserve (part_sizes.size ());
45-
4643 size_t outer_size = 1 ;
4744 for (int i = 0 ; i < axis; ++i) {
4845 outer_size *= shape[i];
@@ -53,6 +50,12 @@ void SplitLayer::split_impl(const Tensor& input,
5350 inner_size *= shape[i];
5451 }
5552
53+ const size_t input_axis_stride = shape[axis] * inner_size;
54+ const size_t input_inner_stride = inner_size;
55+
56+ outputs.clear ();
57+ outputs.reserve (part_sizes.size ());
58+
5659 size_t input_offset = 0 ;
5760 for (size_t part = 0 ; part < part_sizes.size (); ++part) {
5861 const size_t output_axis_size = part_sizes[part];
@@ -66,17 +69,17 @@ void SplitLayer::split_impl(const Tensor& input,
6669 outputs.emplace_back (output_shape, input.get_type ());
6770 auto & output_data = *outputs.back ().as <T>();
6871
72+ const size_t output_part_size = output_axis_size * inner_size;
73+ const size_t input_part_size = output_part_size;
74+
6975 for (size_t outer = 0 ; outer < outer_size; ++outer) {
70- for (size_t a = 0 ; a < output_axis_size; ++a) {
71- for (size_t inner = 0 ; inner < inner_size; ++inner) {
72- size_t input_pos = outer * shape[axis] * inner_size +
73- (input_offset + a) * inner_size + inner;
74- size_t output_pos =
75- outer * output_axis_size * inner_size + a * inner_size + inner;
76- output_data[output_pos] = input_data[input_pos];
77- }
78- }
76+ const T* input_start =
77+ &input_data[outer * input_axis_stride + input_offset * inner_size];
78+ T* output_start = &output_data[outer * output_part_size];
79+
80+ std::copy_n (input_start, output_part_size, output_start);
7981 }
82+
8083 input_offset += output_axis_size;
8184 }
8285}
@@ -107,6 +110,14 @@ void SplitLayer::validate(const Tensor& input) const {
107110 throw std::runtime_error (" num_outputs cannot be greater than axis size" );
108111 }
109112 }
113+
114+ if (!splits_ && num_outputs_) {
115+ if (*num_outputs_ > axis_size) {
116+ throw std::runtime_error (" num_outputs (" + std::to_string (*num_outputs_) +
117+ " ) cannot be greater than axis size (" +
118+ std::to_string (axis_size) + " )" );
119+ }
120+ }
110121}
111122
112123int SplitLayer::get_normalized_axis (int rank) const {
0 commit comments