@@ -32,11 +32,20 @@ namespace AiDotNet.NeuralNetworks;
3232/// <typeparam name="T">The numeric type used for calculations, typically float or double.</typeparam>
3333public class ACGAN < T > : NeuralNetworkBase < T >
3434{
35- private Vector < T > _momentum ;
36- private Vector < T > _secondMoment ;
37- private T _beta1Power ;
38- private T _beta2Power ;
39- private double _currentLearningRate ;
35+ // Generator optimizer state
36+ private Vector < T > _genMomentum ;
37+ private Vector < T > _genSecondMoment ;
38+ private T _genBeta1Power ;
39+ private T _genBeta2Power ;
40+ private double _genCurrentLearningRate ;
41+
42+ // Discriminator optimizer state
43+ private Vector < T > _discMomentum ;
44+ private Vector < T > _discSecondMoment ;
45+ private T _discBeta1Power ;
46+ private T _discBeta2Power ;
47+ private double _discCurrentLearningRate ;
48+
4049 private readonly double _initialLearningRate ;
4150 private readonly double _learningRateDecay ;
4251 private readonly List < T > _generatorLosses = [ ] ;
@@ -102,17 +111,24 @@ public ACGAN(
102111 {
103112 _numClasses = numClasses ;
104113 _initialLearningRate = initialLearningRate ;
105- _currentLearningRate = initialLearningRate ;
114+ _genCurrentLearningRate = initialLearningRate ;
115+ _discCurrentLearningRate = initialLearningRate ;
106116 _learningRateDecay = 0.9999 ;
107117
108- // Initialize optimizer parameters
109- _beta1Power = NumOps . One ;
110- _beta2Power = NumOps . One ;
118+ // Initialize generator optimizer state
119+ _genBeta1Power = NumOps . One ;
120+ _genBeta2Power = NumOps . One ;
121+ _genMomentum = Vector < T > . Empty ( ) ;
122+ _genSecondMoment = Vector < T > . Empty ( ) ;
123+
124+ // Initialize discriminator optimizer state
125+ _discBeta1Power = NumOps . One ;
126+ _discBeta2Power = NumOps . One ;
127+ _discMomentum = Vector < T > . Empty ( ) ;
128+ _discSecondMoment = Vector < T > . Empty ( ) ;
111129
112130 Generator = new ConvolutionalNeuralNetwork < T > ( generatorArchitecture ) ;
113131 Discriminator = new ConvolutionalNeuralNetwork < T > ( discriminatorArchitecture ) ;
114- _momentum = Vector < T > . Empty ( ) ;
115- _secondMoment = Vector < T > . Empty ( ) ;
116132 _lossFunction = lossFunction ?? NeuralNetworkHelper < T > . GetDefaultLossFunction ( generatorArchitecture . TaskType ) ;
117133
118134 InitializeLayers ( ) ;
@@ -170,7 +186,7 @@ public ACGAN(
170186 // Backpropagate for real images
171187 var realGradients = CalculateDiscriminatorGradients ( realDiscOutput , realLabels , isReal : true , batchSize ) ;
172188 Discriminator . Backpropagate ( realGradients ) ;
173- UpdateNetworkParameters ( Discriminator ) ;
189+ UpdateDiscriminatorParameters ( ) ;
174190
175191 // Train discriminator on fake images
176192 var fakeDiscOutput = Discriminator . Predict ( fakeImages ) ;
@@ -181,7 +197,7 @@ public ACGAN(
181197 // Backpropagate for fake images
182198 var fakeGradients = CalculateDiscriminatorGradients ( fakeDiscOutput , fakeLabels , isReal : false , batchSize ) ;
183199 Discriminator . Backpropagate ( fakeGradients ) ;
184- UpdateNetworkParameters ( Discriminator ) ;
200+ UpdateDiscriminatorParameters ( ) ;
185201
186202 // Total discriminator loss
187203 T discriminatorLoss = NumOps . Divide ( NumOps . Add ( realLoss , fakeLoss ) , NumOps . FromDouble ( 2.0 ) ) ;
@@ -207,7 +223,7 @@ public ACGAN(
207223 var genGradients = CalculateDiscriminatorGradients ( genDiscOutput , fakeLabels , isReal : true , batchSize ) ;
208224 var discInputGradients = Discriminator . Backpropagate ( genGradients ) ;
209225 Generator . Backpropagate ( discInputGradients ) ;
210- UpdateNetworkParameters ( Generator ) ;
226+ UpdateGeneratorParameters ( ) ;
211227
212228 Discriminator . SetTrainingMode ( true ) ;
213229
@@ -403,26 +419,89 @@ public Tensor<T> GenerateRandomNoiseTensor(int batchSize, int noiseSize)
403419 }
404420
405421 /// <summary>
406- /// Updates network parameters using Adam optimizer.
422+ /// Updates generator parameters using Adam optimizer with generator-specific state.
423+ /// </summary>
424+ private void UpdateGeneratorParameters ( )
425+ {
426+ var parameters = Generator . GetParameters ( ) ;
427+ var gradients = Generator . GetParameterGradients ( ) ;
428+
429+ if ( _genMomentum == null || _genMomentum . Length != parameters . Length )
430+ {
431+ _genMomentum = new Vector < T > ( parameters . Length ) ;
432+ _genMomentum . Fill ( NumOps . Zero ) ;
433+ }
434+
435+ if ( _genSecondMoment == null || _genSecondMoment . Length != parameters . Length )
436+ {
437+ _genSecondMoment = new Vector < T > ( parameters . Length ) ;
438+ _genSecondMoment . Fill ( NumOps . Zero ) ;
439+ }
440+
441+ var learningRate = NumOps . FromDouble ( _genCurrentLearningRate ) ;
442+ var beta1 = NumOps . FromDouble ( 0.5 ) ;
443+ var beta2 = NumOps . FromDouble ( 0.999 ) ;
444+ var epsilon = NumOps . FromDouble ( 1e-8 ) ;
445+
446+ var updatedParameters = new Vector < T > ( parameters . Length ) ;
447+
448+ for ( int i = 0 ; i < parameters . Length ; i ++ )
449+ {
450+ _genMomentum [ i ] = NumOps . Add (
451+ NumOps . Multiply ( beta1 , _genMomentum [ i ] ) ,
452+ NumOps . Multiply ( NumOps . Subtract ( NumOps . One , beta1 ) , gradients [ i ] )
453+ ) ;
454+
455+ _genSecondMoment [ i ] = NumOps . Add (
456+ NumOps . Multiply ( beta2 , _genSecondMoment [ i ] ) ,
457+ NumOps . Multiply (
458+ NumOps . Subtract ( NumOps . One , beta2 ) ,
459+ NumOps . Multiply ( gradients [ i ] , gradients [ i ] )
460+ )
461+ ) ;
462+
463+ var momentumCorrected = NumOps . Divide ( _genMomentum [ i ] , NumOps . Subtract ( NumOps . One , _genBeta1Power ) ) ;
464+ var secondMomentCorrected = NumOps . Divide ( _genSecondMoment [ i ] , NumOps . Subtract ( NumOps . One , _genBeta2Power ) ) ;
465+
466+ var adaptiveLR = NumOps . Divide (
467+ learningRate ,
468+ NumOps . Add ( NumOps . Sqrt ( secondMomentCorrected ) , epsilon )
469+ ) ;
470+
471+ updatedParameters [ i ] = NumOps . Subtract (
472+ parameters [ i ] ,
473+ NumOps . Multiply ( adaptiveLR , momentumCorrected )
474+ ) ;
475+ }
476+
477+ _genBeta1Power = NumOps . Multiply ( _genBeta1Power , beta1 ) ;
478+ _genBeta2Power = NumOps . Multiply ( _genBeta2Power , beta2 ) ;
479+ _genCurrentLearningRate *= _learningRateDecay ;
480+
481+ Generator . UpdateParameters ( updatedParameters ) ;
482+ }
483+
484+ /// <summary>
485+ /// Updates discriminator parameters using Adam optimizer with discriminator-specific state.
407486 /// </summary>
408- private void UpdateNetworkParameters ( ConvolutionalNeuralNetwork < T > network )
487+ private void UpdateDiscriminatorParameters ( )
409488 {
410- var parameters = network . GetParameters ( ) ;
411- var gradients = network . GetParameterGradients ( ) ;
489+ var parameters = Discriminator . GetParameters ( ) ;
490+ var gradients = Discriminator . GetParameterGradients ( ) ;
412491
413- if ( _momentum == null || _momentum . Length != parameters . Length )
492+ if ( _discMomentum == null || _discMomentum . Length != parameters . Length )
414493 {
415- _momentum = new Vector < T > ( parameters . Length ) ;
416- _momentum . Fill ( NumOps . Zero ) ;
494+ _discMomentum = new Vector < T > ( parameters . Length ) ;
495+ _discMomentum . Fill ( NumOps . Zero ) ;
417496 }
418497
419- if ( _secondMoment == null || _secondMoment . Length != parameters . Length )
498+ if ( _discSecondMoment == null || _discSecondMoment . Length != parameters . Length )
420499 {
421- _secondMoment = new Vector < T > ( parameters . Length ) ;
422- _secondMoment . Fill ( NumOps . Zero ) ;
500+ _discSecondMoment = new Vector < T > ( parameters . Length ) ;
501+ _discSecondMoment . Fill ( NumOps . Zero ) ;
423502 }
424503
425- var learningRate = NumOps . FromDouble ( _currentLearningRate ) ;
504+ var learningRate = NumOps . FromDouble ( _discCurrentLearningRate ) ;
426505 var beta1 = NumOps . FromDouble ( 0.5 ) ;
427506 var beta2 = NumOps . FromDouble ( 0.999 ) ;
428507 var epsilon = NumOps . FromDouble ( 1e-8 ) ;
@@ -431,21 +510,21 @@ private void UpdateNetworkParameters(ConvolutionalNeuralNetwork<T> network)
431510
432511 for ( int i = 0 ; i < parameters . Length ; i ++ )
433512 {
434- _momentum [ i ] = NumOps . Add (
435- NumOps . Multiply ( beta1 , _momentum [ i ] ) ,
513+ _discMomentum [ i ] = NumOps . Add (
514+ NumOps . Multiply ( beta1 , _discMomentum [ i ] ) ,
436515 NumOps . Multiply ( NumOps . Subtract ( NumOps . One , beta1 ) , gradients [ i ] )
437516 ) ;
438517
439- _secondMoment [ i ] = NumOps . Add (
440- NumOps . Multiply ( beta2 , _secondMoment [ i ] ) ,
518+ _discSecondMoment [ i ] = NumOps . Add (
519+ NumOps . Multiply ( beta2 , _discSecondMoment [ i ] ) ,
441520 NumOps . Multiply (
442521 NumOps . Subtract ( NumOps . One , beta2 ) ,
443522 NumOps . Multiply ( gradients [ i ] , gradients [ i ] )
444523 )
445524 ) ;
446525
447- var momentumCorrected = NumOps . Divide ( _momentum [ i ] , NumOps . Subtract ( NumOps . One , _beta1Power ) ) ;
448- var secondMomentCorrected = NumOps . Divide ( _secondMoment [ i ] , NumOps . Subtract ( NumOps . One , _beta2Power ) ) ;
526+ var momentumCorrected = NumOps . Divide ( _discMomentum [ i ] , NumOps . Subtract ( NumOps . One , _discBeta1Power ) ) ;
527+ var secondMomentCorrected = NumOps . Divide ( _discSecondMoment [ i ] , NumOps . Subtract ( NumOps . One , _discBeta2Power ) ) ;
449528
450529 var adaptiveLR = NumOps . Divide (
451530 learningRate ,
@@ -458,11 +537,11 @@ private void UpdateNetworkParameters(ConvolutionalNeuralNetwork<T> network)
458537 ) ;
459538 }
460539
461- _beta1Power = NumOps . Multiply ( _beta1Power , beta1 ) ;
462- _beta2Power = NumOps . Multiply ( _beta2Power , beta2 ) ;
463- _currentLearningRate *= _learningRateDecay ;
540+ _discBeta1Power = NumOps . Multiply ( _discBeta1Power , beta1 ) ;
541+ _discBeta2Power = NumOps . Multiply ( _discBeta2Power , beta2 ) ;
542+ _discCurrentLearningRate *= _learningRateDecay ;
464543
465- network . UpdateParameters ( updatedParameters ) ;
544+ Discriminator . UpdateParameters ( updatedParameters ) ;
466545 }
467546
468547 protected override void InitializeLayers ( )
@@ -500,7 +579,8 @@ public override ModelMetadata<T> GetModelMetadata()
500579
501580 protected override void SerializeNetworkSpecificData ( BinaryWriter writer )
502581 {
503- writer . Write ( _currentLearningRate ) ;
582+ writer . Write ( _genCurrentLearningRate ) ;
583+ writer . Write ( _discCurrentLearningRate ) ;
504584 writer . Write ( _numClasses ) ;
505585
506586 var generatorBytes = Generator . Serialize ( ) ;
@@ -514,7 +594,8 @@ protected override void SerializeNetworkSpecificData(BinaryWriter writer)
514594
515595 protected override void DeserializeNetworkSpecificData ( BinaryReader reader )
516596 {
517- _currentLearningRate = reader . ReadDouble ( ) ;
597+ _genCurrentLearningRate = reader . ReadDouble ( ) ;
598+ _discCurrentLearningRate = reader . ReadDouble ( ) ;
518599 _numClasses = reader . ReadInt32 ( ) ;
519600
520601 int generatorDataLength = reader . ReadInt32 ( ) ;
0 commit comments