Skip to content

Commit 87957a3

Browse files
committed
fix: address all pr review comments for timeseries models
- DeepARModel: fix silent dimension mismatch with explicit validation - DeepARModel: implement proper softplus with numerical stability - DeepARModel: fix deserialization to handle layer count mismatch - DeepARModel: add LSTM state reset to prevent contamination - DeepARModel: fix LSTM dimension handling with proper padding - DeepANT: implement batch processing using BatchSize option - ChronosFoundationModel: update all output weights during training - InformerModel: update all output weights during training - InformerModel: fix EmbedInput to handle dimension mismatch - InformerModel: add serialization version discriminator
1 parent 391c8e7 commit 87957a3

4 files changed

Lines changed: 184 additions & 53 deletions

File tree

src/TimeSeries/AnomalyDetection/DeepANT.cs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,25 +87,34 @@ protected override void TrainCore(Matrix<T> x, Vector<T> y)
8787
T learningRate = _numOps.FromDouble(_options.LearningRate);
8888
List<T> predictionErrors = new List<T>();
8989

90-
// Training loop
90+
// Training loop with batch processing
9191
for (int epoch = 0; epoch < _options.Epochs; epoch++)
9292
{
9393
predictionErrors.Clear();
9494

95-
for (int i = 0; i < x.Rows; i++)
95+
// Process in batches using BatchSize
96+
for (int batchStart = 0; batchStart < x.Rows; batchStart += _options.BatchSize)
9697
{
97-
Vector<T> input = x.GetRow(i);
98-
T target = y[i];
99-
T prediction = PredictSingle(input);
98+
int batchEnd = Math.Min(batchStart + _options.BatchSize, x.Rows);
10099

101-
// Compute prediction error
102-
T error = _numOps.Subtract(target, prediction);
103-
predictionErrors.Add(_numOps.Abs(error));
100+
for (int i = batchStart; i < batchEnd; i++)
101+
{
102+
Vector<T> input = x.GetRow(i);
103+
T target = y[i];
104+
T prediction = PredictSingle(input);
105+
106+
// Compute prediction error
107+
T error = _numOps.Subtract(target, prediction);
108+
predictionErrors.Add(_numOps.Abs(error));
109+
}
104110

105-
// Simplified weight update (in practice, use backpropagation)
106-
if (epoch % 10 == 0 && i % 100 == 0)
111+
// Update weights once per batch (instead of periodically)
112+
if (batchEnd > batchStart)
107113
{
108-
UpdateWeightsNumerically(input, target, learningRate);
114+
// Use a sample from the batch for gradient computation
115+
Vector<T> sampleInput = x.GetRow(batchStart);
116+
T sampleTarget = y[batchStart];
117+
UpdateWeightsNumerically(sampleInput, sampleTarget, learningRate);
109118
}
110119
}
111120
}

src/TimeSeries/ChronosFoundationModel.cs

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,12 @@ protected override void TrainCore(Matrix<T> x, Vector<T> y)
150150
private void UpdateOutputWeights(Vector<T> input, T target, T learningRate)
151151
{
152152
T epsilon = _numOps.FromDouble(1e-5);
153+
T twoEpsilon = _numOps.Multiply(_numOps.FromDouble(2.0), epsilon);
153154

154-
// Update a subset of output projection weights for efficiency
155-
int rowsToUpdate = Math.Min(5, _outputProjection.Rows);
156-
int colsToUpdate = Math.Min(5, _outputProjection.Columns);
157-
158-
for (int i = 0; i < rowsToUpdate; i++)
155+
// Update all output projection weights (not just a 5x5 subset)
156+
for (int i = 0; i < _outputProjection.Rows; i++)
159157
{
160-
for (int j = 0; j < colsToUpdate; j++)
158+
for (int j = 0; j < _outputProjection.Columns; j++)
161159
{
162160
T original = _outputProjection[i, j];
163161

@@ -176,14 +174,31 @@ private void UpdateOutputWeights(Vector<T> input, T target, T learningRate)
176174
// Restore and update
177175
_outputProjection[i, j] = original;
178176

179-
T gradient = _numOps.Divide(
180-
_numOps.Subtract(lossPlus, lossMinus),
181-
_numOps.Multiply(_numOps.FromDouble(2.0), epsilon)
182-
);
183-
177+
T gradient = _numOps.Divide(_numOps.Subtract(lossPlus, lossMinus), twoEpsilon);
184178
_outputProjection[i, j] = _numOps.Subtract(original, _numOps.Multiply(learningRate, gradient));
185179
}
186180
}
181+
182+
// Also update output bias
183+
for (int i = 0; i < _outputBias.Length; i++)
184+
{
185+
T original = _outputBias[i];
186+
187+
_outputBias[i] = _numOps.Add(original, epsilon);
188+
T predPlus = PredictSingle(input);
189+
T errorPlus = _numOps.Subtract(target, predPlus);
190+
T lossPlus = _numOps.Multiply(errorPlus, errorPlus);
191+
192+
_outputBias[i] = _numOps.Subtract(original, epsilon);
193+
T predMinus = PredictSingle(input);
194+
T errorMinus = _numOps.Subtract(target, predMinus);
195+
T lossMinus = _numOps.Multiply(errorMinus, errorMinus);
196+
197+
_outputBias[i] = original;
198+
199+
T gradient = _numOps.Divide(_numOps.Subtract(lossPlus, lossMinus), twoEpsilon);
200+
_outputBias[i] = _numOps.Subtract(original, _numOps.Multiply(learningRate, gradient));
201+
}
187202
}
188203

