Skip to content

Commit 7833258

Browse files
authored
fix(#1406): pinn train silently no-op when reusing fused-plan thread cache (#1411)
both `scientificmltests.hamiltoniannn_trainupdatesparameters` and `...lagrangiannn_trainupdatesparameters` failed only when they ran after `universaldifferentialequation_trainupdatesparameters` in the same class (each passes in isolation). symptom: post-train params identical to pre-train, lastloss nonzero but inconsistent with the real forward output, train returned in ~4 ms (way too fast for an eager forward+backward+adam step). root cause: `compiledtapetrainingstep<t>` keeps a `[threadstatic]` cache of compiled plans, `_cachedparameters`, `_persistentinput`, `_persistenttarget`, etc. the cache key for both `step` and `trystepwithfusedoptimizer` is just `(input shape, target shape)`. ude trains first → caches a compiled plan + parameter array whose leaves are ude's tensor refs. hamilton/lagrangian then call train with the same `[1,2]→[1,1]` shapes → cache hit → plan replays against ude's (now uninvolved) parameter tensors and the optimizer "trains" them, leaving the new model's tensors untouched. fix: track the trainable-layer set's reference identities alongside the cached parameters. when the next call arrives with a different layer set (element-wise `referenceequals`), force `invalidate()` so the next call rebuilds the cache against the new model's tensors. the steady-state cost is one ref-compare per layer per train step; on a model switch we pay one cache rebuild, which is exactly the correctness boundary. also added `_cachedlayersetidentities` to `invalidate()` so manual invalidation paths reset it too. Closes #1406
1 parent 159db1b commit 7833258

1 file changed

Lines changed: 88 additions & 0 deletions

File tree

src/Training/CompiledTapeTrainingStep.cs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,27 @@ public static class CompiledTapeTrainingStep<T>
3434
[ThreadStatic]
3535
private static Tensor<T>[]? _cachedParameters;
3636

37+
/// <summary>
38+
/// AiDotNet#1406: identity of the trainable-layer set that produced
39+
/// <see cref="_cachedParameters"/> and the cached compiled plan. The
40+
/// per-thread cache above keys plans by tensor shape only, so two
41+
/// distinct models with the same input/target shapes — created in
42+
/// sequence on the same test thread, for example — would otherwise
43+
/// share a single compiled plan. The plan's tensor leaves capture the
44+
/// FIRST model's parameter refs, and a replay then "trains" model A's
45+
/// (potentially garbage-collected) tensors while model B's params sit
46+
/// untouched. Surfaced as
47+
/// <c>ScientificMLTests.{Hamiltonian,Lagrangian}NeuralNetwork_TrainUpdatesParameters</c>
48+
/// failing only when run after
49+
/// <c>UniversalDifferentialEquation_TrainUpdatesParameters</c>.
50+
/// We track the layer-set identity (element-wise <see cref="object.ReferenceEquals"/>)
51+
/// and force <see cref="Invalidate"/> when it diverges from the cached
52+
/// one. Per-instance optimizer state is reset as part of Invalidate, so
53+
/// the next model gets a clean compile.
54+
/// </summary>
55+
[ThreadStatic]
56+
private static object?[]? _cachedLayerSetIdentities;
57+
3758
/// <summary>
3859
/// AiDotNet#1331: persistent input tensor reused across <see cref="TryStepWithFusedOptimizer"/>
3960
/// calls. The compiled plan captures whatever tensor ref the trace lambda saw — if every call
@@ -170,6 +191,17 @@ public static T Step(
170191

171192
try
172193
{
194+
// AiDotNet#1406: drop the cached compiled plan + parameter array
195+
// when the caller has switched to a different layer set. Without
196+
// this the next non-matching model on the same thread would
197+
// replay the previous model's plan against its own (uninvolved)
198+
// tensors — the plan was compiled with the FIRST model's leaf
199+
// refs and the optimizer step would update those tensors, not
200+
// the caller's.
201+
if (InvalidateIfLayerSetChanged(layers))
202+
{
203+
// Caches cleared; cache field rebound below.
204+
}
173205
var cache = _cache ??= new CompiledModelCache<T>();
174206

175207
// Force layer initialization before collecting parameters.
@@ -183,7 +215,9 @@ public static T Step(
183215
// TryStepWithFusedOptimizer. Shared/tied weights would otherwise
184216
// appear twice in the array — wrong for both eager-SGD here AND
185217
// wrong for the fused kernel's m/v buffers downstream.
218+
bool firstCollectThisLifecycle = _cachedParameters is null;
186219
var parameters = _cachedParameters ??= CollectDeduplicatedParameters(layers);
220+
if (firstCollectThisLifecycle) RememberLayerSet(layers);
187221

188222
// Zero gradients before forward pass
189223
foreach (var layer in layers)
@@ -236,10 +270,52 @@ public static T Step(
236270
/// <summary>
237271
/// Invalidates the compiled plan cache. Call when model structure changes.
238272
/// </summary>
273+
/// <summary>
274+
/// AiDotNet#1406: invalidates the per-thread compiled-plan cache when the
275+
/// supplied trainable-layer set is not the same one that produced the
276+
/// currently-cached parameters (compared element-wise by reference
277+
/// identity). Returns <c>true</c> if an invalidation occurred. Cheap on
278+
/// the steady state (single ref-compare per layer when the set matches);
279+
/// only allocates on a model switch.
280+
/// </summary>
281+
private static bool InvalidateIfLayerSetChanged<TLayer>(IReadOnlyList<TLayer> layers) where TLayer : class
282+
{
283+
var cached = _cachedLayerSetIdentities;
284+
if (cached is null) return false;
285+
286+
if (cached.Length != layers.Count)
287+
{
288+
Invalidate();
289+
return true;
290+
}
291+
for (int i = 0; i < cached.Length; i++)
292+
{
293+
if (!ReferenceEquals(cached[i], layers[i]))
294+
{
295+
Invalidate();
296+
return true;
297+
}
298+
}
299+
return false;
300+
}
301+
302+
/// <summary>
303+
/// Captures the current trainable-layer set's reference identities so a
304+
/// subsequent call can detect a model switch. Called immediately after
305+
/// the cache is populated for the first time in a given lifecycle.
306+
/// </summary>
307+
private static void RememberLayerSet<TLayer>(IReadOnlyList<TLayer> layers) where TLayer : class
308+
{
309+
var ids = new object?[layers.Count];
310+
for (int i = 0; i < layers.Count; i++) ids[i] = layers[i];
311+
_cachedLayerSetIdentities = ids;
312+
}
313+
239314
public static void Invalidate()
240315
{
241316
_cache?.Invalidate();
242317
_cachedParameters = null;
318+
_cachedLayerSetIdentities = null;
243319
_configuredPlan = null;
244320
_configuredOptimizerConfig = null;
245321
// AiDotNet#1331: drop the persistent input/target tensors so the next
@@ -326,6 +402,16 @@ or AiDotNet.Tensors.Engines.Compilation.OptimizerType.Adam
326402

327403
try
328404
{
405+
// AiDotNet#1406: drop the cached compiled plan + parameter array
406+
// when the caller has switched to a different layer set. The
407+
// per-thread cache keys plans by shape only, so two distinct
408+
// models with the same (input, target) shapes — chained on the
409+
// same test thread, for example — would otherwise replay the
410+
// FIRST model's compiled plan against tensors that no longer
411+
// exist on the live model. Symptom: post-Train params identical
412+
// to pre-Train, even though LastLoss reports a non-zero loss
413+
// (the plan ran on the previous model's now-stale tensors).
414+
InvalidateIfLayerSetChanged(layers);
329415
var cache = _cache ??= new CompiledModelCache<T>();
330416

331417
// AiDotNet#1331: ensure the persistent input/target tensors exist
@@ -406,7 +492,9 @@ or AiDotNet.Tensors.Engines.Compilation.OptimizerType.Adam
406492
// weights (same Tensor<T> instance referenced by multiple layers)
407493
// would otherwise drive the fused kernel's m/v buffers to update
408494
// the same parameter twice per step, breaking Adam's moment math.
495+
bool firstCollectThisLifecycle = _cachedParameters is null;
409496
var parameters = _cachedParameters ??= CollectDeduplicatedParameters(layers);
497+
if (firstCollectThisLifecycle) RememberLayerSet(layers);
410498

411499
foreach (var layer in layers)
412500
layer.ZeroGrad();

0 commit comments

Comments
 (0)