Skip to content

Commit c9a3891

Browse files
franklinicclaude
andcommitted
fix: properly accumulate gradient penalty gradients in WGANGP
- Restructure WGANGP to capture and combine gradients from all sources: real image gradients, fake image gradients, and gradient penalty gradients - Create ComputeGradientPenaltyWithGradients to return both penalty value and parameter gradients - Add UpdateCriticParametersWithGradients for combined gradient updates - Fix DCGAN InputType to use ThreeDimensional for CNN compatibility - Fix serialization in multiple GAN implementations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent e052e7a commit c9a3891

6 files changed

Lines changed: 668 additions & 118 deletions

File tree

src/NeuralNetworks/ConditionalGAN.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,9 @@ private T TrainDiscriminatorBatch(Tensor<T> images, Tensor<T> labels)
369369
private T TrainGeneratorBatch(Tensor<T> generatorInput, Tensor<T> fakeImagesWithConditions, Tensor<T> targetLabels)
370370
{
371371
Generator.SetTrainingMode(true);
372-
Discriminator.SetTrainingMode(false);
372+
// Keep Discriminator in training mode for backpropagation (required for Backpropagate to work)
373+
// We just won't update its parameters
374+
Discriminator.SetTrainingMode(true);
373375

374376
// Get discriminator output
375377
var discriminatorOutput = Discriminator.Predict(fakeImagesWithConditions);
@@ -380,7 +382,7 @@ private T TrainGeneratorBatch(Tensor<T> generatorInput, Tensor<T> fakeImagesWith
380382
// Calculate gradients
381383
var outputGradients = CalculateBatchGradients(discriminatorOutput, targetLabels);
382384

383-
// Backpropagate through discriminator
385+
// Backpropagate through discriminator to get input gradients (but don't update discriminator weights)
384386
var discriminatorInputGradients = Discriminator.Backpropagate(outputGradients);
385387

386388
// Extract gradients for the image part (not the condition part)

src/NeuralNetworks/DCGAN.cs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public DCGAN(
7272
: base(
7373
CreateDCGANGeneratorArchitecture(latentSize, imageChannels, imageHeight, imageWidth, generatorFeatureMaps),
7474
CreateDCGANDiscriminatorArchitecture(imageChannels, imageHeight, imageWidth, discriminatorFeatureMaps),
75-
InputType.TwoDimensional,
75+
InputType.ThreeDimensional,
7676
lossFunction,
7777
initialLearningRate)
7878
{
@@ -113,11 +113,20 @@ private static NeuralNetworkArchitecture<T> CreateDCGANGeneratorArchitecture(
113113
int imageWidth,
114114
int featureMaps)
115115
{
116+
// For DCGAN generator, the latent vector is first projected and reshaped to an initial
117+
// 3D feature map. The typical starting spatial size is 4x4 which gets upsampled through
118+
// transposed convolutions. The depth represents the number of feature channels.
119+
// Note: The actual latent vector (1D) handling is done by the first projection layer.
120+
int initialSpatialSize = 4;
121+
int initialChannels = featureMaps * 8; // Standard DCGAN uses 8x feature maps initially
122+
116123
return new NeuralNetworkArchitecture<T>(
117-
InputType.OneDimensional,
124+
InputType.ThreeDimensional,
118125
NeuralNetworkTaskType.Generative,
119126
NetworkComplexity.Medium,
120-
inputSize: latentSize,
127+
inputDepth: initialChannels,
128+
inputHeight: initialSpatialSize,
129+
inputWidth: initialSpatialSize,
121130
outputSize: imageChannels * imageHeight * imageWidth);
122131
}
123132

@@ -154,12 +163,15 @@ private static NeuralNetworkArchitecture<T> CreateDCGANDiscriminatorArchitecture
154163
int imageWidth,
155164
int featureMaps)
156165
{
166+
// DCGAN discriminator takes 3D images as input (channels x height x width)
167+
// and outputs a single probability value for real/fake classification
157168
return new NeuralNetworkArchitecture<T>(
158-
InputType.TwoDimensional,
169+
InputType.ThreeDimensional,
159170
NeuralNetworkTaskType.BinaryClassification,
160171
NetworkComplexity.Medium,
172+
inputDepth: imageChannels,
161173
inputHeight: imageHeight,
162-
inputWidth: imageWidth * imageChannels,
174+
inputWidth: imageWidth,
163175
outputSize: 1);
164176
}
165177
}

0 commit comments

Comments
 (0)