Skip to content

Commit 81ad554

Browse files
franklinicclaude
andcommitted
fix: address PR review comments for GAN implementations
- Fix unused variable assignments (WGANGP, ProgressiveGAN, InceptionScore) - Add readonly modifier to fields in ACGAN - Fix potential integer overflow in InfoGAN - Fix precision loss in ProgressiveGAN 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 19f00a5 commit 81ad554

15 files changed

Lines changed: 2094 additions & 1690 deletions

src/Metrics/FrechetInceptionDistance.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using System;
22
using System.Linq;
3-
using AiDotNet.Mathematics;
43
using AiDotNet.NeuralNetworks;
54

65
namespace AiDotNet.Metrics
@@ -68,7 +67,7 @@ public FrechetInceptionDistance(
6867
ConvolutionalNeuralNetwork<T>? inceptionNetwork = null,
6968
int featureDimension = 2048)
7069
{
71-
NumOps = NumericOperations<T>.Instance;
70+
NumOps = MathHelper.GetNumericOperations<T>();
7271
InceptionNetwork = inceptionNetwork;
7372
FeatureDimension = featureDimension;
7473
FeatureLayer = -2; // Second to last layer (before classification)
@@ -103,18 +102,18 @@ public double ComputeFID(Tensor<T> realImages, Tensor<T> generatedImages)
103102
/// <returns>Feature matrix (num_images × feature_dim)</returns>
104103
private Matrix<T> ExtractFeatures(Tensor<T> images)
105104
{
105+
var numImages = images.Shape[0];
106+
106107
if (InceptionNetwork == null)
107108
{
108109
// If no Inception network provided, return dummy features
109110
// In a real implementation, you would load a pre-trained InceptionV3
110-
var numImages = images.Shape[0];
111111
return CreateDummyFeatures(numImages);
112112
}
113113

114114
// Set to inference mode
115115
InceptionNetwork.SetTrainingMode(false);
116116

117-
var numImages = images.Shape[0];
118117
var features = new Matrix<T>(numImages, FeatureDimension);
119118

120119
// Process each image
@@ -123,7 +122,12 @@ private Matrix<T> ExtractFeatures(Tensor<T> images)
123122
// Extract single image
124123
var imageSize = images.Length / numImages;
125124
var singleImage = new Tensor<T>(new[] { 1, images.Shape[1], images.Shape[2], images.Shape[3] });
126-
Array.Copy(images.Data, i * imageSize, singleImage.Data, 0, imageSize);
125+
126+
// Copy data from source tensor to single image tensor
127+
for (int k = 0; k < imageSize; k++)
128+
{
129+
singleImage.SetFlat(k, images.GetFlat(i * imageSize + k));
130+
}
127131

128132
// Forward pass through Inception network
129133
var output = InceptionNetwork.Predict(singleImage);
@@ -133,7 +137,7 @@ private Matrix<T> ExtractFeatures(Tensor<T> images)
133137
// For now, use output
134138
for (int j = 0; j < Math.Min(output.Length, FeatureDimension); j++)
135139
{
136-
features[i, j] = output.Data[j];
140+
features[i, j] = output.GetFlat(j);
137141
}
138142
}
139143

src/Metrics/InceptionScore.cs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using System;
22
using System.Linq;
3-
using AiDotNet.Mathematics;
43
using AiDotNet.NeuralNetworks;
54

65
namespace AiDotNet.Metrics
@@ -78,7 +77,7 @@ public InceptionScore(
7877
int numClasses = 1000,
7978
int numSplits = 10)
8079
{
81-
NumOps = NumericOperations<T>.Instance;
80+
NumOps = MathHelper.GetNumericOperations<T>();
8281
InceptionNetwork = inceptionNetwork;
8382
NumClasses = numClasses;
8483
NumSplits = numSplits;
@@ -91,7 +90,7 @@ public InceptionScore(
9190
/// <returns>Inception Score (higher is better, typical range 1-15+)</returns>
9291
public double ComputeIS(Tensor<T> generatedImages)
9392
{
94-
var (mean, std) = ComputeISWithUncertainty(generatedImages);
93+
var (mean, _) = ComputeISWithUncertainty(generatedImages);
9594
return mean;
9695
}
9796

@@ -184,7 +183,10 @@ private Matrix<T> GetPredictions(Tensor<T> images)
184183
// Extract single image
185184
var imageSize = images.Length / numImages;
186185
var singleImage = new Tensor<T>(new[] { 1, images.Shape[1], images.Shape[2], images.Shape[3] });
187-
Array.Copy(images.Data, i * imageSize, singleImage.Data, 0, imageSize);
186+
for (int idx = 0; idx < imageSize; idx++)
187+
{
188+
singleImage.SetFlat(idx, images.GetFlat(i * imageSize + idx));
189+
}
188190

189191
// Forward pass
190192
var output = InceptionNetwork.Predict(singleImage);
@@ -195,7 +197,7 @@ private Matrix<T> GetPredictions(Tensor<T> images)
195197
// Store predictions
196198
for (int j = 0; j < Math.Min(probs.Length, NumClasses); j++)
197199
{
198-
predictions[i, j] = probs.Data[j];
200+
predictions[i, j] = probs.GetFlat(j);
199201
}
200202
}
201203

@@ -213,26 +215,26 @@ private Tensor<T> Softmax(Tensor<T> logits)
213215
// Find max for numerical stability
214216
for (int i = 0; i < logits.Length; i++)
215217
{
216-
if (NumOps.Compare(logits.Data[i], maxLogit) > 0)
218+
if (NumOps.GreaterThan(logits.GetFlat(i), maxLogit))
217219
{
218-
maxLogit = logits.Data[i];
220+
maxLogit = logits.GetFlat(i);
219221
}
220222
}
221223

222224
// Compute exp(x - max) and sum
223225
var sum = NumOps.Zero;
224226
for (int i = 0; i < logits.Length; i++)
225227
{
226-
var shifted = NumOps.Subtract(logits.Data[i], maxLogit);
228+
var shifted = NumOps.Subtract(logits.GetFlat(i), maxLogit);
227229
var expVal = NumOps.Exp(shifted);
228-
result.Data[i] = expVal;
230+
result.SetFlat(i, expVal);
229231
sum = NumOps.Add(sum, expVal);
230232
}
231233

232234
// Normalize
233235
for (int i = 0; i < result.Length; i++)
234236
{
235-
result.Data[i] = NumOps.Divide(result.Data[i], sum);
237+
result.SetFlat(i, NumOps.Divide(result.GetFlat(i), sum));
236238
}
237239

238240
return result;
@@ -325,7 +327,10 @@ private Tensor<T> ExtractImageSubset(Tensor<T> images, int startIdx, int count)
325327
}
326328

327329
var subset = new Tensor<T>(subsetShape);
328-
Array.Copy(images.Data, startIdx * imageSize, subset.Data, 0, count * imageSize);
330+
for (int idx = 0; idx < count * imageSize; idx++)
331+
{
332+
subset.SetFlat(idx, images.GetFlat(startIdx * imageSize + idx));
333+
}
329334

330335
return subset;
331336
}

