@@ -120,15 +120,59 @@ private T Detokenize(int tokenIdx)
120120
121121 protected override void TrainCore ( Matrix < T > x , Vector < T > y )
122122 {
123- // Note: This is a simplified training loop. In practice, gradients would be computed and applied.
123+ T learningRate = _numOps . FromDouble ( _options . LearningRate ) ;
124+
124125 for ( int epoch = 0 ; epoch < _options . Epochs ; epoch ++ )
125126 {
126127 for ( int i = 0 ; i < x . Rows ; i ++ )
127128 {
128129 Vector < T > input = x . GetRow ( i ) ;
130+ T target = y [ i ] ;
129131
130- // Forward pass - prediction computed for gradient calculation in full implementation
131- _ = PredictSingle ( input ) ;
132+ // Update output projection weights using numerical gradients
133+ UpdateOutputWeights ( input , target , learningRate ) ;
134+ }
135+ }
136+ }
137+
138+ /// <summary>
139+ /// Updates output projection weights using numerical gradient estimation.
140+ /// </summary>
141+ private void UpdateOutputWeights ( Vector < T > input , T target , T learningRate )
142+ {
143+ T epsilon = _numOps . FromDouble ( 1e-5 ) ;
144+
145+ // Update a subset of output projection weights for efficiency
146+ int rowsToUpdate = Math . Min ( 5 , _outputProjection . Rows ) ;
147+ int colsToUpdate = Math . Min ( 5 , _outputProjection . Columns ) ;
148+
149+ for ( int i = 0 ; i < rowsToUpdate ; i ++ )
150+ {
151+ for ( int j = 0 ; j < colsToUpdate ; j ++ )
152+ {
153+ T original = _outputProjection [ i , j ] ;
154+
155+ // Compute loss with perturbed weight (positive)
156+ _outputProjection [ i , j ] = _numOps . Add ( original , epsilon ) ;
157+ T predPlus = PredictSingle ( input ) ;
158+ T errorPlus = _numOps . Subtract ( target , predPlus ) ;
159+ T lossPlus = _numOps . Multiply ( errorPlus , errorPlus ) ;
160+
161+ // Compute loss with perturbed weight (negative)
162+ _outputProjection [ i , j ] = _numOps . Subtract ( original , epsilon ) ;
163+ T predMinus = PredictSingle ( input ) ;
164+ T errorMinus = _numOps . Subtract ( target , predMinus ) ;
165+ T lossMinus = _numOps . Multiply ( errorMinus , errorMinus ) ;
166+
167+ // Restore and update
168+ _outputProjection [ i , j ] = original ;
169+
170+ T gradient = _numOps . Divide (
171+ _numOps . Subtract ( lossPlus , lossMinus ) ,
172+ _numOps . Multiply ( _numOps . FromDouble ( 2.0 ) , epsilon )
173+ ) ;
174+
175+ _outputProjection [ i , j ] = _numOps . Subtract ( original , _numOps . Multiply ( learningRate , gradient ) ) ;
132176 }
133177 }
134178 }
@@ -238,10 +282,15 @@ public Dictionary<double, Vector<T>> ForecastWithQuantiles(Vector<T> history, do
238282
239283 protected override void SerializeCore ( BinaryWriter writer )
240284 {
285+ // Serialize options
241286 writer . Write ( _vocabularySize ) ;
242287 writer . Write ( _options . EmbeddingDim ) ;
288+ writer . Write ( _options . ContextLength ) ;
289+ writer . Write ( _options . ForecastHorizon ) ;
290+ writer . Write ( _options . NumLayers ) ;
291+ writer . Write ( _options . NumHeads ) ;
243292
244- // Serialize vocabulary
293+ // Serialize vocabulary centroids
245294 for ( int i = 0 ; i < _vocabularyCentroids . Length ; i ++ )
246295 writer . Write ( Convert . ToDouble ( _vocabularyCentroids [ i ] ) ) ;
247296
@@ -251,23 +300,71 @@ protected override void SerializeCore(BinaryWriter writer)
251300 for ( int i = 0 ; i < _tokenEmbeddings . Rows ; i ++ )
252301 for ( int j = 0 ; j < _tokenEmbeddings . Columns ; j ++ )
253302 writer . Write ( Convert . ToDouble ( _tokenEmbeddings [ i , j ] ) ) ;
303+
304+ // Serialize transformer layers
305+ writer . Write ( _transformerLayers . Count ) ;
306+ foreach ( var layer in _transformerLayers )
307+ {
308+ layer . Serialize ( writer ) ;
309+ }
310+
311+ // Serialize output projection
312+ writer . Write ( _outputProjection . Rows ) ;
313+ writer . Write ( _outputProjection . Columns ) ;
314+ for ( int i = 0 ; i < _outputProjection . Rows ; i ++ )
315+ for ( int j = 0 ; j < _outputProjection . Columns ; j ++ )
316+ writer . Write ( Convert . ToDouble ( _outputProjection [ i , j ] ) ) ;
317+
318+ // Serialize output bias
319+ writer . Write ( _outputBias . Length ) ;
320+ for ( int i = 0 ; i < _outputBias . Length ; i ++ )
321+ writer . Write ( Convert . ToDouble ( _outputBias [ i ] ) ) ;
254322 }
255323
256324 protected override void DeserializeCore ( BinaryReader reader )
257325 {
326+ // Deserialize options
258327 _vocabularySize = reader . ReadInt32 ( ) ;
259328 _options . EmbeddingDim = reader . ReadInt32 ( ) ;
329+ _options . ContextLength = reader . ReadInt32 ( ) ;
330+ _options . ForecastHorizon = reader . ReadInt32 ( ) ;
331+ _options . NumLayers = reader . ReadInt32 ( ) ;
332+ _options . NumHeads = reader . ReadInt32 ( ) ;
260333
334+ // Deserialize vocabulary centroids
261335 _vocabularyCentroids = new Vector < T > ( _vocabularySize ) ;
262336 for ( int i = 0 ; i < _vocabularySize ; i ++ )
263337 _vocabularyCentroids [ i ] = _numOps . FromDouble ( reader . ReadDouble ( ) ) ;
264338
339+ // Deserialize token embeddings
265340 int rows = reader . ReadInt32 ( ) ;
266341 int cols = reader . ReadInt32 ( ) ;
267342 _tokenEmbeddings = new Matrix < T > ( rows , cols ) ;
268343 for ( int i = 0 ; i < rows ; i ++ )
269344 for ( int j = 0 ; j < cols ; j ++ )
270345 _tokenEmbeddings [ i , j ] = _numOps . FromDouble ( reader . ReadDouble ( ) ) ;
346+
347+ // Deserialize transformer layers
348+ int numLayers = reader . ReadInt32 ( ) ;
349+ _transformerLayers = new List < TransformerBlock < T > > ( numLayers ) ;
350+ for ( int i = 0 ; i < numLayers ; i ++ )
351+ {
352+ _transformerLayers . Add ( TransformerBlock < T > . Deserialize ( reader ) ) ;
353+ }
354+
355+ // Deserialize output projection
356+ rows = reader . ReadInt32 ( ) ;
357+ cols = reader . ReadInt32 ( ) ;
358+ _outputProjection = new Matrix < T > ( rows , cols ) ;
359+ for ( int i = 0 ; i < rows ; i ++ )
360+ for ( int j = 0 ; j < cols ; j ++ )
361+ _outputProjection [ i , j ] = _numOps . FromDouble ( reader . ReadDouble ( ) ) ;
362+
363+ // Deserialize output bias
364+ int biasLen = reader . ReadInt32 ( ) ;
365+ _outputBias = new Vector < T > ( biasLen ) ;
366+ for ( int i = 0 ; i < biasLen ; i ++ )
367+ _outputBias [ i ] = _numOps . FromDouble ( reader . ReadDouble ( ) ) ;
271368 }
272369
273370 public override ModelMetadata < T > GetModelMetadata ( )
@@ -342,7 +439,7 @@ public ChronosOptions(ChronosOptions<T> other)
342439internal class TransformerBlock < T >
343440{
344441 private readonly INumericOperations < T > _numOps ;
345- private readonly Matrix < T > _weights ;
442+ private Matrix < T > _weights ;
346443
347444 public int ParameterCount => _weights . Rows * _weights . Columns ;
348445
@@ -358,6 +455,15 @@ public TransformerBlock(int embeddingDim, int numHeads)
358455 _weights [ i , j ] = _numOps . FromDouble ( ( random . NextDouble ( ) * 2 - 1 ) * stddev ) ;
359456 }
360457
458+ /// <summary>
459+ /// Creates a TransformerBlock for deserialization.
460+ /// </summary>
461+ private TransformerBlock ( )
462+ {
463+ _numOps = MathHelper . GetNumericOperations < T > ( ) ;
464+ _weights = new Matrix < T > ( 0 , 0 ) ;
465+ }
466+
361467 public Vector < T > Forward ( Vector < T > input )
362468 {
363469 var output = new Vector < T > ( input . Length ) ;
@@ -372,4 +478,34 @@ public Vector<T> Forward(Vector<T> input)
372478 }
373479 return output ;
374480 }
481+
482+ /// <summary>
483+ /// Serializes the transformer block weights.
484+ /// </summary>
485+ public void Serialize ( BinaryWriter writer )
486+ {
487+ writer . Write ( _weights . Rows ) ;
488+ writer . Write ( _weights . Columns ) ;
489+ for ( int i = 0 ; i < _weights . Rows ; i ++ )
490+ for ( int j = 0 ; j < _weights . Columns ; j ++ )
491+ writer . Write ( Convert . ToDouble ( _weights [ i , j ] ) ) ;
492+ }
493+
494+ /// <summary>
495+ /// Deserializes a transformer block from binary data.
496+ /// </summary>
497+ public static TransformerBlock < T > Deserialize ( BinaryReader reader )
498+ {
499+ var block = new TransformerBlock < T > ( ) ;
500+ var numOps = MathHelper . GetNumericOperations < T > ( ) ;
501+
502+ int rows = reader . ReadInt32 ( ) ;
503+ int cols = reader . ReadInt32 ( ) ;
504+ block . _weights = new Matrix < T > ( rows , cols ) ;
505+ for ( int i = 0 ; i < rows ; i ++ )
506+ for ( int j = 0 ; j < cols ; j ++ )
507+ block . _weights [ i , j ] = numOps . FromDouble ( reader . ReadDouble ( ) ) ;
508+
509+ return block ;
510+ }
375511}
0 commit comments