|
| 1 | +using System.Threading.Tasks; |
| 2 | +using AiDotNet.ActivationFunctions; |
| 3 | +using AiDotNet.Enums; |
| 4 | +using AiDotNet.Interfaces; |
| 5 | +using AiDotNet.LossFunctions; |
| 6 | +using AiDotNet.Models.Options; |
| 7 | +using AiDotNet.NeuralNetworks; |
| 8 | +using AiDotNet.NeuralNetworks.Layers; |
| 9 | +using AiDotNet.Optimizers; |
| 10 | +using AiDotNet.Tensors.LinearAlgebra; |
| 11 | +using Xunit; |
| 12 | +using Xunit.Abstractions; |
| 13 | + |
| 14 | +namespace AiDotNetTests.IntegrationTests.NeuralNetworks; |
| 15 | + |
| 16 | +/// <summary> |
| 17 | +/// Consumer-side regression coverage for AiDotNet#1346 (FlashAttentionLayer |
| 18 | +/// degenerate output on the compiled fused-Adam path). The original root |
| 19 | +/// cause — Engine.FlashAttention missing its GraphMode.IsActive lazy-graph |
| 20 | +/// recording branch — was fixed by AiDotNet.Tensors PR #362 and ships in |
| 21 | +/// AiDotNet.Tensors NuGet 0.81.3. This file pins TWO things: |
| 22 | +/// <list type="number"> |
| 23 | +/// <item>The engine-side fix actually reaches the AiDotNet fused training |
| 24 | +/// path: a Transformer<float> whose layer stack contains |
| 25 | +/// <see cref="FlashAttentionLayer{T}"/> engages |
| 26 | +/// <see cref="AiDotNet.Training.CompiledTapeTrainingStep{T}.TryStepWithFusedOptimizer"/> |
| 27 | +/// when trained via the public network API (canary test below).</item> |
| 28 | +/// <item>The remaining consumer-side gap that the #1346 investigation |
| 29 | +/// surfaced — Tensors plan-loss-readout silently returning literal 0 |
| 30 | +/// instead of the actual NaN/Inf when a CCE-style chain produces NaN under |
| 31 | +/// many trainable parameters — is tracked at |
| 32 | +/// <a href="https://github.com/ooples/AiDotNet.Tensors/issues/396">AiDotNet.Tensors#396</a>. |
| 33 | +/// The Skip'd regression test below auto-enables once that fix lands and |
| 34 | +/// the consuming NuGet version bumps.</item> |
| 35 | +/// </list> |
| 36 | +/// </summary> |
| 37 | +/// <remarks> |
| 38 | +/// PR #1386 review (CodeRabbit C8Bm6 + Copilot Drjj5): both tests reset and |
| 39 | +/// read <see cref="AiDotNet.Training.CompiledTapeTrainingStep{T}"/>'s |
| 40 | +/// thread-static fused-step counter and cache. Default xUnit per-class |
| 41 | +/// parallelization would race those resets/reads against any other test |
| 42 | +/// touching the same global state (FusedOptimizerIntegrationTests etc.), |
| 43 | +/// producing flaky engaged-count assertions or cross-test counter leak. |
| 44 | +/// Join the existing "FusedOptimizerGlobalState" collection (defined in |
| 45 | +/// <see cref="FusedOptimizerCollection"/>) so xUnit serializes every test |
| 46 | +/// in this class with every other CompiledTapeTrainingStep-mutating test. |
| 47 | +/// </remarks> |
| 48 | +[Collection("FusedOptimizerGlobalState")] |
| 49 | +public class FlashAttentionFusedCompiledTrainingIssue1346Tests |
| 50 | +{ |
| 51 | + private readonly ITestOutputHelper _output; |
| 52 | + |
| 53 | + public FlashAttentionFusedCompiledTrainingIssue1346Tests(ITestOutputHelper output) |
| 54 | + { |
| 55 | + _output = output; |
| 56 | + } |
| 57 | + |
| 58 | + private const int SeqLen = 4; |
| 59 | + private const int EmbedDim = 16; |
| 60 | + private const int HeadCount = 2; |
| 61 | + private const int NumClasses = 8; |
| 62 | + |
| 63 | + /// <summary> |
| 64 | + /// Builds a small Transformer whose explicit layer list contains |
| 65 | + /// <see cref="FlashAttentionLayer{T}"/> as the attention block — the same |
| 66 | + /// drop-in-replacement pattern AiDotNet#1346 documented as broken on the |
| 67 | + /// fused-Adam path before AiDotNet.Tensors PR #362 landed. |
| 68 | + /// </summary> |
| 69 | + private static Transformer<float> BuildFlashAttentionTransformer(double learningRate = 0.01) |
| 70 | + { |
| 71 | + // No EmbeddingLayer: input is continuous-valued [1, seq, embed]. |
| 72 | + // EmbeddingLayer-first trips a pre-existing TransformerArchitecture |
| 73 | + // input-dim validator quirk (see TransformerCustomLayerValidationIssue1317IntegrationTests |
| 74 | + // .CustomTransformerLayerStack_AcceptsFlashAttentionLayerAsDropInReplacement) |
| 75 | + // that is unrelated to #1346. |
| 76 | + var layers = new List<ILayer<float>> |
| 77 | + { |
| 78 | + new FlashAttentionLayer<float>(SeqLen, EmbedDim, HeadCount), |
| 79 | + new LayerNormalizationLayer<float>(), |
| 80 | + new SequenceTokenSliceLayer<float>(SequenceTokenSliceLayer<float>.Position.Last), |
| 81 | + new DenseLayer<float>(NumClasses, (IActivationFunction<float>)new IdentityActivation<float>()) |
| 82 | + }; |
| 83 | + |
| 84 | + var arch = new TransformerArchitecture<float>( |
| 85 | + inputType: InputType.TwoDimensional, |
| 86 | + taskType: NeuralNetworkTaskType.SequenceClassification, |
| 87 | + numEncoderLayers: 0, // explicit layers: list replaces the auto-built encoder block (#1382) |
| 88 | + numDecoderLayers: 0, |
| 89 | + numHeads: HeadCount, |
| 90 | + modelDimension: EmbedDim, |
| 91 | + feedForwardDimension: EmbedDim, |
| 92 | + complexity: NetworkComplexity.Medium, |
| 93 | + inputSize: SeqLen * EmbedDim, |
| 94 | + outputSize: NumClasses, |
| 95 | + dropoutRate: 0.0, |
| 96 | + maxSequenceLength: SeqLen, |
| 97 | + vocabularySize: NumClasses, |
| 98 | + usePositionalEncoding: false, |
| 99 | + temperature: 1.0, |
| 100 | + sequencePooling: null, |
| 101 | + layers: layers); |
| 102 | + |
| 103 | + var optOptions = new AdamOptimizerOptions<float, Tensor<float>, Tensor<float>> |
| 104 | + { |
| 105 | + InitialLearningRate = learningRate, |
| 106 | + Beta1 = 0.9, |
| 107 | + Beta2 = 0.999, |
| 108 | + Epsilon = 1e-8 |
| 109 | + }; |
| 110 | + var optimizer = new AdamOptimizer<float, Tensor<float>, Tensor<float>>(null, optOptions); |
| 111 | + |
| 112 | + return new Transformer<float>( |
| 113 | + arch, |
| 114 | + lossFunction: new CategoricalCrossEntropyLoss<float>(), |
| 115 | + optimizer: optimizer); |
| 116 | + } |
| 117 | + |
| 118 | + private static Tensor<float> BuildFingerprintInput(int classIndex, int seed) |
| 119 | + { |
| 120 | + var t = new Tensor<float>([1, SeqLen, EmbedDim]); |
| 121 | + var rng = new System.Random(seed * 1000 + classIndex); |
| 122 | + for (int s = 0; s < SeqLen; s++) |
| 123 | + { |
| 124 | + for (int e = 0; e < EmbedDim; e++) |
| 125 | + { |
| 126 | + t[0, s, e] = (float)(classIndex + 0.05 * rng.NextDouble()); |
| 127 | + } |
| 128 | + } |
| 129 | + return t; |
| 130 | + } |
| 131 | + |
| 132 | + private static Tensor<float> BuildOneHotTarget(int classIndex) |
| 133 | + { |
| 134 | + var t = new Tensor<float>([1, NumClasses]); |
| 135 | + t[0, classIndex] = 1f; |
| 136 | + return t; |
| 137 | + } |
| 138 | + |
| 139 | + /// <summary> |
| 140 | + /// CANARY for AiDotNet.Tensors PR #362's reach into the public AiDotNet |
| 141 | + /// fused-Adam training path. A Transformer whose layer stack contains |
| 142 | + /// FlashAttentionLayer must engage |
| 143 | + /// <see cref="AiDotNet.Training.CompiledTapeTrainingStep{T}.TryStepWithFusedOptimizer"/> |
| 144 | + /// on the first Train() call. Pre-fix (before #362) the GraphMode lazy |
| 145 | + /// trace inside the fused path would still record everything except |
| 146 | + /// FlashAttention; the fused step would run successfully (so this |
| 147 | + /// counter would still increment) but downstream gradient flow would |
| 148 | + /// be broken. This test specifically verifies the canary — a regression |
| 149 | + /// that prevents fused-path engagement at all (e.g. a future Tensors |
| 150 | + /// change that throws during GraphMode trace) would flip this red. |
| 151 | + /// </summary> |
| 152 | + [Fact(Timeout = 60000)] |
| 153 | + public async Task FlashAttentionLayer_TrainViaFusedCompiledAdam_EngagesFusedPath() |
| 154 | + { |
| 155 | + await Task.Yield(); |
| 156 | + AiDotNet.Training.CompiledTapeTrainingStep<float>.ResetFusedStepCount(); |
| 157 | + AiDotNet.Training.CompiledTapeTrainingStep<float>.Invalidate(); |
| 158 | + |
| 159 | + var model = BuildFlashAttentionTransformer(); |
| 160 | + model.SetTrainingMode(true); |
| 161 | + |
| 162 | + var input = BuildFingerprintInput(0, seed: 7); |
| 163 | + var target = BuildOneHotTarget(0); |
| 164 | + |
| 165 | + model.Train(input, target); |
| 166 | + |
| 167 | + long fusedSteps = AiDotNet.Training.CompiledTapeTrainingStep<float>.GetFusedStepCount(); |
| 168 | + _output.WriteLine($"Fused step count after 1 Train() call: {fusedSteps}"); |
| 169 | + |
| 170 | + Assert.True(fusedSteps > 0, |
| 171 | + $"FlashAttentionLayer Transformer fell back to eager on first Train() — " + |
| 172 | + $"CompiledTapeTrainingStep<float>.GetFusedStepCount() = {fusedSteps}. " + |
| 173 | + "This indicates Engine.FlashAttention threw during GraphMode trace OR a " + |
| 174 | + "downstream compile gate rejected the FA-containing graph. See AiDotNet.Tensors " + |
| 175 | + "PR #362 and AiDotNet issue #1346."); |
| 176 | + } |
| 177 | + |
| 178 | + /// <summary> |
| 179 | + /// Future-fix regression for the consumer-side gap surfaced during AiDotNet#1346 |
| 180 | + /// investigation. Tracks <a href="https://github.com/ooples/AiDotNet.Tensors/issues/396">AiDotNet.Tensors#396</a>: |
| 181 | + /// when a model with multiple trainable parameters routes raw logits through |
| 182 | + /// <see cref="CategoricalCrossEntropyLoss{T}"/> on the fused-Adam path, |
| 183 | + /// the loss chain's <c>log(negative_logit + eps)</c> produces NaN that |
| 184 | + /// SHOULD propagate to <see cref="NeuralNetworkBase{T}.GetLastLoss"/> but |
| 185 | + /// instead surfaces as literal float 0. This silent zeroing was the |
| 186 | + /// actual reason the original #1346 reporter's HE PathB sanity test |
| 187 | + /// stayed at top1=0% / top5=100% / ppl=V after the engine-side |
| 188 | + /// FlashAttention fix landed — the consumer's loss-curve looked |
| 189 | + /// "converged" while gradients were corrupted. |
| 190 | + /// <para> |
| 191 | + /// Skipped until AiDotNet.Tensors#396 ships and the NuGet version bumps. |
| 192 | + /// On enabling: this test must report <c>lastLoss = NaN</c> (or a finite |
| 193 | + /// positive value if the underlying chain doesn't actually NaN at this seed), |
| 194 | + /// but MUST NOT report literal 0. |
| 195 | + /// </para> |
| 196 | + /// </summary> |
| 197 | + [Fact(Timeout = 60000, Skip = "Blocked on AiDotNet.Tensors#396 — fused-Adam loss-readout returns literal 0 instead of NaN under CCE+raw-logits+many-params. Unskip once that fix lands and the AiDotNet.Tensors NuGet version bumps past the build containing it.")] |
| 198 | + public async Task DenseIdentity_CCE_OnFusedAdam_DoesNotSilentlyZeroNaN() |
| 199 | + { |
| 200 | + await Task.Yield(); |
| 201 | + AiDotNet.Training.CompiledTapeTrainingStep<float>.ResetFusedStepCount(); |
| 202 | + AiDotNet.Training.CompiledTapeTrainingStep<float>.Invalidate(); |
| 203 | + |
| 204 | + // Force-negative-logit setup: Dense layer with IdentityActivation passes |
| 205 | + // raw logits (potentially negative) into CCE, whose log(p + 1e-7) goes NaN. |
| 206 | + var layers = new List<ILayer<float>> |
| 207 | + { |
| 208 | + new DenseLayer<float>(EmbedDim, (IActivationFunction<float>)new IdentityActivation<float>()), |
| 209 | + new LayerNormalizationLayer<float>(), |
| 210 | + new SequenceTokenSliceLayer<float>(SequenceTokenSliceLayer<float>.Position.Last), |
| 211 | + new DenseLayer<float>(NumClasses, (IActivationFunction<float>)new IdentityActivation<float>()) |
| 212 | + }; |
| 213 | + var arch = new TransformerArchitecture<float>( |
| 214 | + inputType: InputType.TwoDimensional, |
| 215 | + taskType: NeuralNetworkTaskType.SequenceClassification, |
| 216 | + numEncoderLayers: 0, numDecoderLayers: 0, |
| 217 | + numHeads: HeadCount, modelDimension: EmbedDim, feedForwardDimension: EmbedDim, |
| 218 | + complexity: NetworkComplexity.Medium, |
| 219 | + inputSize: SeqLen * EmbedDim, outputSize: NumClasses, |
| 220 | + dropoutRate: 0.0, maxSequenceLength: SeqLen, vocabularySize: NumClasses, |
| 221 | + usePositionalEncoding: false, temperature: 1.0, |
| 222 | + sequencePooling: null, layers: layers); |
| 223 | + var optimizer = new AdamOptimizer<float, Tensor<float>, Tensor<float>>(null, |
| 224 | + new AdamOptimizerOptions<float, Tensor<float>, Tensor<float>> |
| 225 | + { |
| 226 | + InitialLearningRate = 0.01, |
| 227 | + Beta1 = 0.9, Beta2 = 0.999, Epsilon = 1e-8 |
| 228 | + }); |
| 229 | + var model = new Transformer<float>(arch, |
| 230 | + lossFunction: new CategoricalCrossEntropyLoss<float>(), |
| 231 | + optimizer: optimizer); |
| 232 | + model.SetTrainingMode(true); |
| 233 | + |
| 234 | + var input = BuildFingerprintInput(0, seed: 42); |
| 235 | + var target = BuildOneHotTarget(0); |
| 236 | + |
| 237 | + model.Train(input, target); |
| 238 | + long fusedSteps = AiDotNet.Training.CompiledTapeTrainingStep<float>.GetFusedStepCount(); |
| 239 | + float lastLoss = model.GetLastLoss(); |
| 240 | + |
| 241 | + _output.WriteLine($"Identity+CCE on fused-Adam: fusedSteps={fusedSteps}, lastLoss={lastLoss}, " + |
| 242 | + $"IsNaN={float.IsNaN(lastLoss)}, IsInfinity={float.IsInfinity(lastLoss)}, " + |
| 243 | + $"IsZero={lastLoss == 0f}"); |
| 244 | + |
| 245 | + Assert.True(fusedSteps > 0, "Fused path must have engaged"); |
| 246 | + |
| 247 | + // The signal: lastLoss must be either a sane positive number OR NaN/Inf. |
| 248 | + // Literal 0 means the silent-failure mode behind AiDotNet#1346 / Tensors#396 — |
| 249 | + // NaN was produced inside the loss chain but the fused readout silently |
| 250 | + // zeroed it, so the consumer thinks training is converging while |
| 251 | + // gradients are corrupted and the model never moves off random init. |
| 252 | + bool isSilentlyZero = lastLoss == 0f && !float.IsNaN(lastLoss); |
| 253 | + Assert.False(isSilentlyZero, |
| 254 | + $"AiDotNet.Tensors#396 regression: Identity+CCE on fused-Adam reports " + |
| 255 | + $"lastLoss=0 (literal 0, not NaN). The fused readout is silently zeroing " + |
| 256 | + "the NaN that the CCE log(negative_logit+eps) chain produces. Consumer " + |
| 257 | + "would see 'loss converged' while gradients are corrupted."); |
| 258 | + } |
| 259 | +} |
0 commit comments