Skip to content

Commit e052e7a

Browse files
franklinicclaude
andcommitted
fix: split optimizer state per network in GANs
- Split Adam optimizer state between Generator and Discriminator in ACGAN and ConditionalGAN - Add readonly modifiers to fields in WGAN and WGANGP - Use ternary operators in SAGAN for cleaner code 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 81ad554 commit e052e7a

5 files changed

Lines changed: 250 additions & 103 deletions

File tree

src/NeuralNetworks/ACGAN.cs

Lines changed: 118 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,20 @@ namespace AiDotNet.NeuralNetworks;
3232
/// <typeparam name="T">The numeric type used for calculations, typically float or double.</typeparam>
3333
public class ACGAN<T> : NeuralNetworkBase<T>
3434
{
35-
private Vector<T> _momentum;
36-
private Vector<T> _secondMoment;
37-
private T _beta1Power;
38-
private T _beta2Power;
39-
private double _currentLearningRate;
35+
// Generator optimizer state
36+
private Vector<T> _genMomentum;
37+
private Vector<T> _genSecondMoment;
38+
private T _genBeta1Power;
39+
private T _genBeta2Power;
40+
private double _genCurrentLearningRate;
41+
42+
// Discriminator optimizer state
43+
private Vector<T> _discMomentum;
44+
private Vector<T> _discSecondMoment;
45+
private T _discBeta1Power;
46+
private T _discBeta2Power;
47+
private double _discCurrentLearningRate;
48+
4049
private readonly double _initialLearningRate;
4150
private readonly double _learningRateDecay;
4251
private readonly List<T> _generatorLosses = [];
@@ -102,17 +111,24 @@ public ACGAN(
102111
{
103112
_numClasses = numClasses;
104113
_initialLearningRate = initialLearningRate;
105-
_currentLearningRate = initialLearningRate;
114+
_genCurrentLearningRate = initialLearningRate;
115+
_discCurrentLearningRate = initialLearningRate;
106116
_learningRateDecay = 0.9999;
107117

108-
// Initialize optimizer parameters
109-
_beta1Power = NumOps.One;
110-
_beta2Power = NumOps.One;
118+
// Initialize generator optimizer state
119+
_genBeta1Power = NumOps.One;
120+
_genBeta2Power = NumOps.One;
121+
_genMomentum = Vector<T>.Empty();
122+
_genSecondMoment = Vector<T>.Empty();
123+
124+
// Initialize discriminator optimizer state
125+
_discBeta1Power = NumOps.One;
126+
_discBeta2Power = NumOps.One;
127+
_discMomentum = Vector<T>.Empty();
128+
_discSecondMoment = Vector<T>.Empty();
111129

112130
Generator = new ConvolutionalNeuralNetwork<T>(generatorArchitecture);
113131
Discriminator = new ConvolutionalNeuralNetwork<T>(discriminatorArchitecture);
114-
_momentum = Vector<T>.Empty();
115-
_secondMoment = Vector<T>.Empty();
116132
_lossFunction = lossFunction ?? NeuralNetworkHelper<T>.GetDefaultLossFunction(generatorArchitecture.TaskType);
117133

118134
InitializeLayers();
@@ -170,7 +186,7 @@ public ACGAN(
170186
// Backpropagate for real images
171187
var realGradients = CalculateDiscriminatorGradients(realDiscOutput, realLabels, isReal: true, batchSize);
172188
Discriminator.Backpropagate(realGradients);
173-
UpdateNetworkParameters(Discriminator);
189+
UpdateDiscriminatorParameters();
174190

175191
// Train discriminator on fake images
176192
var fakeDiscOutput = Discriminator.Predict(fakeImages);
@@ -181,7 +197,7 @@ public ACGAN(
181197
// Backpropagate for fake images
182198
var fakeGradients = CalculateDiscriminatorGradients(fakeDiscOutput, fakeLabels, isReal: false, batchSize);
183199
Discriminator.Backpropagate(fakeGradients);
184-
UpdateNetworkParameters(Discriminator);
200+
UpdateDiscriminatorParameters();
185201

186202
// Total discriminator loss
187203
T discriminatorLoss = NumOps.Divide(NumOps.Add(realLoss, fakeLoss), NumOps.FromDouble(2.0));
@@ -207,7 +223,7 @@ public ACGAN(
207223
var genGradients = CalculateDiscriminatorGradients(genDiscOutput, fakeLabels, isReal: true, batchSize);
208224
var discInputGradients = Discriminator.Backpropagate(genGradients);
209225
Generator.Backpropagate(discInputGradients);
210-
UpdateNetworkParameters(Generator);
226+
UpdateGeneratorParameters();
211227

212228
Discriminator.SetTrainingMode(true);
213229

@@ -403,26 +419,89 @@ public Tensor<T> GenerateRandomNoiseTensor(int batchSize, int noiseSize)
403419
}
404420

405421
/// <summary>
406-
/// Updates network parameters using Adam optimizer.
422+
/// Updates generator parameters using Adam optimizer with generator-specific state.
423+
/// </summary>
424+
private void UpdateGeneratorParameters()
425+
{
426+
var parameters = Generator.GetParameters();
427+
var gradients = Generator.GetParameterGradients();
428+
429+
if (_genMomentum == null || _genMomentum.Length != parameters.Length)
430+
{
431+
_genMomentum = new Vector<T>(parameters.Length);
432+
_genMomentum.Fill(NumOps.Zero);
433+
}
434+
435+
if (_genSecondMoment == null || _genSecondMoment.Length != parameters.Length)
436+
{
437+
_genSecondMoment = new Vector<T>(parameters.Length);
438+
_genSecondMoment.Fill(NumOps.Zero);
439+
}
440+
441+
var learningRate = NumOps.FromDouble(_genCurrentLearningRate);
442+
var beta1 = NumOps.FromDouble(0.5);
443+
var beta2 = NumOps.FromDouble(0.999);
444+
var epsilon = NumOps.FromDouble(1e-8);
445+
446+
var updatedParameters = new Vector<T>(parameters.Length);
447+
448+
for (int i = 0; i < parameters.Length; i++)
449+
{
450+
_genMomentum[i] = NumOps.Add(
451+
NumOps.Multiply(beta1, _genMomentum[i]),
452+
NumOps.Multiply(NumOps.Subtract(NumOps.One, beta1), gradients[i])
453+
);
454+
455+
_genSecondMoment[i] = NumOps.Add(
456+
NumOps.Multiply(beta2, _genSecondMoment[i]),
457+
NumOps.Multiply(
458+
NumOps.Subtract(NumOps.One, beta2),
459+
NumOps.Multiply(gradients[i], gradients[i])
460+
)
461+
);
462+
463+
var momentumCorrected = NumOps.Divide(_genMomentum[i], NumOps.Subtract(NumOps.One, _genBeta1Power));
464+
var secondMomentCorrected = NumOps.Divide(_genSecondMoment[i], NumOps.Subtract(NumOps.One, _genBeta2Power));
465+
466+
var adaptiveLR = NumOps.Divide(
467+
learningRate,
468+
NumOps.Add(NumOps.Sqrt(secondMomentCorrected), epsilon)
469+
);
470+
471+
updatedParameters[i] = NumOps.Subtract(
472+
parameters[i],
473+
NumOps.Multiply(adaptiveLR, momentumCorrected)
474+
);
475+
}
476+
477+
_genBeta1Power = NumOps.Multiply(_genBeta1Power, beta1);
478+
_genBeta2Power = NumOps.Multiply(_genBeta2Power, beta2);
479+
_genCurrentLearningRate *= _learningRateDecay;
480+
481+
Generator.UpdateParameters(updatedParameters);
482+
}
483+
484+
/// <summary>
485+
/// Updates discriminator parameters using Adam optimizer with discriminator-specific state.
407486
/// </summary>
408-
private void UpdateNetworkParameters(ConvolutionalNeuralNetwork<T> network)
487+
private void UpdateDiscriminatorParameters()
409488
{
410-
var parameters = network.GetParameters();
411-
var gradients = network.GetParameterGradients();
489+
var parameters = Discriminator.GetParameters();
490+
var gradients = Discriminator.GetParameterGradients();
412491

413-
if (_momentum == null || _momentum.Length != parameters.Length)
492+
if (_discMomentum == null || _discMomentum.Length != parameters.Length)
414493
{
415-
_momentum = new Vector<T>(parameters.Length);
416-
_momentum.Fill(NumOps.Zero);
494+
_discMomentum = new Vector<T>(parameters.Length);
495+
_discMomentum.Fill(NumOps.Zero);
417496
}
418497

419-
if (_secondMoment == null || _secondMoment.Length != parameters.Length)
498+
if (_discSecondMoment == null || _discSecondMoment.Length != parameters.Length)
420499
{
421-
_secondMoment = new Vector<T>(parameters.Length);
422-
_secondMoment.Fill(NumOps.Zero);
500+
_discSecondMoment = new Vector<T>(parameters.Length);
501+
_discSecondMoment.Fill(NumOps.Zero);
423502
}
424503

425-
var learningRate = NumOps.FromDouble(_currentLearningRate);
504+
var learningRate = NumOps.FromDouble(_discCurrentLearningRate);
426505
var beta1 = NumOps.FromDouble(0.5);
427506
var beta2 = NumOps.FromDouble(0.999);
428507
var epsilon = NumOps.FromDouble(1e-8);
@@ -431,21 +510,21 @@ private void UpdateNetworkParameters(ConvolutionalNeuralNetwork<T> network)
431510

432511
for (int i = 0; i < parameters.Length; i++)
433512
{
434-
_momentum[i] = NumOps.Add(
435-
NumOps.Multiply(beta1, _momentum[i]),
513+
_discMomentum[i] = NumOps.Add(
514+
NumOps.Multiply(beta1, _discMomentum[i]),
436515
NumOps.Multiply(NumOps.Subtract(NumOps.One, beta1), gradients[i])
437516
);
438517

439-
_secondMoment[i] = NumOps.Add(
440-
NumOps.Multiply(beta2, _secondMoment[i]),
518+
_discSecondMoment[i] = NumOps.Add(
519+
NumOps.Multiply(beta2, _discSecondMoment[i]),
441520
NumOps.Multiply(
442521
NumOps.Subtract(NumOps.One, beta2),
443522
NumOps.Multiply(gradients[i], gradients[i])
444523
)
445524
);
446525

447-
var momentumCorrected = NumOps.Divide(_momentum[i], NumOps.Subtract(NumOps.One, _beta1Power));
448-
var secondMomentCorrected = NumOps.Divide(_secondMoment[i], NumOps.Subtract(NumOps.One, _beta2Power));
526+
var momentumCorrected = NumOps.Divide(_discMomentum[i], NumOps.Subtract(NumOps.One, _discBeta1Power));
527+
var secondMomentCorrected = NumOps.Divide(_discSecondMoment[i], NumOps.Subtract(NumOps.One, _discBeta2Power));
449528

450529
var adaptiveLR = NumOps.Divide(
451530
learningRate,
@@ -458,11 +537,11 @@ private void UpdateNetworkParameters(ConvolutionalNeuralNetwork<T> network)
458537
);
459538
}
460539

461-
_beta1Power = NumOps.Multiply(_beta1Power, beta1);
462-
_beta2Power = NumOps.Multiply(_beta2Power, beta2);
463-
_currentLearningRate *= _learningRateDecay;
540+
_discBeta1Power = NumOps.Multiply(_discBeta1Power, beta1);
541+
_discBeta2Power = NumOps.Multiply(_discBeta2Power, beta2);
542+
_discCurrentLearningRate *= _learningRateDecay;
464543

465-
network.UpdateParameters(updatedParameters);
544+
Discriminator.UpdateParameters(updatedParameters);
466545
}
467546

468547
protected override void InitializeLayers()
@@ -500,7 +579,8 @@ public override ModelMetadata<T> GetModelMetadata()
500579

501580
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
502581
{
503-
writer.Write(_currentLearningRate);
582+
writer.Write(_genCurrentLearningRate);
583+
writer.Write(_discCurrentLearningRate);
504584
writer.Write(_numClasses);
505585

506586
var generatorBytes = Generator.Serialize();
@@ -514,7 +594,8 @@ protected override void SerializeNetworkSpecificData(BinaryWriter writer)
514594

515595
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
516596
{
517-
_currentLearningRate = reader.ReadDouble();
597+
_genCurrentLearningRate = reader.ReadDouble();
598+
_discCurrentLearningRate = reader.ReadDouble();
518599
_numClasses = reader.ReadInt32();
519600

520601
int generatorDataLength = reader.ReadInt32();

0 commit comments

Comments
 (0)