src/NeuralNetworks/ACGAN.cs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ public class ACGAN<T> : NeuralNetworkBase<T>
3737
private T _beta1Power;
3838
private T _beta2Power;
3939
private double _currentLearningRate;
40-
private double _initialLearningRate;
41-
private double _learningRateDecay;
42-
private List<T> _generatorLosses = [];
43-
private List<T> _discriminatorLosses = [];
40+
private readonly double _initialLearningRate;
41+
private readonly double _learningRateDecay;
42+
private readonly List<T> _generatorLosses = [];
43+
private readonly List<T> _discriminatorLosses = [];
4444

4545
/// <summary>
4646
/// The number of classes for classification.
@@ -72,7 +72,7 @@ public class ACGAN<T> : NeuralNetworkBase<T>
7272
/// </remarks>
7373
public ConvolutionalNeuralNetwork<T> Discriminator { get; private set; }
7474

75-
private ILossFunction<T> _lossFunction;
75+
private readonly ILossFunction<T> _lossFunction;
7676

7777
/// <summary>
7878
/// Initializes a new instance of the <see cref="ACGAN{T}"/> class.
@@ -536,4 +536,26 @@ protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
536536
_lossFunction,
537537
_initialLearningRate);
538538
}
539+
540+
/// <summary>
541+
/// Updates the parameters of all networks in the ACGAN.
542+
/// </summary>
543+
/// <param name="parameters">The new parameters vector containing parameters for all networks.</param>
544+
public override void UpdateParameters(Vector<T> parameters)
545+
{
546+
int generatorCount = Generator.GetParameterCount();
547+
int discriminatorCount = Discriminator.GetParameterCount();
548+
549+
// Update Generator parameters
550+
var generatorParams = new Vector<T>(generatorCount);
551+
for (int i = 0; i < generatorCount; i++)
552+
generatorParams[i] = parameters[i];
553+
Generator.UpdateParameters(generatorParams);
554+
555+
// Update Discriminator parameters
556+
var discriminatorParams = new Vector<T>(discriminatorCount);
557+
for (int i = 0; i < discriminatorCount; i++)
558+
discriminatorParams[i] = parameters[generatorCount + i];
559+
Discriminator.UpdateParameters(discriminatorParams);
560+
}
539561
}

0 commit comments

Comments
 (0)