55// Created by Alex on 21.10.2020.
66//
77
8- #include <nntoolkitcore/core/memory.h>
98#include "nntoolkitcore/layers/time_distributed_dense.h"
10- #include "nntoolkitcore/core/ops.h"
119#include "nntoolkitcore/core/loop.h"
1210#include "stdlib.h"
1311
1412
15- typedef struct {
16- TimeDistributedDenseTrainingConfig config ;
17- DenseTrainingConfig dense_config ;
18- } TimeDistributedDenseTrainingData ;
19-
2013struct TimeDistributedDenseStruct {
2114 TimeDistributedDenseConfig config ;
22- TimeDistributedDenseTrainingData * training_data ;
2315 Dense dense ;
2416};
2517
@@ -39,18 +31,13 @@ TimeDistributedDense TimeDistributedDenseCreate(TimeDistributedDenseConfig confi
3931TimeDistributedDense TimeDistributedDenseCreateForInference (TimeDistributedDenseConfig config ){
4032 TimeDistributedDense ts_filter = TimeDistributedDenseCreate (config );
4133 ts_filter -> dense = DenseCreateForInference (config .dense );
42- ts_filter -> training_data = NULL ;
4334 return ts_filter ;
4435}
4536
46-
4737TimeDistributedDense TimeDistributedDenseCreateForTraining (TimeDistributedDenseConfig config , TimeDistributedDenseTrainingConfig training_config ) {
4838 TimeDistributedDense ts_filter = TimeDistributedDenseCreate (config );
4939 DenseTrainingConfig dense_training_config = DefaultTrainingConfigCreate (training_config .mini_batch_size * config .ts );
5040 ts_filter -> dense = DenseCreateForTraining (config .dense , dense_training_config );
51- ts_filter -> training_data = malloc (sizeof (TimeDistributedDenseTrainingData ));
52- ts_filter -> training_data -> config = training_config ;
53- ts_filter -> training_data -> dense_config = dense_training_config ;
5441 return ts_filter ;
5542}
5643
@@ -59,18 +46,10 @@ DenseWeights* TimeDistributedDenseGetWeights(TimeDistributedDense filter) {
5946}
6047
6148DenseGradient * TimeDistributedDenseGradientCreate (TimeDistributedDense filter ){
62- return DenseGradientCreate (
63- filter -> config .dense ,
64- DefaultTrainingConfigCreate (filter -> config .ts *
65- filter -> training_data -> config .mini_batch_size
66- ));
49+ return DenseGradientCreateFromFilter (filter -> dense );
6750}
6851
69-
7052int TimeDistributedDenseApplyInference (TimeDistributedDense filter , const float * input , float * output ){
71- if (filter -> training_data != NULL ){
72- return -1 ;
73- }
7453 P_LOOP_START (filter -> config .ts , ts )
7554 DenseApplyInference (filter -> dense , input + ts * filter -> config .dense .input_size ,
7655 output + ts * filter -> config .dense .output_size );
@@ -79,11 +58,7 @@ int TimeDistributedDenseApplyInference(TimeDistributedDense filter, const float
7958}
8059
8160int TimeDistributedDenseApplyTrainingBatch (TimeDistributedDense filter , const float * input , float * output ){
82- if (filter -> training_data == NULL ){
83- return -1 ;
84- }
85- DenseApplyTrainingBatch (filter -> dense , input , output );
86- return 0 ;
61+ return DenseApplyTrainingBatch (filter -> dense , input , output );
8762}
8863
8964void TimeDistributedDenseCalculateGradient (TimeDistributedDense filter , DenseGradient * gradient , float * d_out ){
@@ -93,4 +68,4 @@ void TimeDistributedDenseCalculateGradient(TimeDistributedDense filter, DenseGra
9368void TimeDistributedDenseDestroy (TimeDistributedDense filter ){
9469 DenseDestroy (filter -> dense );
9570 free (filter );
96- }
71+ }
0 commit comments