1010#include "nntoolkitcore/core/loop.h"
1111#include "nntoolkitcore/core/ops.h"
1212#include "nntoolkitcore/core/memory.h"
13+ #include "nntoolkitcore/layers/private/weights_private.h"
1314
1415typedef struct {
1516 ConvTrainingConfig config ;
1617 float * input_transposed ;
18+ DefaultGradient * * batch_gradients ;
1719} Conv1dTrainingData ;
1820
1921typedef struct {
@@ -27,26 +29,43 @@ struct Conv1dStruct {
2729 Conv1dTrainingData * training_data ;
2830};
2931
32+ ConvWeightsSize conv1d_weight_size_from_config (Conv1dConfig config ) {
33+ int w_size = config .kernel_size * config .input_feature_channels * config .output_feature_channels ;
34+ int sum = w_size + config .output_feature_channels ;
35+ return (DefaultWeightsSize ) {.w = w_size , .b = config .output_feature_channels , .sum = sum };
36+ }
37+
38+
39+ static Conv1dInferenceData * conv1d_inference_data_create (Conv1dConfig config ) {
40+ Conv1dInferenceData * data = malloc (sizeof (Conv1dInferenceData ));
41+ data -> buffer = malloc (config .input_size * config .input_feature_channels * sizeof (float ));
42+ return data ;
43+ }
44+
3045static Conv1dTrainingData * conv1d_training_data_create (Conv1dConfig config , ConvTrainingConfig training_config ) {
3146 Conv1dTrainingData * data = malloc (sizeof (Conv1dTrainingData ));
47+ int b = training_config .mini_batch_size ;
3248 data -> config = training_config ;
3349 data -> input_transposed = malloc (config .input_feature_channels * config .input_size
34- * training_config .mini_batch_size * sizeof (float ));
50+ * b * sizeof (float ));
51+ data -> batch_gradients = malloc (b * sizeof (DefaultGradient * ));
52+ for (int i = 0 ; i < b ; ++ i ) {
53+ data -> batch_gradients [i ] = default_gradient_create (conv1d_weight_size_from_config (config ), 0 );
54+ }
3555 return data ;
3656}
3757
3858static void conv_training_data_destroy (Conv1dTrainingData * training_data ) {
59+ for (int i = 0 ; i < training_data -> config .mini_batch_size ; ++ i ) {
60+ default_gradient_destroy (training_data -> batch_gradients [i ]);
61+ }
62+ free (training_data -> batch_gradients );
3963 free (training_data -> input_transposed );
4064 free (training_data );
4165}
4266
43- static Conv1dInferenceData * conv1d_inference_data_create (Conv1dConfig config ) {
44- Conv1dInferenceData * data = malloc (sizeof (Conv1dInferenceData ));
45- data -> buffer = malloc (config .input_size * config .input_feature_channels * sizeof (float ));
46- return data ;
47- }
4867
49- static void conv1d_inference_data_destroy (Conv1dInferenceData * data ) {
68+ static void conv1d_inference_data_destroy (Conv1dInferenceData * data ) {
5069 free (data -> buffer );
5170 free (data );
5271}
@@ -67,20 +86,10 @@ Conv1dConfig Conv1dConfigCreate(int input_feature_channels, int output_feature_c
6786 return config ;
6887}
6988
70- ConvWeightsSize conv1d_weight_size_from_config (Conv1dConfig config ){
71- int w_size = config .kernel_size * config .input_feature_channels * config .output_feature_channels ;
72- int sum = w_size + config .output_feature_channels ;
73- return (DefaultWeightsSize ) { .w = w_size , .b = config .output_feature_channels , .sum = sum };
74- }
75-
7689Conv1d conv1d_create (Conv1dConfig config ) {
7790 Conv1d filter = malloc (sizeof (struct Conv1dStruct ));
7891 filter -> config = config ;
79- filter -> weights = malloc (sizeof (ConvWeights ));
80- int W_size = config .kernel_size * config .input_feature_channels * config .output_feature_channels ;
81- int weights_size = W_size + config .output_feature_channels ;
82- filter -> weights -> W = f_malloc (weights_size );
83- filter -> weights -> b = filter -> weights -> W + W_size ;
92+ filter -> weights = default_weights_create (conv1d_weight_size_from_config (config ));
8493 filter -> training_data = NULL ;
8594 filter -> inference_data = NULL ;
8695 return filter ;
@@ -146,19 +155,13 @@ int Conv1dApplyInference(Conv1d filter, const float *input, float *output) {
146155}
147156
148157ConvGradient * Conv1dCreateGradient (Conv1dConfig config , ConvTrainingConfig training_config ) {
149- ConvGradient * gradient = malloc (sizeof (ConvGradient ));
150- int d_x_size = config .input_size * config .input_feature_channels * training_config .mini_batch_size ;
151- int d_w_size = config .input_feature_channels * config .output_feature_channels * config .kernel_size * training_config .mini_batch_size ;
152- int grad_size = d_x_size + d_w_size + config .output_feature_channels * training_config .mini_batch_size ;
153- gradient -> d_W = f_malloc (grad_size );
154- gradient -> d_X = gradient -> d_W + d_w_size ;
155- gradient -> d_b = gradient -> d_X + d_x_size ;
156- return gradient ;
158+ return default_gradient_create (conv1d_weight_size_from_config (config ),
159+ training_config .mini_batch_size *
160+ config .input_size * config .input_feature_channels );
157161}
158162
159163void ConvGradientDestroy (ConvGradient * gradient ) {
160- free (gradient -> d_W );
161- free (gradient );
164+ default_gradient_destroy (gradient );
162165}
163166
164167int Conv1dApplyTrainingBatch (Conv1d filter , const float * input , float * output ) {
@@ -181,9 +184,9 @@ int Conv1dApplyTrainingBatch(Conv1d filter, const float *input, float *output) {
181184
182185void Conv1dCalculateGradient (Conv1d filter , ConvGradient * gradient , const float * d_out ) {
183186 int db_size = filter -> config .output_feature_channels *
184- filter -> training_data -> config .mini_batch_size ;
185- for (int o = 0 ; o < filter -> config .output_size ; ++ o ){
186- op_vec_add (gradient -> d_b , d_out + o * db_size , gradient -> d_b , db_size );
187+ filter -> training_data -> config .mini_batch_size ;
188+ for (int o = 0 ; o < filter -> config .output_size ; ++ o ) {
189+ op_vec_add (gradient -> d_b ,d_out + o * db_size , gradient -> d_b , db_size );
187190 }
188191
189192
@@ -225,7 +228,7 @@ void Conv1dCalculateGradient(Conv1d filter, ConvGradient *gradient, const float
225228
226229 float d_kernel [k_size ];
227230 op_vec_mul_sc (row_ptr , d_o , d_kernel , k_size );
228- float * d_W = gradient -> d_W + W_size * b + weights_offset ;
231+ float * d_W = filter -> training_data -> batch_gradients [ b ] -> d_W + weights_offset ;
229232 op_vec_add (d_W , d_kernel , d_W , k_size );
230233
231234 // d_X;
@@ -239,6 +242,7 @@ void Conv1dCalculateGradient(Conv1d filter, ConvGradient *gradient, const float
239242 }
240243 op_mat_transp (d_x_transposed + b * inp_size , gradient -> d_X + b * inp_size , filter -> config .input_size , in_ftrs );
241244 }
245+
242246}
243247
244248
0 commit comments