1010#include "nntoolkitcore/core/loop.h"
1111#include "nntoolkitcore/core/ops.h"
1212#include "nntoolkitcore/core/memory.h"
13+ #include "nntoolkitcore/layers/private/weights_private.h"
1314
1415typedef struct {
1516 ConvTrainingConfig config ;
1617 float * input_transposed ;
18+ DefaultGradient * * batch_gradients ;
1719} Conv1dTrainingData ;
1820
1921typedef struct {
@@ -27,26 +29,43 @@ struct Conv1dStruct {
2729 Conv1dTrainingData * training_data ;
2830};
2931
32+ ConvWeightsSize conv1d_weight_size_from_config (Conv1dConfig config ) {
33+ int w_size = config .kernel_size * config .input_feature_channels * config .output_feature_channels ;
34+ int sum = w_size + config .output_feature_channels ;
35+ return (DefaultWeightsSize ) {.w = w_size , .b = config .output_feature_channels , .sum = sum };
36+ }
37+
38+
39+ static Conv1dInferenceData * conv1d_inference_data_create (Conv1dConfig config ) {
40+ Conv1dInferenceData * data = malloc (sizeof (Conv1dInferenceData ));
41+ data -> buffer = malloc (config .input_size * config .input_feature_channels * sizeof (float ));
42+ return data ;
43+ }
44+
3045static Conv1dTrainingData * conv1d_training_data_create (Conv1dConfig config , ConvTrainingConfig training_config ) {
3146 Conv1dTrainingData * data = malloc (sizeof (Conv1dTrainingData ));
47+ int b = training_config .mini_batch_size ;
3248 data -> config = training_config ;
3349 data -> input_transposed = malloc (config .input_feature_channels * config .input_size
34- * training_config .mini_batch_size * sizeof (float ));
50+ * b * sizeof (float ));
51+ data -> batch_gradients = malloc (b * sizeof (DefaultGradient * ));
52+ for (int i = 0 ; i < b ; ++ i ) {
53+ data -> batch_gradients [i ] = default_gradient_create (conv1d_weight_size_from_config (config ), 0 );
54+ }
3555 return data ;
3656}
3757
3858static void conv_training_data_destroy (Conv1dTrainingData * training_data ) {
59+ for (int i = 0 ; i < training_data -> config .mini_batch_size ; ++ i ) {
60+ default_gradient_destroy (training_data -> batch_gradients [i ]);
61+ }
62+ free (training_data -> batch_gradients );
3963 free (training_data -> input_transposed );
4064 free (training_data );
4165}
4266
43- static Conv1dInferenceData * conv1d_inference_data_create (Conv1dConfig config ) {
44- Conv1dInferenceData * data = malloc (sizeof (Conv1dInferenceData ));
45- data -> buffer = malloc (config .input_size * config .input_feature_channels * sizeof (float ));
46- return data ;
47- }
4867
49- static void conv1d_inference_data_destroy (Conv1dInferenceData * data ) {
68+ static void conv1d_inference_data_destroy (Conv1dInferenceData * data ) {
5069 free (data -> buffer );
5170 free (data );
5271}
@@ -70,11 +89,7 @@ Conv1dConfig Conv1dConfigCreate(int input_feature_channels, int output_feature_c
7089Conv1d conv1d_create (Conv1dConfig config ) {
7190 Conv1d filter = malloc (sizeof (struct Conv1dStruct ));
7291 filter -> config = config ;
73- filter -> weights = malloc (sizeof (ConvWeights ));
74- int W_size = config .kernel_size * config .input_feature_channels * config .output_feature_channels ;
75- int weights_size = W_size + config .output_feature_channels ;
76- filter -> weights -> W = f_malloc (weights_size );
77- filter -> weights -> b = filter -> weights -> W + W_size ;
92+ filter -> weights = default_weights_create (conv1d_weight_size_from_config (config ));
7893 filter -> training_data = NULL ;
7994 filter -> inference_data = NULL ;
8095 return filter ;
@@ -140,19 +155,13 @@ int Conv1dApplyInference(Conv1d filter, const float *input, float *output) {
140155}
141156
142157ConvGradient * Conv1dCreateGradient (Conv1dConfig config , ConvTrainingConfig training_config ) {
143- ConvGradient * gradient = malloc (sizeof (ConvGradient ));
144- int d_x_size = config .input_size * config .input_feature_channels * training_config .mini_batch_size ;
145- int d_w_size = config .input_feature_channels * config .output_feature_channels * config .kernel_size * training_config .mini_batch_size ;
146- int grad_size = d_x_size + d_w_size + config .output_feature_channels * training_config .mini_batch_size ;
147- gradient -> d_W = f_malloc (grad_size );
148- gradient -> d_X = gradient -> d_W + d_w_size ;
149- gradient -> d_b = gradient -> d_X + d_x_size ;
150- return gradient ;
158+ return default_gradient_create (conv1d_weight_size_from_config (config ),
159+ training_config .mini_batch_size *
160+ config .input_size * config .input_feature_channels );
151161}
152162
153163void ConvGradientDestroy (ConvGradient * gradient ) {
154- free (gradient -> d_W );
155- free (gradient );
164+ default_gradient_destroy (gradient );
156165}
157166
158167int Conv1dApplyTrainingBatch (Conv1d filter , const float * input , float * output ) {
@@ -174,13 +183,6 @@ int Conv1dApplyTrainingBatch(Conv1d filter, const float *input, float *output) {
174183}
175184
176185void Conv1dCalculateGradient (Conv1d filter , ConvGradient * gradient , const float * d_out ) {
177- int db_size = filter -> config .output_feature_channels *
178- filter -> training_data -> config .mini_batch_size ;
179- for (int o = 0 ; o < filter -> config .output_size ; ++ o ){
180- op_vec_add (gradient -> d_b , d_out + o * db_size , gradient -> d_b , db_size );
181- }
182-
183-
184186 int k_size = filter -> config .kernel_size ;
185187 int batch = filter -> training_data -> config .mini_batch_size ;
186188 int in_ftrs = filter -> config .input_feature_channels ;
@@ -199,6 +201,12 @@ void Conv1dCalculateGradient(Conv1d filter, ConvGradient *gradient, const float
199201 // out_n d4 d5 d6
200202
201203 for (int b = 0 ; b < batch ; ++ b ) {
204+ //db
205+ float * db_batched = filter -> training_data -> batch_gradients [b ]-> d_b ;
206+ for (int o = 0 ; o < filter -> config .output_size ; ++ o ){
207+ op_vec_add (db_batched , d_out + o * out_ftrs + b * out_size , db_batched , out_ftrs );
208+ }
209+
202210 for (int out_f = 0 ; out_f < out_ftrs ; ++ out_f ) {
203211 for (int out_n = 0 ; out_n < filter -> config .output_size ; ++ out_n ) {
204212
@@ -219,7 +227,7 @@ void Conv1dCalculateGradient(Conv1d filter, ConvGradient *gradient, const float
219227
220228 float d_kernel [k_size ];
221229 op_vec_mul_sc (row_ptr , d_o , d_kernel , k_size );
222- float * d_W = gradient -> d_W + W_size * b + weights_offset ;
230+ float * d_W = filter -> training_data -> batch_gradients [ b ] -> d_W + weights_offset ;
223231 op_vec_add (d_W , d_kernel , d_W , k_size );
224232
225233 // d_X;
@@ -233,6 +241,7 @@ void Conv1dCalculateGradient(Conv1d filter, ConvGradient *gradient, const float
233241 }
234242 op_mat_transp (d_x_transposed + b * inp_size , gradient -> d_X + b * inp_size , filter -> config .input_size , in_ftrs );
235243 }
244+ default_gradient_sum (filter -> training_data -> batch_gradients , gradient , conv1d_weight_size_from_config (filter -> config ), batch );
236245}
237246
238247
0 commit comments