Skip to content

Commit ea8d8cb

Browse files
author
Oleksii Moiseenko
committed
minor things;
1 parent 47649f6 commit ea8d8cb

5 files changed

Lines changed: 15 additions & 35 deletions

File tree

NNToolkitCore.podspec

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ Pod::Spec.new do |s|
2525
s.source_files =
2626
'nntoolkitcore/core/*.h',
2727
'nntoolkitcore/core/debug.c',
28-
# 'nntoolkitcore/core/apple_ops.c',
29-
'nntoolkitcore/core/default_ops.cc',
28+
'nntoolkitcore/core/apple_ops.c',
3029
'nntoolkitcore/core/memory.c',
3130
'nntoolkitcore/layers/**/*',
3231
'nntoolkitcore/train/*.{h,c}',

nntoolkitcore/layers/conv_1d.c

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,6 @@ int Conv1dApplyTrainingBatch(Conv1d filter, const float *input, float *output) {
183183
}
184184

185185
void Conv1dCalculateGradient(Conv1d filter, ConvGradient *gradient, const float *d_out) {
186-
// int db_size = filter->config.output_feature_channels *
187-
// filter->training_data->config.mini_batch_size;
188-
// for (int o = 0; o < filter->config.output_size; ++o) {
189-
// op_vec_add(gradient->d_b,d_out + o * db_size, gradient->d_b, db_size);
190-
// }
191186
int k_size = filter->config.kernel_size;
192187
int batch = filter->training_data->config.mini_batch_size;
193188
int in_ftrs = filter->config.input_feature_channels;

nntoolkitcore/layers/dense.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ DenseGradient *DenseGradientCreate(DenseConfig config, DenseTrainingConfig train
103103
);
104104
}
105105

106+
DenseGradient *DenseGradientCreateFromFilter(Dense dense) {
107+
if (dense->training_data == NULL){
108+
return NULL;
109+
}
110+
return DenseGradientCreate(dense->config, dense->training_data->config);
111+
}
112+
113+
106114
void DenseGradientDestroy(DenseGradient *gradient) {
107115
default_gradient_destroy(gradient);
108116
}
@@ -176,3 +184,4 @@ void DenseCalculateGradient(Dense filter, DenseGradient *gradient, float *d_out)
176184
default_gradient_sum(filter->training_data->batch_gradients, gradient, dense_weight_size_from_config(filter->config), batch);
177185
}
178186

187+

nntoolkitcore/layers/dense.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ void DenseGradientDestroy(DenseGradient *gradient);
3636

3737
typedef struct DenseStruct* Dense;
3838

39+
DenseGradient* DenseGradientCreateFromFilter(Dense dense);
40+
3941
DenseWeights* DenseGetWeights(Dense filter);
4042

4143
DenseConfig DenseConfigCreate(int input_size, int output_size, ActivationFunction activation);

nntoolkitcore/layers/time_distributed_dense.c

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,13 @@
55
// Created by Alex on 21.10.2020.
66
//
77

8-
#include <nntoolkitcore/core/memory.h>
98
#include "nntoolkitcore/layers/time_distributed_dense.h"
10-
#include "nntoolkitcore/core/ops.h"
119
#include "nntoolkitcore/core/loop.h"
1210
#include "stdlib.h"
1311

1412

15-
typedef struct {
16-
TimeDistributedDenseTrainingConfig config;
17-
DenseTrainingConfig dense_config;
18-
} TimeDistributedDenseTrainingData;
19-
2013
struct TimeDistributedDenseStruct {
2114
TimeDistributedDenseConfig config;
22-
TimeDistributedDenseTrainingData *training_data;
2315
Dense dense;
2416
};
2517

@@ -39,18 +31,13 @@ TimeDistributedDense TimeDistributedDenseCreate(TimeDistributedDenseConfig confi
3931
TimeDistributedDense TimeDistributedDenseCreateForInference(TimeDistributedDenseConfig config){
4032
TimeDistributedDense ts_filter = TimeDistributedDenseCreate(config);
4133
ts_filter->dense = DenseCreateForInference(config.dense);
42-
ts_filter->training_data = NULL;
4334
return ts_filter;
4435
}
4536

46-
4737
TimeDistributedDense TimeDistributedDenseCreateForTraining(TimeDistributedDenseConfig config, TimeDistributedDenseTrainingConfig training_config) {
4838
TimeDistributedDense ts_filter = TimeDistributedDenseCreate(config);
4939
DenseTrainingConfig dense_training_config = DefaultTrainingConfigCreate(training_config.mini_batch_size * config.ts);
5040
ts_filter->dense = DenseCreateForTraining(config.dense, dense_training_config);
51-
ts_filter->training_data = malloc(sizeof(TimeDistributedDenseTrainingData));
52-
ts_filter->training_data->config = training_config;
53-
ts_filter->training_data->dense_config = dense_training_config;
5441
return ts_filter;
5542
}
5643

@@ -59,18 +46,10 @@ DenseWeights* TimeDistributedDenseGetWeights(TimeDistributedDense filter) {
5946
}
6047

6148
DenseGradient* TimeDistributedDenseGradientCreate(TimeDistributedDense filter){
62-
return DenseGradientCreate(
63-
filter->config.dense,
64-
DefaultTrainingConfigCreate(filter->config.ts *
65-
filter->training_data->config.mini_batch_size
66-
));
49+
return DenseGradientCreateFromFilter(filter->dense);
6750
}
6851

69-
7052
int TimeDistributedDenseApplyInference(TimeDistributedDense filter, const float *input, float* output){
71-
if(filter->training_data != NULL){
72-
return -1;
73-
}
7453
P_LOOP_START(filter->config.ts, ts)
7554
DenseApplyInference(filter->dense, input + ts * filter->config.dense.input_size,
7655
output + ts * filter->config.dense.output_size);
@@ -79,11 +58,7 @@ int TimeDistributedDenseApplyInference(TimeDistributedDense filter, const float
7958
}
8059

8160
int TimeDistributedDenseApplyTrainingBatch(TimeDistributedDense filter, const float *input, float* output){
82-
if (filter->training_data == NULL){
83-
return -1;
84-
}
85-
DenseApplyTrainingBatch(filter->dense, input, output);
86-
return 0;
61+
return DenseApplyTrainingBatch(filter->dense, input, output);
8762
}
8863

8964
void TimeDistributedDenseCalculateGradient(TimeDistributedDense filter, DenseGradient *gradient, float *d_out){
@@ -93,4 +68,4 @@ void TimeDistributedDenseCalculateGradient(TimeDistributedDense filter, DenseGra
9368
void TimeDistributedDenseDestroy(TimeDistributedDense filter){
9469
DenseDestroy(filter->dense);
9570
free(filter);
96-
}
71+
}

0 commit comments

Comments
 (0)