Skip to content

Commit dcb0b3f

Browse files
franklinicclaude
andcommitted
fix: keep discriminator/critic in training mode during generator backprop
The Backpropagate method requires training mode to be enabled. Previous code set discriminator/critic to eval mode before calling Backpropagate, which would throw InvalidOperationException. Changed files: - ACGAN.cs: Keep discriminator in training mode, use BackwardWithInputGradient - StyleGAN.cs: Keep discriminator in training mode, use BackwardWithInputGradient - InfoGAN.cs: Keep discriminator and QNetwork in training mode - WGAN.cs: Keep critic in training mode - WGANGP.cs: Keep critic in training mode 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 7ad1c4b commit dcb0b3f

5 files changed

Lines changed: 27 additions & 32 deletions

File tree

src/NeuralNetworks/ACGAN.cs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ public ACGAN(
205205
// ----- Train Generator -----
206206

207207
Generator.SetTrainingMode(true);
208-
Discriminator.SetTrainingMode(false);
208+
// Keep Discriminator in training mode - required for backpropagation
209+
// We just don't call UpdateDiscriminatorParameters() during generator training
209210

210211
// Generate new fake images
211212
var newGeneratorInput = ConcatenateTensors(noise, fakeLabels);
@@ -219,14 +220,12 @@ public ACGAN(
219220
T genClassLoss = CalculateClassificationLoss(genDiscOutput, fakeLabels, batchSize);
220221
T generatorLoss = NumOps.Add(genAuthLoss, genClassLoss);
221222

222-
// Backpropagate through discriminator and generator
223+
// Backpropagate through discriminator to get input gradients, then through generator
223224
var genGradients = CalculateDiscriminatorGradients(genDiscOutput, fakeLabels, isReal: true, batchSize);
224-
var discInputGradients = Discriminator.Backpropagate(genGradients);
225-
Generator.Backpropagate(discInputGradients);
225+
var discInputGradients = Discriminator.BackwardWithInputGradient(genGradients);
226+
Generator.Backward(discInputGradients);
226227
UpdateGeneratorParameters();
227228

228-
Discriminator.SetTrainingMode(true);
229-
230229
// Track losses
231230
_discriminatorLosses.Add(discriminatorLoss);
232231
_generatorLosses.Add(generatorLoss);

src/NeuralNetworks/InfoGAN.cs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ public InfoGAN(
269269
// ----- Train Generator and Q Network -----
270270

271271
Generator.SetTrainingMode(true);
272-
Discriminator.SetTrainingMode(false);
272+
// Keep Discriminator and QNetwork in training mode - required for backpropagation
273+
// We just don't call UpdateDiscriminatorParameters() during generator training
273274
QNetwork.SetTrainingMode(true);
274275

275276
// Generate new fake images
@@ -289,13 +290,13 @@ public InfoGAN(
289290
T miCoeff = NumOps.FromDouble(_mutualInfoCoefficient);
290291
T generatorLoss = NumOps.Add(ganLoss, NumOps.Multiply(miCoeff, mutualInfoLoss));
291292

292-
// Backpropagate through discriminator (for GAN loss)
293+
// Backpropagate through discriminator (for GAN loss) to get input gradients
293294
var ganGradients = CalculateBinaryGradients(genPredictions, allRealLabels, batchSize);
294-
var discInputGradients = Discriminator.Backpropagate(ganGradients);
295+
var discInputGradients = Discriminator.BackwardWithInputGradient(ganGradients);
295296

296-
// Backpropagate through Q network (for MI loss)
297+
// Backpropagate through Q network (for MI loss) to get input gradients
297298
var miGradients = CalculateMutualInfoGradients(predictedCodes, latentCodes, batchSize);
298-
var qInputGradients = QNetwork.Backpropagate(miGradients);
299+
var qInputGradients = QNetwork.BackwardWithInputGradient(miGradients);
299300

300301
// Combine gradients
301302
var combinedGradients = new Tensor<T>(discInputGradients.Shape);
@@ -309,12 +310,10 @@ public InfoGAN(
309310
}
310311

311312
// Backpropagate through generator
312-
Generator.Backpropagate(combinedGradients);
313+
Generator.Backward(combinedGradients);
313314
UpdateGeneratorParameters();
314315
UpdateQNetworkParameters();
315316

316-
Discriminator.SetTrainingMode(true);
317-
318317
// Track losses
319318
_discriminatorLosses.Add(discriminatorLoss);
320319
_generatorLosses.Add(generatorLoss);

src/NeuralNetworks/StyleGAN.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ public StyleGAN(
288288

289289
MappingNetwork.SetTrainingMode(true);
290290
SynthesisNetwork.SetTrainingMode(true);
291-
Discriminator.SetTrainingMode(false);
291+
// Keep Discriminator in training mode - required for backpropagation
292+
// We just don't call UpdateDiscriminatorParameters() during generator training
292293

293294
// Generate new images
294295
var newLatentCodes = GenerateRandomLatentCodes(batchSize);
@@ -300,22 +301,20 @@ public StyleGAN(
300301
var allRealLabels = CreateLabelTensor(batchSize, NumOps.One);
301302
T generatorLoss = CalculateBinaryLoss(genPredictions, allRealLabels, batchSize);
302303

303-
// Backpropagate
304+
// Backpropagate through discriminator to get input gradients
304305
var genGradients = CalculateBinaryGradients(genPredictions, allRealLabels, batchSize);
305-
var discInputGradients = Discriminator.Backpropagate(genGradients);
306+
var discInputGradients = Discriminator.BackwardWithInputGradient(genGradients);
306307

307308
// Backprop through synthesis network
308-
var styleGradients = SynthesisNetwork.Backpropagate(discInputGradients);
309+
var styleGradients = SynthesisNetwork.BackwardWithInputGradient(discInputGradients);
309310

310311
// Backprop through mapping network
311-
MappingNetwork.Backpropagate(styleGradients);
312+
MappingNetwork.Backward(styleGradients);
312313

313314
// Update both generator networks
314315
UpdateSynthesisNetworkParameters();
315316
UpdateMappingNetworkParameters();
316317

317-
Discriminator.SetTrainingMode(true);
318-
319318
return (discriminatorLoss, generatorLoss);
320319
}
321320

src/NeuralNetworks/WGAN.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ private T TrainCriticBatch(Tensor<T> images, bool isReal)
337337
private T TrainGeneratorBatch(Tensor<T> noise)
338338
{
339339
Generator.SetTrainingMode(true);
340-
Critic.SetTrainingMode(false); // Freeze critic
340+
// Keep Critic in training mode - required for backpropagation
341+
// We just don't call UpdateCriticParameters() during generator training
341342

342343
// Generate fake images
343344
var generatedImages = Generator.Predict(noise);
@@ -368,16 +369,14 @@ private T TrainGeneratorBatch(Tensor<T> noise)
368369
}
369370

370371
// Backpropagate through critic to get gradients for generator output
371-
var criticInputGradients = Critic.Backpropagate(gradients);
372+
var criticInputGradients = Critic.BackwardWithInputGradient(gradients);
372373

373374
// Backpropagate through generator
374-
Generator.Backpropagate(criticInputGradients);
375+
Generator.Backward(criticInputGradients);
375376

376377
// Update generator parameters
377378
UpdateGeneratorParameters();
378379

379-
Critic.SetTrainingMode(true);
380-
381380
return loss;
382381
}
383382

src/NeuralNetworks/WGANGP.cs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,8 @@ private T ComputeGradientPenalty(Tensor<T> realImages, Tensor<T> fakeImages, int
531531
private T TrainGeneratorBatch(Tensor<T> noise)
532532
{
533533
Generator.SetTrainingMode(true);
534-
Critic.SetTrainingMode(false); // Freeze critic
534+
// Keep Critic in training mode - required for backpropagation
535+
// We just don't call UpdateCriticParameters() during generator training
535536

536537
// Generate fake images
537538
var generatedImages = Generator.Predict(noise);
@@ -558,17 +559,15 @@ private T TrainGeneratorBatch(Tensor<T> noise)
558559
gradients[i, 0] = NumOps.Divide(NumOps.One, NumOps.FromDouble(batchSize));
559560
}
560561

561-
// Backpropagate through critic (frozen) to get gradients for generator
562-
var criticInputGradients = Critic.Backpropagate(gradients);
562+
// Backpropagate through critic to get gradients for generator
563+
var criticInputGradients = Critic.BackwardWithInputGradient(gradients);
563564

564565
// Backpropagate through generator
565-
Generator.Backpropagate(criticInputGradients);
566+
Generator.Backward(criticInputGradients);
566567

567568
// Update generator parameters
568569
UpdateGeneratorParameters();
569570

570-
Critic.SetTrainingMode(true);
571-
572571
return loss;
573572
}
574573

0 commit comments

Comments
 (0)