@@ -135,6 +135,7 @@ public class DenseLayer<T> : LayerBase<T>
135135 /// </para>
136136 /// </remarks>
137137 private Tensor < T > ? _lastInput ;
138+ private Tensor < T > ? _lastPreActivation ;
138139
139140 /// <summary>
140141 /// Gets the total number of trainable parameters in the layer.
@@ -370,16 +371,17 @@ public override Tensor<T> Forward(Tensor<T> input)
370371 _lastInput = input ;
371372
372373 var ( input2D , squeezeOutput ) = EnsureRank2BatchFirst ( input ) ;
373- var output = input2D . Multiply ( _weights . Transpose ( ) ) . Add ( _biases ) ;
374+ var preActivation = input2D . Multiply ( _weights . Transpose ( ) ) . Add ( _biases ) ;
375+ _lastPreActivation = preActivation ;
374376
375377 if ( UsingVectorActivation )
376378 {
377- var activated = VectorActivation ! . Activate ( output ) ;
379+ var activated = VectorActivation ! . Activate ( preActivation ) ;
378380 return squeezeOutput ? activated . Reshape ( [ activated . Shape [ 1 ] ] ) : activated ;
379381 }
380382 else
381383 {
382- var activated = ApplyActivation ( output ) ;
384+ var activated = ApplyActivation ( preActivation ) ;
383385 return squeezeOutput ? activated . Reshape ( [ activated . Shape [ 1 ] ] ) : activated ;
384386 }
385387 }
@@ -411,7 +413,7 @@ public override Tensor<T> Forward(Tensor<T> input)
411413 /// </remarks>
412414 public override Tensor < T > Backward ( Tensor < T > outputGradient )
413415 {
414- if ( _lastInput == null )
416+ if ( _lastInput == null || _lastPreActivation == null )
415417 throw new InvalidOperationException ( "Forward pass must be called before backward pass." ) ;
416418
417419 var ( lastInput2D , _) = EnsureRank2BatchFirst ( _lastInput ) ;
@@ -430,20 +432,7 @@ public override Tensor<T> Backward(Tensor<T> outputGradient)
430432 outputGradient2D = outputGradient . Reshape ( [ batchSize , outputGradient . Length / batchSize ] ) ;
431433 }
432434
433- Tensor < T > activationGradient ;
434- if ( UsingVectorActivation )
435- {
436- activationGradient = VectorActivation ! . Derivative ( outputGradient2D ) ;
437- }
438- else
439- {
440- // Apply scalar activation derivative element-wise
441- activationGradient = new Tensor < T > ( outputGradient2D . Shape ) ;
442- for ( int i = 0 ; i < outputGradient2D . Length ; i ++ )
443- {
444- activationGradient [ i ] = ScalarActivation ! . Derivative ( outputGradient2D [ i ] ) ;
445- }
446- }
435+ var activationGradient = ApplyActivationDerivative ( _lastPreActivation , outputGradient2D ) ;
447436
448437 _weightsGradient = activationGradient . Transpose ( [ 1 , 0 ] ) . ToMatrix ( ) . Multiply ( lastInput2D . ToMatrix ( ) ) ;
449438 _biasesGradient = activationGradient . Sum ( [ 0 ] ) . ToVector ( ) ;
@@ -638,6 +627,7 @@ public override void ResetState()
638627 {
639628 // Clear cached values from forward and backward passes
640629 _lastInput = null ;
630+ _lastPreActivation = null ;
641631 _weightsGradient = null ;
642632 _biasesGradient = null ;
643633 }
0 commit comments