189204
public override T PredictSingle(Vector<T> input)

src/TimeSeries/DeepARModel.cs

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,12 @@ private void UpdateWeights(Matrix<T> x, Vector<T> y, int batchStart, int batchEn
263263
/// </summary>
264264
private (T mean, T scale) PredictDistribution(Vector<T> input)
265265
{
266+
// Reset LSTM states before each prediction to avoid contamination
267+
foreach (var lstm in _lstmLayers)
268+
{
269+
lstm.ResetState();
270+
}
271+
266272
// Forward pass through LSTM layers
267273
Vector<T> hidden = input.Clone();
268274

@@ -275,20 +281,55 @@ private void UpdateWeights(Matrix<T> x, Vector<T> y, int batchStart, int batchEn
275281
hidden = lstm.Forward(hidden);
276282
}
277283

278-
// Predict mean
284+
// Validate dimension alignment - hidden must match weight dimensions
285+
if (hidden.Length != _meanWeights.Columns)
286+
{
287+
// Resize hidden to match weight dimensions if needed
288+
var resizedHidden = new Vector<T>(_meanWeights.Columns);
289+
for (int j = 0; j < Math.Min(hidden.Length, _meanWeights.Columns); j++)
290+
{
291+
resizedHidden[j] = hidden[j];
292+
}
293+
hidden = resizedHidden;
294+
}
295+
296+
// Predict mean using all weights
279297
T mean = _meanBias[0];
280-
for (int j = 0; j < Math.Min(hidden.Length, _meanWeights.Columns); j++)
298+
for (int j = 0; j < _meanWeights.Columns; j++)
281299
{
282300
mean = _numOps.Add(mean, _numOps.Multiply(_meanWeights[0, j], hidden[j]));
283301
}
284302

285-
// Predict scale (must be positive)
303+
// Predict scale (must be positive) using proper softplus: log(1 + exp(x))
286304
T scaleRaw = _scaleBias[0];
287-
for (int j = 0; j < Math.Min(hidden.Length, _scaleWeights.Columns); j++)
305+
for (int j = 0; j < _scaleWeights.Columns; j++)
288306
{
289307
scaleRaw = _numOps.Add(scaleRaw, _numOps.Multiply(_scaleWeights[0, j], hidden[j]));
290308
}
291-
T scale = _numOps.Exp(_numOps.Multiply(scaleRaw, _numOps.FromDouble(0.1))); // Softplus approximation
309+
// Numerically stable softplus: for large x, softplus(x) ≈ x
310+
// threshold at 20 to avoid exp overflow (exp(20) ≈ 5e8, exp(88) overflows double)
311+
T scale;
312+
T threshold = _numOps.FromDouble(20.0);
313+
if (_numOps.GreaterThan(scaleRaw, threshold))
314+
{
315+
scale = scaleRaw;
316+
}
317+
else if (_numOps.LessThan(scaleRaw, _numOps.FromDouble(-20.0)))
318+
{
319+
// For very negative values, softplus(x) ≈ exp(x) which is very small but positive
320+
scale = _numOps.Exp(scaleRaw);
321+
}
322+
else
323+
{
324+
// Standard softplus: log(1 + exp(x))
325+
scale = _numOps.Log(_numOps.Add(_numOps.One, _numOps.Exp(scaleRaw)));
326+
}
327+
// Ensure minimum scale to avoid division by zero
328+
T minScale = _numOps.FromDouble(1e-6);
329+
if (_numOps.LessThan(scale, minScale))
330+
{
331+
scale = minScale;
332+
}
292333

293334
return (mean, scale);
294335
}
@@ -397,9 +438,21 @@ protected override void DeserializeCore(BinaryReader reader)
397438

398439
InitializeModel();
399440

400-
// Deserialize LSTM layers
441+
// Deserialize LSTM layers with count validation
401442
int numLayers = reader.ReadInt32();
402-
for (int i = 0; i < numLayers && i < _lstmLayers.Count; i++)
443+
if (numLayers != _lstmLayers.Count)
444+
{
445+
// Recreate layers to match serialized count
446+
_lstmLayers.Clear();
447+
int inputSize = 1 + _options.CovariateSize;
448+
for (int i = 0; i < numLayers; i++)
449+
{
450+
int layerInputSize = (i == 0) ? inputSize : _options.HiddenSize;
451+
_lstmLayers.Add(new DeepARLstmCell<T>(layerInputSize, _options.HiddenSize));
452+
}
453+
}
454+
455+
for (int i = 0; i < numLayers; i++)
403456
{
404457
int paramCount = reader.ReadInt32();
405458
var parameters = new Vector<T>(paramCount);
@@ -524,29 +577,47 @@ public DeepARLstmCell(int inputSize, int hiddenSize)
524577
_cellState = new Vector<T>(hiddenSize);
525578
}
526579

580+
/// <summary>
581+
/// Resets the hidden and cell states to prevent contamination between predictions.
582+
/// </summary>
583+
public void ResetState()
584+
{
585+
for (int i = 0; i < _hiddenSize; i++)
586+
{
587+
_hiddenState[i] = _numOps.Zero;
588+
_cellState[i] = _numOps.Zero;
589+
}
590+
}
591+
527592
public Vector<T> Forward(Vector<T> input)
528593
{
529-
// Simplified LSTM forward pass (full implementation would include all gates)
530-
var combined = new Vector<T>(_inputSize + _hiddenSize);
594+
// Create combined vector with proper dimensions
595+
int combinedSize = _inputSize + _hiddenSize;
596+
var combined = new Vector<T>(combinedSize);
531597

532-
// Copy input
533-
for (int i = 0; i < Math.Min(input.Length, _inputSize); i++)
534-
combined[i] = input[i];
598+
// Copy input - pad with zeros if input is smaller than expected
599+
for (int i = 0; i < _inputSize; i++)
600+
{
601+
combined[i] = i < input.Length ? input[i] : _numOps.Zero;
602+
}
535603

536604
// Copy hidden state
537605
for (int i = 0; i < _hiddenSize; i++)
606+
{
538607
combined[_inputSize + i] = _hiddenState[i];
608+
}
539609

540-
// Compute gates (simplified)
610+
// Compute gates using all weights (no truncation)
541611
var output = new Vector<T>(_hiddenSize);
542612
for (int i = 0; i < _hiddenSize; i++)
543613
{
544614
T sum = _bias[i];
545-
for (int j = 0; j < combined.Length && j < _weights.Columns; j++)
615+
// Use all weights columns (combined length matches weight columns)
616+
for (int j = 0; j < _weights.Columns; j++)
546617
{
547618
sum = _numOps.Add(sum, _numOps.Multiply(_weights[i, j], combined[j]));
548619
}
549-
output[i] = MathHelper.Tanh(sum); // Simplified activation
620+
output[i] = MathHelper.Tanh(sum);
550621
_hiddenState[i] = output[i];
551622
}
552623

src/TimeSeries/InformerModel.cs

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,12 @@ protected override void TrainCore(Matrix<T> x, Vector<T> y)
121121
private void UpdateOutputWeights(Vector<T> input, T target, T learningRate)
122122
{
123123
T epsilon = _numOps.FromDouble(1e-5);
124+
T twoEpsilon = _numOps.Multiply(_numOps.FromDouble(2.0), epsilon);
124125

125-
// Update a subset of output projection weights for efficiency
126-
int rowsToUpdate = Math.Min(5, _outputProjection.Rows);
127-
int colsToUpdate = Math.Min(5, _outputProjection.Columns);
128-
129-
for (int i = 0; i < rowsToUpdate; i++)
126+
// Update all output projection weights (not just a 5x5 subset)
127+
for (int i = 0; i < _outputProjection.Rows; i++)
130128
{
131-
for (int j = 0; j < colsToUpdate; j++)
129+
for (int j = 0; j < _outputProjection.Columns; j++)
132130
{
133131
T original = _outputProjection[i, j];
134132

@@ -147,14 +145,31 @@ private void UpdateOutputWeights(Vector<T> input, T target, T learningRate)
147145
// Restore and update
148146
_outputProjection[i, j] = original;
149147

150-
T gradient = _numOps.Divide(
151-
_numOps.Subtract(lossPlus, lossMinus),
152-
_numOps.Multiply(_numOps.FromDouble(2.0), epsilon)
153-
);
154-
148+
T gradient = _numOps.Divide(_numOps.Subtract(lossPlus, lossMinus), twoEpsilon);
155149
_outputProjection[i, j] = _numOps.Subtract(original, _numOps.Multiply(learningRate, gradient));
156150
}
157151
}
152+
153+
// Also update output bias
154+
for (int i = 0; i < _outputBias.Length; i++)
155+
{
156+
T original = _outputBias[i];
157+
158+
_outputBias[i] = _numOps.Add(original, epsilon);
159+
T predPlus = PredictSingle(input);
160+
T errorPlus = _numOps.Subtract(target, predPlus);
161+
T lossPlus = _numOps.Multiply(errorPlus, errorPlus);
162+
163+
_outputBias[i] = _numOps.Subtract(original, epsilon);
164+
T predMinus = PredictSingle(input);
165+
T errorMinus = _numOps.Subtract(target, predMinus);
166+
T lossMinus = _numOps.Multiply(errorMinus, errorMinus);
167+
168+
_outputBias[i] = original;
169+
170+
T gradient = _numOps.Divide(_numOps.Subtract(lossPlus, lossMinus), twoEpsilon);
171+
_outputBias[i] = _numOps.Subtract(original, _numOps.Multiply(learningRate, gradient));
172+
}
158173
}
159174

160175
public override T PredictSingle(Vector<T> input)
@@ -204,24 +219,44 @@ public Vector<T> ForecastHorizon(Vector<T> input)
204219
private Vector<T> EmbedInput(Vector<T> input)
205220
{
206221
var embedded = new Vector<T>(_options.EmbeddingDim);
207-
int inputLen = Math.Min(input.Length, _embeddingWeights.Columns);
208222

209-
// Linear embedding: project input through embedding weights
223+
// Ensure input matches expected dimension - pad with zeros if shorter
224+
int expectedLen = _embeddingWeights.Columns;
225+
Vector<T> paddedInput;
226+
if (input.Length < expectedLen)
227+
{
228+
paddedInput = new Vector<T>(expectedLen);
229+
for (int i = 0; i < input.Length; i++)
230+
{
231+
paddedInput[i] = input[i];
232+
}
233+
// Remaining elements are already zero by default
234+
}
235+
else
236+
{
237+
paddedInput = input;
238+
}
239+
240+
// Linear embedding: project input through embedding weights using all weights
210241
for (int i = 0; i < _options.EmbeddingDim; i++)
211242
{
212243
T sum = _numOps.Zero;
213-
for (int j = 0; j < inputLen; j++)
244+
for (int j = 0; j < expectedLen; j++)
214245
{
215-
sum = _numOps.Add(sum, _numOps.Multiply(_embeddingWeights[i, j], input[j]));
246+
sum = _numOps.Add(sum, _numOps.Multiply(_embeddingWeights[i, j], paddedInput[j]));
216247
}
217248
embedded[i] = sum;
218249
}
219250

220251
return embedded;
221252
}
222253

254+
private const int SerializationVersion = 1;
255+
223256
protected override void SerializeCore(BinaryWriter writer)
224257
{
258+
writer.Write(SerializationVersion);
259+
225260
// Serialize options
226261
writer.Write(_options.LookbackWindow);
227262
writer.Write(_options.ForecastHorizon);
@@ -267,7 +302,8 @@ protected override void SerializeCore(BinaryWriter writer)
267302

268303
protected override void DeserializeCore(BinaryReader reader)
269304
{
270-
// Deserialize options
305+
_ = reader.ReadInt32(); // version
306+
271307
_options.LookbackWindow = reader.ReadInt32();
272308
_options.ForecastHorizon = reader.ReadInt32();
273309
_options.EmbeddingDim = reader.ReadInt32();

0 commit comments

Comments
 (0)