Skip to content

Commit 845c9a1

Browse files
committed
fix: correct DenseLayer activation gradients
1 parent 08ec84e commit 845c9a1

1 file changed

Lines changed: 8 additions & 18 deletions

File tree

src/NeuralNetworks/Layers/DenseLayer.cs

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ public class DenseLayer<T> : LayerBase<T>
135135
/// </para>
136136
/// </remarks>
137137
private Tensor<T>? _lastInput;
138+
private Tensor<T>? _lastPreActivation;
138139

139140
/// <summary>
140141
/// Gets the total number of trainable parameters in the layer.
@@ -370,16 +371,17 @@ public override Tensor<T> Forward(Tensor<T> input)
370371
_lastInput = input;
371372

372373
var (input2D, squeezeOutput) = EnsureRank2BatchFirst(input);
373-
var output = input2D.Multiply(_weights.Transpose()).Add(_biases);
374+
var preActivation = input2D.Multiply(_weights.Transpose()).Add(_biases);
375+
_lastPreActivation = preActivation;
374376

375377
if (UsingVectorActivation)
376378
{
377-
var activated = VectorActivation!.Activate(output);
379+
var activated = VectorActivation!.Activate(preActivation);
378380
return squeezeOutput ? activated.Reshape([activated.Shape[1]]) : activated;
379381
}
380382
else
381383
{
382-
var activated = ApplyActivation(output);
384+
var activated = ApplyActivation(preActivation);
383385
return squeezeOutput ? activated.Reshape([activated.Shape[1]]) : activated;
384386
}
385387
}
@@ -411,7 +413,7 @@ public override Tensor<T> Forward(Tensor<T> input)
411413
/// </remarks>
412414
public override Tensor<T> Backward(Tensor<T> outputGradient)
413415
{
414-
if (_lastInput == null)
416+
if (_lastInput == null || _lastPreActivation == null)
415417
throw new InvalidOperationException("Forward pass must be called before backward pass.");
416418

417419
var (lastInput2D, _) = EnsureRank2BatchFirst(_lastInput);
@@ -430,20 +432,7 @@ public override Tensor<T> Backward(Tensor<T> outputGradient)
430432
outputGradient2D = outputGradient.Reshape([batchSize, outputGradient.Length / batchSize]);
431433
}
432434

433-
Tensor<T> activationGradient;
434-
if (UsingVectorActivation)
435-
{
436-
activationGradient = VectorActivation!.Derivative(outputGradient2D);
437-
}
438-
else
439-
{
440-
// Apply scalar activation derivative element-wise
441-
activationGradient = new Tensor<T>(outputGradient2D.Shape);
442-
for (int i = 0; i < outputGradient2D.Length; i++)
443-
{
444-
activationGradient[i] = ScalarActivation!.Derivative(outputGradient2D[i]);
445-
}
446-
}
435+
var activationGradient = ApplyActivationDerivative(_lastPreActivation, outputGradient2D);
447436

448437
_weightsGradient = activationGradient.Transpose([1, 0]).ToMatrix().Multiply(lastInput2D.ToMatrix());
449438
_biasesGradient = activationGradient.Sum([0]).ToVector();
@@ -638,6 +627,7 @@ public override void ResetState()
638627
{
639628
// Clear cached values from forward and backward passes
640629
_lastInput = null;
630+
_lastPreActivation = null;
641631
_weightsGradient = null;
642632
_biasesGradient = null;
643633
}

0 commit comments

Comments
 (0)