diff --git a/.github/workflows/sonarcloud.yml b/.github/workflows/sonarcloud.yml index 2731f1d7a6..d2aaeba1f6 100644 --- a/.github/workflows/sonarcloud.yml +++ b/.github/workflows/sonarcloud.yml @@ -189,10 +189,10 @@ jobs: project: tests/AiDotNet.Tests/AiDotNetTests.csproj framework: net10.0 filter: 'Category!=GPU&Category!=Stress&(FullyQualifiedName~UnitTests.Helpers|FullyQualifiedName~UnitTests.Inference|FullyQualifiedName~UnitTests.Interpretability)' - - name: Unit - 06 JIT/KD/Schedulers/LA + - name: Unit - 06 KD/Schedulers/LA project: tests/AiDotNet.Tests/AiDotNetTests.csproj framework: net10.0 - filter: 'Category!=GPU&Category!=Stress&(FullyQualifiedName~UnitTests.JitCompiler|FullyQualifiedName~UnitTests.KnowledgeDistillation|FullyQualifiedName~UnitTests.LearningRateSchedulers|FullyQualifiedName~UnitTests.LinearAlgebra)' + filter: 'Category!=GPU&Category!=Stress&(FullyQualifiedName~UnitTests.KnowledgeDistillation|FullyQualifiedName~UnitTests.LearningRateSchedulers|FullyQualifiedName~UnitTests.LinearAlgebra)' - name: Unit - 07 Logging/Loss/Meta/Mixed/Compression project: tests/AiDotNet.Tests/AiDotNetTests.csproj framework: net10.0 @@ -270,7 +270,6 @@ jobs: FullyQualifiedName!~UnitTests.Helpers& FullyQualifiedName!~UnitTests.Inference& FullyQualifiedName!~UnitTests.Interpretability& - FullyQualifiedName!~UnitTests.JitCompiler& FullyQualifiedName!~UnitTests.KnowledgeDistillation& FullyQualifiedName!~UnitTests.LearningRateSchedulers& FullyQualifiedName!~UnitTests.LinearAlgebra& @@ -294,10 +293,6 @@ jobs: # ===================================================================== # OTHER NON-UNIT TESTS (InferenceOpt, PromptEng, Recovery, Playground) # ===================================================================== - - name: Other - InferenceOptimization - project: tests/AiDotNet.Tests/AiDotNetTests.csproj - framework: net10.0 - filter: 'Category!=GPU&Category!=Stress&FullyQualifiedName~AiDotNet.Tests.InferenceOptimization' - name: Other - PromptEngineering project: tests/AiDotNet.Tests/AiDotNetTests.csproj framework: net10.0 diff --git a/AiDotNetBenchmarkTests/Helpers/MockNeuralNetwork.cs b/AiDotNetBenchmarkTests/Helpers/MockNeuralNetwork.cs index 3536e0c6e2..6ff3e190c4 100644 --- a/AiDotNetBenchmarkTests/Helpers/MockNeuralNetwork.cs +++ b/AiDotNetBenchmarkTests/Helpers/MockNeuralNetwork.cs @@ -137,13 +137,5 @@ public void ApplyGradients(Vector gradients, T learningRate) // Mock implementation } - // IJitCompilable implementation - public bool SupportsJitCompilation => false; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation not supported in mock model"); - } - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/AiDotNetBenchmarkTests/InferenceOptimization/AttentionBenchmark.cs b/AiDotNetBenchmarkTests/InferenceOptimization/AttentionBenchmark.cs deleted file mode 100644 index 775ff211e7..0000000000 --- a/AiDotNetBenchmarkTests/InferenceOptimization/AttentionBenchmark.cs +++ /dev/null @@ -1,146 +0,0 @@ -using System; -using AiDotNet.InferenceOptimization; -using AiDotNet.InferenceOptimization.Kernels; -using AiDotNet.LinearAlgebra; -using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Jobs; - -namespace AiDotNetBenchmarkTests.InferenceOptimization -{ - /// - /// Benchmarks for fused attention kernel - /// - [SimpleJob(RuntimeMoniker.Net80)] - [MemoryDiagnoser] - [CsvExporter] - [HtmlExporter] - public class AttentionBenchmark - { - private Tensor _q = null!; - private Tensor _k = null!; - private Tensor _v = null!; - private AttentionKernel _attentionKernel = null!; - - [Params(64, 128, 256)] - public int SequenceLength { get; set; } - - [Params(32, 64)] - public int FeatureDim { get; set; } - - [GlobalSetup] - public void Setup() - { - OptimizationInitializer.Initialize(enableProfiling: false); - - _attentionKernel = new AttentionKernel(); - - // Initialize Q, K, V tensors with deterministic data - int tensorSize = SequenceLength * FeatureDim; - var dataQ = new float[tensorSize]; - var dataK = new float[tensorSize]; - var dataV = new float[tensorSize]; - - for (int i = 0; i < tensorSize; i++) - { - dataQ[i] = DeterministicValue(i); - dataK[i] = DeterministicValue(i + 1_000_000); - dataV[i] = DeterministicValue(i + 2_000_000); - } - - _q = new Tensor(dataQ, new[] { 1, SequenceLength, FeatureDim }); - _k = new Tensor(dataK, new[] { 1, SequenceLength, FeatureDim }); - _v = new Tensor(dataV, new[] { 1, SequenceLength, FeatureDim }); - } - - private static float DeterministicValue(int i) - { - // Stable deterministic value in [0, 1) without PRNG APIs (avoids security hotspot noise in analysis). - unchecked - { - uint x = (uint)(i * 1664525 + 1013904223); - return (x & 0x00FFFFFF) / 16777216f; - } - } - - [Benchmark(Baseline = true)] - public Tensor NaiveAttention() - { - // Naive implementation: QK^T, softmax, multiply by V - float scale = 1.0f / (float)Math.Sqrt(FeatureDim); - - // Get arrays for direct access in naive benchmark - var qData = _q.ToArray(); - var kData = _k.ToArray(); - var vData = _v.ToArray(); - - // Compute attention scores - var scores = new float[SequenceLength * SequenceLength]; - - for (int i = 0; i < SequenceLength; i++) - { - for (int j = 0; j < SequenceLength; j++) - { - float score = 0.0f; - for (int k = 0; k < FeatureDim; k++) - { - score += qData[i * FeatureDim + k] * kData[j * FeatureDim + k]; - } - scores[i * SequenceLength + j] = score * scale; - } - } - - // Apply softmax - for (int i = 0; i < SequenceLength; i++) - { - float maxVal = float.NegativeInfinity; - for (int j = 0; j < SequenceLength; j++) - { - if (scores[i * SequenceLength + j] > maxVal) - maxVal = scores[i * SequenceLength + j]; - } - - float sum = 0.0f; - for (int j = 0; j < SequenceLength; j++) - { - scores[i * SequenceLength + j] = (float)Math.Exp(scores[i * SequenceLength + j] - maxVal); - sum += scores[i * SequenceLength + j]; - } - - for (int j = 0; j < SequenceLength; j++) - { - scores[i * SequenceLength + j] /= sum; - } - } - - // Multiply by V - var resultData = new float[SequenceLength * FeatureDim]; - - for (int i = 0; i < SequenceLength; i++) - { - for (int j = 0; j < FeatureDim; j++) - { - float sum = 0.0f; - for (int k = 0; k < SequenceLength; k++) - { - sum += scores[i * SequenceLength + k] * vData[k * FeatureDim + j]; - } - resultData[i * FeatureDim + j] = sum; - } - } - - return new Tensor(resultData, new[] { 1, SequenceLength, FeatureDim }); - } - - [Benchmark] - public Tensor OptimizedAttention() - { - return _attentionKernel.Execute(_q, _k, _v); - } - - [Benchmark] - public Tensor MultiHeadAttention() - { - return _attentionKernel.MultiHeadAttention(_q, _k, _v, numHeads: 8); - } - } -} diff --git a/AiDotNetBenchmarkTests/InferenceOptimization/GemmBenchmark.cs b/AiDotNetBenchmarkTests/InferenceOptimization/GemmBenchmark.cs deleted file mode 100644 index 65f97c8d5f..0000000000 --- a/AiDotNetBenchmarkTests/InferenceOptimization/GemmBenchmark.cs +++ /dev/null @@ -1,116 +0,0 @@ -using System; -using AiDotNet.InferenceOptimization; -using AiDotNet.InferenceOptimization.Kernels; -using AiDotNet.LinearAlgebra; -using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Jobs; - -namespace AiDotNetBenchmarkTests.InferenceOptimization -{ - /// - /// Benchmarks for GEMM (General Matrix Multiplication) kernel - /// Tests optimized implementation against naive implementation - /// - [SimpleJob(RuntimeMoniker.Net80)] - [MemoryDiagnoser] - [CsvExporter] - [HtmlExporter] - public class GemmBenchmark - { - private Tensor? _matrixA; - private Tensor? _matrixB; - private float[]? _matrixAData; - private float[]? _matrixBData; - private GemmKernel? _gemmKernel; - - [Params(64, 128, 256, 512, 1024)] - public int MatrixSize { get; set; } - - [GlobalSetup] - public void Setup() - { - OptimizationInitializer.Initialize(enableProfiling: false); - - _gemmKernel = new GemmKernel(); - - // Initialize matrices with deterministic data (avoids security hotspot noise in analysis) - var dataA = new float[MatrixSize * MatrixSize]; - var dataB = new float[MatrixSize * MatrixSize]; - - for (int i = 0; i < dataA.Length; i++) - { - dataA[i] = DeterministicValue(i); - } - - for (int i = 0; i < dataB.Length; i++) - { - dataB[i] = DeterministicValue(i + 1_000_000); - } - - _matrixA = new Tensor(dataA, new[] { MatrixSize, MatrixSize }); - _matrixB = new Tensor(dataB, new[] { MatrixSize, MatrixSize }); - - // Pre-convert to arrays for naive benchmark (avoid ToArray() overhead in benchmark) - _matrixAData = _matrixA.ToArray(); - _matrixBData = _matrixB.ToArray(); - } - - private static float DeterministicValue(int i) - { - unchecked - { - uint x = (uint)(i * 1664525 + 1013904223); - return (x & 0x00FFFFFF) / 16777216f; - } - } - - [Benchmark(Baseline = true)] - public Tensor NaiveGemm() - { - if (_matrixAData is null || _matrixBData is null) - { - throw new InvalidOperationException("Setup must be called before running benchmarks"); - } - - // Naive triple-nested loop implementation - var resultData = new float[MatrixSize * MatrixSize]; - - for (int i = 0; i < MatrixSize; i++) - { - for (int j = 0; j < MatrixSize; j++) - { - float sum = 0.0f; - for (int k = 0; k < MatrixSize; k++) - { - sum += _matrixAData[i * MatrixSize + k] * _matrixBData[k * MatrixSize + j]; - } - resultData[i * MatrixSize + j] = sum; - } - } - - return new Tensor(resultData, new[] { MatrixSize, MatrixSize }); - } - - [Benchmark] - public Tensor OptimizedGemm() - { - if (_gemmKernel is null || _matrixA is null || _matrixB is null) - { - throw new InvalidOperationException("Setup must be called before running benchmarks"); - } - - return _gemmKernel.Execute(_matrixA, _matrixB); - } - - [Benchmark] - public Tensor OptimizedGemmTranspose() - { - if (_gemmKernel is null || _matrixA is null || _matrixB is null) - { - throw new InvalidOperationException("Setup must be called before running benchmarks"); - } - - return _gemmKernel.GemmTransposeB(_matrixA, _matrixB); - } - } -} diff --git a/AiDotNetBenchmarkTests/InferenceOptimization/SimdBenchmark.cs b/AiDotNetBenchmarkTests/InferenceOptimization/SimdBenchmark.cs deleted file mode 100644 index 7cb96d5825..0000000000 --- a/AiDotNetBenchmarkTests/InferenceOptimization/SimdBenchmark.cs +++ /dev/null @@ -1,162 +0,0 @@ -#nullable disable -using System; -using AiDotNet.InferenceOptimization; -using AiDotNet.Tensors.Engines.Simd; -using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Configs; -using BenchmarkDotNet.Jobs; - -namespace AiDotNetBenchmarkTests.InferenceOptimization -{ - /// - /// Benchmarks for SIMD-optimized operations - /// - [SimpleJob(RuntimeMoniker.Net80)] - [MemoryDiagnoser] - [CsvExporter] - [HtmlExporter] - [GroupBenchmarksBy(BenchmarkLogicalGroupRule.ByCategory)] - public class SimdBenchmark - { - private float[] _arrayA; - private float[] _arrayB; - private float[] _result; - - [Params(1000, 10000, 100000, 1000000)] - public int ArraySize { get; set; } - - [GlobalSetup] - public void Setup() - { - OptimizationInitializer.Initialize(enableProfiling: false); - - _arrayA = new float[ArraySize]; - _arrayB = new float[ArraySize]; - _result = new float[ArraySize]; - - for (int i = 0; i < ArraySize; i++) - { - _arrayA[i] = DeterministicValue(i); - _arrayB[i] = DeterministicValue(i + 1_000_000); - } - } - - #region Vector Addition - - [Benchmark(Baseline = true)] - [BenchmarkCategory("VectorAdd")] - public void VectorAdd_Scalar() - { - for (int i = 0; i < ArraySize; i++) - { - _result[i] = _arrayA[i] + _arrayB[i]; - } - } - - [Benchmark] - [BenchmarkCategory("VectorAdd")] - public void VectorAdd_SIMD() - { - SimdKernels.VectorAdd(_arrayA, _arrayB, _result); - } - - #endregion - - #region Vector Multiplication - - [Benchmark(Baseline = true)] - [BenchmarkCategory("VectorMultiply")] - public void VectorMultiply_Scalar() - { - for (int i = 0; i < ArraySize; i++) - { - _result[i] = _arrayA[i] * _arrayB[i]; - } - } - - [Benchmark] - [BenchmarkCategory("VectorMultiply")] - public void VectorMultiply_SIMD() - { - SimdKernels.VectorMultiply(_arrayA, _arrayB, _result); - } - - #endregion - - #region Dot Product - - [Benchmark(Baseline = true)] - [BenchmarkCategory("DotProduct")] - public float DotProduct_Scalar() - { - float sum = 0.0f; - for (int i = 0; i < ArraySize; i++) - { - sum += _arrayA[i] * _arrayB[i]; - } - return sum; - } - - [Benchmark] - [BenchmarkCategory("DotProduct")] - public float DotProduct_SIMD() - { - return SimdKernels.DotProduct(_arrayA, _arrayB); - } - - #endregion - - #region ReLU Activation - - [Benchmark(Baseline = true)] - [BenchmarkCategory("ReLU")] - public void ReLU_Scalar() - { - for (int i = 0; i < ArraySize; i++) - { - _result[i] = Math.Max(0.0f, _arrayA[i]); - } - } - - [Benchmark] - [BenchmarkCategory("ReLU")] - public void ReLU_SIMD() - { - SimdKernels.ReLU(_arrayA, _result); - } - - #endregion - - #region Sum Reduction - - [Benchmark(Baseline = true)] - [BenchmarkCategory("Sum")] - public float Sum_Scalar() - { - float sum = 0.0f; - for (int i = 0; i < ArraySize; i++) - { - sum += _arrayA[i]; - } - return sum; - } - - [Benchmark] - [BenchmarkCategory("Sum")] - public float Sum_SIMD() - { - return SimdKernels.Sum(_arrayA); - } - - #endregion - - private static float DeterministicValue(int i) - { - unchecked - { - uint x = (uint)(i * 1664525 + 1013904223); - return (x & 0x00FFFFFF) / 16777216f; - } - } - } -} diff --git a/Directory.Packages.props b/Directory.Packages.props index b9ff49519f..3668f59532 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -6,7 +6,7 @@ - + diff --git a/docs/GOLDEN_STANDARD_PATTERNS.md b/docs/GOLDEN_STANDARD_PATTERNS.md index 1a7816b370..e02b1b9c85 100644 --- a/docs/GOLDEN_STANDARD_PATTERNS.md +++ b/docs/GOLDEN_STANDARD_PATTERNS.md @@ -102,11 +102,6 @@ public ExampleLayer( /// public override bool SupportsTraining => true; -/// -/// Gets whether this layer supports JIT compilation. -/// -public override bool SupportsJitCompilation => CanActivationBeJitted(); - public override Tensor Forward(Tensor input) { _lastInput = input; @@ -172,11 +167,6 @@ public override void ResetState() _lastInput = new Tensor(InputShape); _lastOutput = new Tensor(OutputShape); } - -public override ComputationNode ExportComputationGraph(List> inputNodes) -{ - // Create computation graph for JIT compilation -} ``` ### Serialization @@ -650,10 +640,6 @@ public abstract IFullModel, Tensor> DeepCopy(); public virtual Vector ComputeGradients(Tensor input, Tensor target, ILossFunction? lossFunction = null); public virtual void ApplyGradients(Vector gradients, T learningRate); -// IJitCompilable -public virtual bool SupportsJitCompilation => false; -public virtual ComputationNode ExportComputationGraph(List> inputNodes); - // IFullModel public ILossFunction DefaultLossFunction => LossFunction; ``` @@ -786,14 +772,12 @@ SomeWeight = NumOps.FromDouble(0.01); - [ ] Uses `NumOps` for all numeric operations - [ ] Uses `RandomHelper` for random number generation - [ ] Implements `SupportsTraining` property -- [ ] Implements `SupportsJitCompilation` property - [ ] Implements `Forward(Tensor)` - [ ] Implements `Backward(Tensor)` - [ ] Implements `UpdateParameters(T)` - [ ] Implements `GetParameters()` - [ ] Implements `SetParameters(Vector)` - [ ] Implements `ResetState()` -- [ ] Implements `ExportComputationGraph(List>)` - [ ] Overrides `Serialize(BinaryWriter)` if additional state - [ ] Overrides `Deserialize(BinaryReader)` if additional state - [ ] Has comprehensive XML documentation with "For Beginners" sections @@ -852,7 +836,6 @@ Key patterns demonstrated: - Forward/Backward pass caching (`_lastInput`, `_lastOutput`) - Gradient storage (`_kernelsGradient`, `_biasesGradient`) - Serialization/Deserialization with `NumOps.ToDouble`/`FromDouble` -- JIT compilation support via `ExportComputationGraph` ### Reference Layer: TimeEmbeddingLayer diff --git a/docs/JIT-Compiler-Usage-Guide.md b/docs/JIT-Compiler-Usage-Guide.md deleted file mode 100644 index 33fff1f60b..0000000000 --- a/docs/JIT-Compiler-Usage-Guide.md +++ /dev/null @@ -1,352 +0,0 @@ -# JIT Compiler Usage Guide - -## Overview - -The AiDotNet JIT (Just-In-Time) Compiler dramatically improves the performance of computation graphs by compiling them to optimized executable code. This can provide **5-10x speedups** for typical neural network operations. - -## Quick Start - -### Basic Usage - -```csharp -using AiDotNet.Autodiff; -using AiDotNet.JitCompiler; - -// Create a computation graph -var x = new ComputationNode(inputTensor, requiresGradient: false); -var weights = new ComputationNode(weightsTensor, requiresGradient: false); -var bias = new ComputationNode(biasTensor, requiresGradient: false); - -var matmul = TensorOperations.MatrixMultiply(x, weights); -var add = TensorOperations.Add(matmul, bias); -var result = TensorOperations.ReLU(add); - -// Create JIT compiler -var jit = new JitCompiler(); - -// Compile the graph -var compiled = jit.Compile(result, new List> { x, weights, bias }); - -// Execute the compiled function (much faster!) -var output = compiled(new[] { inputTensor, weightsTensor, biasTensor }); -``` - -### With Compilation Statistics - -```csharp -// Compile with statistics to see what optimizations were applied -var (compiledFunc, stats) = jit.CompileWithStats(result, inputs); - -Console.WriteLine(stats); -// Output: -// Compilation Stats: -// Original operations: 15 -// Optimized operations: 8 -// Operations eliminated: 7 (46.7%) -// Optimizations applied: Constant Folding, Dead Code Elimination, Operation Fusion -// Compilation time: 12.34ms -// Cache hit: false - -// Use the compiled function -var output = compiledFunc(inputTensors); -``` - -## How It Works - -The JIT compiler follows a multi-stage pipeline: - -### 1. IR Construction -Converts the ComputationNode graph into an Intermediate Representation (IR): -- Each operation becomes an IROp -- Tensors are assigned IDs -- Graph structure is preserved - -### 2. Optimization -Applies multiple optimization passes: - -#### Constant Folding -Evaluates operations with constant inputs at compile time: -``` -Before: t2 = Add(Constant(2), Constant(3)); t3 = Mul(t2, input) -After: t2 = Constant(5); t3 = Mul(t2, input) -``` - -#### Dead Code Elimination -Removes operations whose results are never used: -``` -Before: t2 = Add(a, b); t3 = Mul(a, b); Output: t2 -After: t2 = Add(a, b); Output: t2 (t3 removed!) -``` - -#### Operation Fusion -Combines multiple operations into fused operations: -``` -Before: t2 = MatMul(x, w); t3 = Add(t2, b); t4 = ReLU(t3) -After: t4 = FusedLinearReLU(x, w, b) (3 ops → 1 op!) -``` - -### 3. Code Generation -Generates executable .NET code using Expression Trees: -- Converts each IR operation to a .NET expression -- Builds a lambda function -- Compiles to native code via .NET JIT - -### 4. Caching -Compiled functions are cached by graph structure: -- First compilation: ~10-50ms (depends on graph size) -- Subsequent compilations of same structure: instant! - -## Configuration - -### Custom Compiler Options - -```csharp -var options = new JitCompilerOptions -{ - EnableConstantFolding = true, // Default: true - EnableDeadCodeElimination = true, // Default: true - EnableOperationFusion = true, // Default: true - EnableCaching = true // Default: true -}; - -var jit = new JitCompiler(options); -``` - -### Disabling Optimizations for Debugging - -```csharp -var debugOptions = new JitCompilerOptions -{ - EnableConstantFolding = false, - EnableDeadCodeElimination = false, - EnableOperationFusion = false, - EnableCaching = false // Force recompilation every time -}; - -var debugJit = new JitCompiler(debugOptions); -``` - -## Best Practices - -### 1. Reuse Compiled Functions -The compiled function can be called many times with different tensor values: - -```csharp -// Compile once -var compiled = jit.Compile(modelOutput, modelInputs); - -// Use many times -for (int epoch = 0; epoch < 100; epoch++) -{ - for (int batch = 0; batch < batches.Count; batch++) - { - var output = compiled(batches[batch]); // Fast execution! - // ... training logic ... - } -} -``` - -### 2. Set Operation Metadata for JIT -For optimal JIT compilation, set operation type when creating nodes: - -```csharp -var result = new ComputationNode(value) -{ - OperationType = "Add", - OperationParams = new Dictionary - { - // Include operation-specific parameters if needed - } -}; -``` - -The `TensorOperations` methods will automatically set this metadata in future updates. - -### 3. Cache Management - -```csharp -// Get cache statistics -var cacheStats = jit.GetCacheStats(); -Console.WriteLine($"Cached graphs: {cacheStats.CachedGraphCount}"); -Console.WriteLine($"Memory used: {cacheStats.EstimatedMemoryBytes / 1024} KB"); - -// Clear cache if needed (e.g., memory pressure) -jit.ClearCache(); -``` - -### 4. Monitor Compilation Performance - -```csharp -var (compiledFunc, stats) = jit.CompileWithStats(graph, inputs); - -if (!stats.CacheHit) -{ - Console.WriteLine($"Compiled new graph in {stats.CompilationTime.TotalMilliseconds}ms"); - Console.WriteLine($"Optimized away {stats.OptimizationPercentage:F1}% of operations"); -} -``` - -## Performance Expectations - -### Typical Speedups - -| Graph Type | Operations | Speedup | Notes | -|-----------|-----------|---------|-------| -| Small linear layer | 3-5 ops | 3-5x | Less overhead benefit | -| Deep MLP | 20-50 ops | 5-8x | Good optimization opportunity | -| CNN layer | 10-30 ops | 7-10x | Convolution fusion helps | -| Transformer block | 50-100 ops | 8-12x | Many fusion opportunities | - -### When to Use JIT - -**Best for:** -- Inference (forward pass only) -- Repeated execution of same graph structure -- Large models with many operations -- Production deployments - -**Less beneficial for:** -- Graphs that change structure frequently -- Very small operations (compilation overhead) - -## Common Patterns - -### Model Inference - -```csharp -public class JitCompiledModel -{ - private readonly JitCompiler _jit = new(); - private Func[], Tensor[]>? _compiledForward; - - public Tensor Forward(Tensor input) - { - // Build computation graph - var inputNode = new ComputationNode(input); - var output = BuildGraph(inputNode); - - // Compile on first call - if (_compiledForward == null) - { - _compiledForward = _jit.Compile(output, new List> { inputNode }); - } - - // Execute compiled version - var result = _compiledForward(new[] { input }); - return result[0]; - } -} -``` - -### Batch Processing - -```csharp -var jit = new JitCompiler(); -var compiled = jit.Compile(batchGraph, batchInputs); - -Parallel.ForEach(batches, batch => -{ - var output = compiled(batch); // Thread-safe execution - ProcessOutput(output); -}); -``` - -## Troubleshooting - -### "Node does not have OperationType metadata" - -**Problem:** ComputationNode doesn't have operation type information. - -**Solution:** Ensure you're using TensorOperations methods that set metadata, or manually set: -```csharp -node.OperationType = "Add"; -node.OperationParams = new Dictionary(); -``` - -### Compilation is slow - -**Problem:** Graph compilation takes too long. - -**Solutions:** -1. Enable caching (default) -2. Compile during initialization, not in hot path -3. Reduce graph size if possible -4. Disable expensive optimizations if needed - -### Cache memory usage high - -**Problem:** Too many compiled graphs cached. - -**Solutions:** -```csharp -// Monitor cache -var stats = jit.GetCacheStats(); -if (stats.EstimatedMemoryBytes > threshold) -{ - jit.ClearCache(); -} -``` - -## Future Enhancements - -Planned improvements: -- [x] Support for backward pass (gradient) compilation -- [ ] GPU code generation -- [ ] More fusion patterns -- [ ] Advanced optimizations (loop unrolling, vectorization hints) -- [ ] Profiling and auto-tuning - -## Examples - -See the `examples/JitCompilerExample.cs` file for complete working examples. - -## API Reference - -### JitCompiler - -#### Methods - -- `Func[], Tensor[]> Compile(ComputationNode outputNode, List> inputs)` - - Compiles a computation graph to executable code - -- `(Func[], Tensor[]>, CompilationStats) CompileWithStats(...)` - - Compiles and returns statistics - -- `Func[], Tensor[]> CompileBackward(ComputationNode outputNode, List> inputs)` - - Compiles a backward pass (gradient computation) graph to executable code - -- `(Func[], Tensor[]>, CompilationStats) CompileBackwardWithStats(...)` - - Compiles backward pass and returns statistics - -- `void ClearCache()` - - Clears the compiled graph cache - -- `CacheStats GetCacheStats()` - - Gets cache statistics - -### JitCompilerOptions - -#### Properties - -- `bool EnableConstantFolding` - Enable constant folding optimization (default: true) -- `bool EnableDeadCodeElimination` - Enable dead code elimination (default: true) -- `bool EnableOperationFusion` - Enable operation fusion (default: true) -- `bool EnableCaching` - Enable caching of compiled graphs (default: true) - -### CompilationStats - -#### Properties - -- `int OriginalOperationCount` - Operations before optimization -- `int OptimizedOperationCount` - Operations after optimization -- `List OptimizationsApplied` - Applied optimization passes -- `TimeSpan CompilationTime` - Time to compile -- `bool CacheHit` - Whether result came from cache -- `int OperationsEliminated` - Operations removed by optimization -- `double OptimizationPercentage` - Percentage of operations optimized away - -## Conclusion - -The JIT compiler provides significant performance improvements for computation graph execution with minimal code changes. Simply create a compiler, call `Compile()`, and enjoy 5-10x speedups! - -For questions or issues, please file an issue on GitHub. diff --git a/docs/JIT_ACTIVATION_MAPPING.md b/docs/JIT_ACTIVATION_MAPPING.md deleted file mode 100644 index 94d5915e0e..0000000000 --- a/docs/JIT_ACTIVATION_MAPPING.md +++ /dev/null @@ -1,376 +0,0 @@ -# JIT Activation Mapping Reference - -This document provides a complete reference for all activation functions available in AiDotNet, their JIT compilation support status, and how to use them in your layers. - -## Quick Reference - -**Total Activations**: 37 -**Production-Ready**: 10 -**Available (Pending Integration)**: 27 - ---- - -## Production-Ready Activations (10) - -These activations are fully integrated into DenseLayer and ready for use in JIT compilation. - -### ReLU Family (1) - -| Activation Class | TensorOperations Method | IEngine Method | Parameters | Status | -|------------------|-------------------------|----------------|------------|--------| -| `ReLUActivation` | `TensorOperations.ReLU(node)` | `IEngine.ReLU(tensor)` | None | ✅ Ready | - -**Usage Example:** -```csharp -// In CanActivationBeJitted() -if (ScalarActivation is ReLUActivation) - return true; - -// In ApplyActivationToGraph() -if (ScalarActivation is ReLUActivation) - return TensorOperations.ReLU(input); -``` - -**Forward Function**: `f(x) = max(0, x)` - -**Use Cases**: Default activation for hidden layers in most neural networks. - ---- - -### Sigmoid Family (5) - -| Activation Class | TensorOperations Method | IEngine Method | Parameters | Status | -|------------------|-------------------------|----------------|------------|--------| -| `SigmoidActivation` | `TensorOperations.Sigmoid(node)` | `IEngine.Sigmoid(tensor)` | None | ✅ Ready | -| `TanhActivation` | `TensorOperations.Tanh(node)` | `IEngine.Tanh(tensor)` | None | ✅ Ready | -| `SwishActivation` | `TensorOperations.Swish(node)` | `IEngine.Swish(tensor)` | None | ✅ Ready | -| `SiLUActivation` | `TensorOperations.SiLU(node)` | `IEngine.SiLU(tensor)` | None | ✅ Ready | -| `MishActivation` | `TensorOperations.Mish(node)` | `IEngine.Mish(tensor)` | None | ✅ Ready | - -**Usage Example (Sigmoid):** -```csharp -// In CanActivationBeJitted() -if (ScalarActivation is SigmoidActivation) - return true; - -// In ApplyActivationToGraph() -if (ScalarActivation is SigmoidActivation) - return TensorOperations.Sigmoid(input); -``` - -**Forward Functions**: -- **Sigmoid**: `f(x) = 1 / (1 + e^(-x))` -- **Tanh**: `f(x) = (e^x - e^(-x)) / (e^x + e^(-x))` -- **Swish**: `f(x) = x * sigmoid(x)` (also known as SiLU) -- **SiLU**: Same as Swish -- **Mish**: `f(x) = x * tanh(softplus(x))` - -**Use Cases**: -- **Sigmoid**: Binary classification output layers, LSTM gates -- **Tanh**: RNN hidden states, centered outputs (-1 to 1) -- **Swish/SiLU**: Modern alternative to ReLU with smooth gradients -- **Mish**: Self-regularized activation, good for deep networks - ---- - -### Modern Activations (2) - -| Activation Class | TensorOperations Method | IEngine Method | Parameters | Status | -|------------------|-------------------------|----------------|------------|--------| -| `GELUActivation` | `TensorOperations.GELU(node)` | `IEngine.GELU(tensor)` | None | ✅ Ready | -| `ELUActivation` | `TensorOperations.ELU(node, alpha)` | `IEngine.ELU(tensor, alpha)` | `alpha` (default: 1.0) | ✅ Ready | - -**Usage Example (GELU):** -```csharp -// In CanActivationBeJitted() -if (ScalarActivation is GELUActivation) - return true; - -// In ApplyActivationToGraph() -if (ScalarActivation is GELUActivation) - return TensorOperations.GELU(input); -``` - -**Usage Example (ELU with parameter):** -```csharp -// In CanActivationBeJitted() -if (ScalarActivation is ELUActivation) - return true; - -// In ApplyActivationToGraph() -if (ScalarActivation is ELUActivation elu) - return TensorOperations.ELU(input, elu.Alpha); -``` - -**Forward Functions**: -- **GELU**: `f(x) = x * Φ(x)` where Φ is the cumulative distribution function of the standard normal distribution -- **ELU**: `f(x) = x if x > 0, else alpha * (e^x - 1)` - -**Use Cases**: -- **GELU**: Used in Transformers (BERT, GPT), superior to ReLU for NLP tasks -- **ELU**: Reduces vanishing gradient problem, smooth negative values - ---- - -### Vector Activations (1) - -| Activation Class | TensorOperations Method | IEngine Method | Parameters | Status | -|------------------|-------------------------|----------------|------------|--------| -| `SoftmaxActivation` | `TensorOperations.Softmax(node, axis)` | `IEngine.Softmax(tensor, axis)` | `axis` (default: -1) | ✅ Ready | - -**Usage Example:** -```csharp -// In CanActivationBeJitted() -if (VectorActivation is SoftmaxActivation) - return true; - -// In ApplyActivationToGraph() -if (VectorActivation is SoftmaxActivation) - return TensorOperations.Softmax(input); -``` - -**Forward Function**: `f(x_i) = e^(x_i) / Σ(e^(x_j))` - -**Use Cases**: Multi-class classification output layers, attention mechanisms. - ---- - -### Identity (1) - -| Activation Class | TensorOperations Method | IEngine Method | Parameters | Status | -|------------------|-------------------------|----------------|------------|--------| -| `IdentityActivation` | `input` (no-op) | N/A | None | ✅ Ready | - -**Usage Example:** -```csharp -// In CanActivationBeJitted() -if (ScalarActivation is IdentityActivation) - return true; - -// In ApplyActivationToGraph() -if (ScalarActivation is IdentityActivation) - return input; // No transformation -``` - -**Forward Function**: `f(x) = x` - -**Use Cases**: Linear layers, skip connections, output layers for regression. - ---- - -## Available Activations - Pending Integration (27) - -These activations have TensorOperations methods implemented but are not yet integrated into layer implementations. To use them, follow the pattern shown in the "Production-Ready" section above. - -### ReLU Family (7) - -| Activation Class | TensorOperations Method | Parameters | Forward Function | IEngine Status | -|------------------|-------------------------|------------|------------------|----------------| -| `LeakyReLUActivation` | `TensorOperations.LeakyReLU(node, negativeSlope)` | `negativeSlope` (default: 0.01) | `f(x) = max(negativeSlope*x, x)` | ✅ Integrated | -| `SELUActivation` | `TensorOperations.SELU(node)` | None | `f(x) = scale * (max(0,x) + min(0, alpha*(e^x-1)))` | ✅ Integrated | -| `CELUActivation` | `TensorOperations.CELU(node, alpha)` | `alpha` (default: 1.0) | `f(x) = max(0,x) + min(0, alpha*(e^(x/alpha)-1))` | ✅ Integrated | -| `PReLUActivation` | `TensorOperations.PReLU(node, alpha)` | `alpha` (default: 0.25) | `f(x) = max(alpha*x, x)` | ✅ Integrated | -| `RReLUActivation` | `TensorOperations.RReLU(node, lower, upper)` | `lower` (0.125), `upper` (0.333) | `f(x) = max(a*x, x)` where a ~ U(lower, upper) | ✅ Integrated | -| `ThresholdedReLUActivation` | `TensorOperations.ThresholdedReLU(node, threshold)` | `threshold` (default: 1.0) | `f(x) = x if x > threshold, else 0` | ✅ Integrated | - -**Integration Example (LeakyReLU):** -```csharp -// Add to CanActivationBeJitted() -if (ScalarActivation is LeakyReLUActivation) - return true; - -// Add to ApplyActivationToGraph() -if (ScalarActivation is LeakyReLUActivation leakyRelu) - return TensorOperations.LeakyReLU(input, leakyRelu.NegativeSlope); -``` - ---- - -### Sigmoid Family (9) - -| Activation Class | TensorOperations Method | Parameters | Forward Function | IEngine Status | -|------------------|-------------------------|------------|------------------|----------------| -| `HardSigmoidActivation` | `TensorOperations.HardSigmoid(node)` | None | `f(x) = clip((x+1)/2, 0, 1)` | ✅ Integrated | -| `HardTanhActivation` | `TensorOperations.HardTanh(node)` | None | `f(x) = clip(x, -1, 1)` | ✅ Integrated | -| `ScaledTanhActivation` | `TensorOperations.ScaledTanh(node, alpha, beta)` | `alpha` (1.0), `beta` (1.0) | `f(x) = alpha * tanh(beta * x)` | ✅ Integrated | -| `SoftplusActivation` | `TensorOperations.Softplus(node)` | None | `f(x) = log(1 + e^x)` | ✅ Integrated | -| `SoftsignActivation` | `TensorOperations.Softsign(node)` | None | `f(x) = x / (1 + abs(x))` | ✅ Integrated | -| `BentIdentityActivation` | `TensorOperations.BentIdentity(node)` | None | `f(x) = (sqrt(x^2 + 1) - 1)/2 + x` | ✅ Integrated | - -**Integration Example (Softplus):** -```csharp -// Add to CanActivationBeJitted() -if (ScalarActivation is SoftplusActivation) - return true; - -// Add to ApplyActivationToGraph() -if (ScalarActivation is SoftplusActivation) - return TensorOperations.Softplus(input); -``` - ---- - -### Softmax Family (3) - -| Activation Class | TensorOperations Method | Parameters | Forward Function | IEngine Status | -|------------------|-------------------------|------------|------------------|----------------| -| `SoftminActivation` | `TensorOperations.Softmin(node, axis)` | `axis` (default: -1) | `f(x_i) = e^(-x_i) / Σ(e^(-x_j))` | ✅ Integrated | -| `LogSoftmaxActivation` | `TensorOperations.LogSoftmax(node, axis)` | `axis` (default: -1) | `f(x_i) = log(e^(x_i) / Σ(e^(x_j)))` | ✅ Integrated | -| `LogSoftminActivation` | `TensorOperations.LogSoftmin(node, axis)` | `axis` (default: -1) | `f(x_i) = log(e^(-x_i) / Σ(e^(-x_j)))` | ✅ Integrated | - -**Integration Example (LogSoftmax):** -```csharp -// Add to CanActivationBeJitted() - check VectorActivation -if (VectorActivation is LogSoftmaxActivation) - return true; - -// Add to ApplyActivationToGraph() - check VectorActivation -if (VectorActivation is LogSoftmaxActivation) - return TensorOperations.LogSoftmax(input); -``` - ---- - -### Special Activations (8) - -| Activation Class | TensorOperations Method | Parameters | Forward Function | IEngine Status | -|------------------|-------------------------|------------|------------------|----------------| -| `SignActivation` | `TensorOperations.Sign(node)` | None | `f(x) = 1 if x > 0, -1 if x < 0, 0 if x == 0` | ✅ Integrated | -| `GaussianActivation` | `TensorOperations.Gaussian(node)` | None | `f(x) = e^(-x^2)` | ✅ Integrated | -| `ISRUActivation` | `TensorOperations.ISRU(node, alpha)` | `alpha` (default: 1.0) | `f(x) = x / sqrt(1 + alpha*x^2)` | ✅ Integrated | -| `LiSHTActivation` | `TensorOperations.LiSHT(node)` | None | `f(x) = x * tanh(x)` | ✅ Integrated | -| `SQRBFActivation` | `TensorOperations.SQRBF(node, center, width)` | `center` (0.0), `width` (1.0) | `f(x) = e^(-((x-center)/width)^2)` | ✅ Integrated | -| `SquashActivation` | `TensorOperations.Squash(node)` | None | `f(x) = (norm^2 / (1 + norm^2)) * (x / norm)` | ✅ Integrated | -| `BinarySpikingActivation` | `TensorOperations.BinarySpiking(node, threshold)` | `threshold` (default: 0.0) | `f(x) = 1 if x > threshold, else 0` | ✅ Integrated | - -**Integration Example (Gaussian):** -```csharp -// Add to CanActivationBeJitted() -if (ScalarActivation is GaussianActivation) - return true; - -// Add to ApplyActivationToGraph() -if (ScalarActivation is GaussianActivation) - return TensorOperations.Gaussian(input); -``` - ---- - -### Complex Activations - Placeholder Status (6) - -These activations have placeholder implementations in TensorOperations. Full implementation requires complex algorithms and will be completed in the gradient computation phase. - -| Activation Class | TensorOperations Method | Parameters | Description | Status | -|------------------|-------------------------|------------|-------------|--------| -| `SparsemaxActivation` | `TensorOperations.Sparsemax(node, axis)` | `axis` (default: -1) | Projects onto simplex, produces sparse outputs | ⚠️ Placeholder | -| `SphericalSoftmaxActivation` | `TensorOperations.SphericalSoftmax(node, axis)` | `axis` (default: -1) | Normalizes to unit sphere | ⚠️ Placeholder | -| `GumbelSoftmaxActivation` | `TensorOperations.GumbelSoftmax(node, temp, axis)` | `temp` (1.0), `axis` (-1) | Differentiable sampling | ⚠️ Placeholder | -| `TaylorSoftmaxActivation` | `TensorOperations.TaylorSoftmax(node, order, axis)` | `order` (2), `axis` (-1) | Taylor approximation of softmax | ⚠️ Placeholder | -| `HierarchicalSoftmaxActivation` | `TensorOperations.HierarchicalSoftmax(node)` | None | Tree-structured softmax | ⚠️ Placeholder | -| `MaxoutActivation` | `TensorOperations.Maxout(node, numPieces)` | `numPieces` (default: 2) | Learnable piecewise linear | ⚠️ Placeholder | - -**Note**: These activations currently throw `NotImplementedException` for backward pass. Do not use in production until fully implemented. - ---- - -## Backward Pass Status - -**Current Status**: Placeholder implementations only - -All TensorOperations activation methods currently have placeholder backward functions: - -```csharp -backward: (gradOutput) => -{ - throw new NotImplementedException("Backward pass for [Activation] not yet implemented"); -} -``` - -**Future Work**: Gradient computation will be implemented in a future phase. This includes: -- Analytical gradient formulas for all 37 activations -- Efficient backward pass implementations -- Support for training with JIT-compiled graphs - -**Current Limitation**: JIT compilation is only suitable for **inference** (forward pass only). For **training**, use eager mode until backward pass is implemented. - ---- - -## Activation Selection Guide - -### For Image Classification (CNNs) - -**Recommended**: -- Hidden layers: `ReLUActivation` (fast, effective) -- Modern alternative: `GELUActivation` (smoother gradients) -- Output layer: `SoftmaxActivation` (multi-class) - -**Example**: -```csharp -var conv1 = new ConvolutionalLayer(filters: 32, kernelSize: 3, activation: new ReLUActivation()); -var conv2 = new ConvolutionalLayer(filters: 64, kernelSize: 3, activation: new ReLUActivation()); -var dense = new DenseLayer(inputSize: 1024, outputSize: 10, activation: new SoftmaxActivation()); -``` - -### For Natural Language Processing (Transformers) - -**Recommended**: -- Hidden layers: `GELUActivation` (used in BERT, GPT) -- Alternative: `SwishActivation` or `MishActivation` -- Output layer: `SoftmaxActivation` (classification) or `IdentityActivation` (regression) - -**Example**: -```csharp -var feedForward = new DenseLayer(inputSize: 768, outputSize: 3072, activation: new GELUActivation()); -var output = new DenseLayer(inputSize: 3072, outputSize: 768, activation: new IdentityActivation()); -``` - -### For Recurrent Networks (RNNs, LSTMs, GRUs) - -**Recommended**: -- Gates: `SigmoidActivation` (LSTM/GRU gates) -- Hidden state: `TanhActivation` (LSTM/GRU hidden state) -- Output layer: `SoftmaxActivation` (classification) - -**Example**: -```csharp -// LSTM uses both Sigmoid (for gates) and Tanh (for cell state) -var lstm = new LSTMLayer(inputSize: 100, hiddenSize: 128); -// Gates internally use Sigmoid, cell state uses Tanh -``` - -### For Generative Models (GANs, VAEs) - -**Recommended**: -- Generator hidden: `LeakyReLUActivation` or `ELUActivation` (avoid dying ReLU) -- Generator output: `TanhActivation` (normalize to [-1, 1]) -- Discriminator: `LeakyReLUActivation` (stable gradients) - -**Example**: -```csharp -var genHidden = new DenseLayer(inputSize: 100, outputSize: 256, activation: new LeakyReLUActivation()); -var genOutput = new DenseLayer(inputSize: 256, outputSize: 784, activation: new TanhActivation()); -``` - ---- - -## Integration Checklist - -When adding JIT support for an activation to your layer: - -- [ ] Check if activation is in "Production-Ready" list -- [ ] If not, check "Available Activations - Pending Integration" list -- [ ] Add activation type check to `CanActivationBeJitted()` -- [ ] Add activation mapping to `ApplyActivationToGraph()` -- [ ] Handle parameterized activations correctly (extract parameters) -- [ ] Update `SupportsJitCompilation` property -- [ ] Update XML documentation with supported activations -- [ ] Test with sample data -- [ ] Verify JIT compilation succeeds -- [ ] Benchmark performance - ---- - -## See Also - -- [JIT_COMPILATION_PATTERN_GUIDE.md](JIT_COMPILATION_PATTERN_GUIDE.md) - Complete implementation guide -- [JIT_ROADMAP.md](JIT_ROADMAP.md) - Current status and future work diff --git a/docs/JIT_COMPILATION_PATTERN_GUIDE.md b/docs/JIT_COMPILATION_PATTERN_GUIDE.md deleted file mode 100644 index 2c347ebd79..0000000000 --- a/docs/JIT_COMPILATION_PATTERN_GUIDE.md +++ /dev/null @@ -1,723 +0,0 @@ -# JIT Compilation Pattern Guide - -## Overview - -### What is JIT Compilation in AiDotNet? - -Just-In-Time (JIT) compilation in AiDotNet is a performance optimization technique that compiles neural network layers into optimized computation graphs **before** training or inference begins. This allows the framework to: - -1. **Optimize the computation graph** - Remove redundant operations, fuse operations together, and apply mathematical simplifications -2. **Generate efficient code** - Convert high-level operations into optimized low-level code that runs on CPU or GPU -3. **Accelerate execution** - Execute the compiled graph much faster than interpreting operations one-by-one - -### Performance Benefits - -JIT compilation provides significant performance improvements: - -- **Target speedup**: 5-10x faster execution compared to eager mode -- **Reduced memory overhead**: Optimized graphs use less temporary memory -- **Better hardware utilization**: Compiled code can better leverage CPU/GPU parallelism -- **Batch efficiency**: Symbolic batch dimensions (-1) allow same compiled graph to handle any batch size - -### When to Use JIT Compilation - -**Use JIT compilation when:** -- Training or running inference on production models -- Working with large batch sizes (where compilation overhead is amortized) -- Deploying models to resource-constrained environments -- Performance is critical (real-time inference, large-scale training) - -**Don't use JIT compilation when:** -- Rapidly prototyping and debugging (eager mode is easier to debug) -- Working with dynamic architectures that change structure frequently -- Batch size is 1 and latency is more important than throughput - -### Current Support Status - -As of the latest release: - -- **Foundation**: Complete (TensorOperations, IEngine integration, IR operations) -- **DenseLayer**: Production-ready with 10 supported activations -- **Other layers**: 76 layers pending implementation (following the same pattern) - -**Supported activations (10 ready for production use):** -- ReLU, Sigmoid, Tanh, Softmax, Identity -- GELU, ELU, Mish, Swish, SiLU - -**Additional activations (27 available, pending integration):** -- LeakyReLU, SELU, CELU, PReLU, RReLU, ThresholdedReLU -- HardSigmoid, HardTanh, ScaledTanh, Softplus, Softsign, BentIdentity -- Softmin, LogSoftmax, LogSoftmin -- Sign, Gaussian, ISRU, LiSHT, SQRBF, Squash, BinarySpiking -- Sparsemax, SphericalSoftmax, GumbelSoftmax, TaylorSoftmax, HierarchicalSoftmax, Maxout - ---- - -## Supported Activations - -The following activations are fully implemented and ready for JIT compilation: - -### Scalar Activations (Element-wise) - -| Activation | TensorOperations Method | Description | Use Cases | -|------------|------------------------|-------------|-----------| -| **ReLU** | `TensorOperations.ReLU(node)` | Rectified Linear Unit - outputs max(0, x) | Most common activation, default for hidden layers | -| **Sigmoid** | `TensorOperations.Sigmoid(node)` | Sigmoid function - outputs 1/(1+e^(-x)) | Binary classification output, gates in RNNs | -| **Tanh** | `TensorOperations.Tanh(node)` | Hyperbolic tangent - outputs (e^x - e^(-x))/(e^x + e^(-x)) | Alternative to sigmoid, centers output around 0 | -| **GELU** | `TensorOperations.GELU(node)` | Gaussian Error Linear Unit | Used in Transformers (BERT, GPT) | -| **ELU** | `TensorOperations.ELU(node, alpha)` | Exponential Linear Unit | Reduces vanishing gradient problem | -| **Mish** | `TensorOperations.Mish(node)` | Self-regularized smooth activation | Modern alternative to ReLU | -| **Swish** | `TensorOperations.Swish(node)` | Self-gated activation (x * sigmoid(x)) | Google Brain's smooth alternative to ReLU | -| **SiLU** | `TensorOperations.SiLU(node)` | Sigmoid Linear Unit (same as Swish) | Used in modern architectures | -| **LeakyReLU** | `TensorOperations.LeakyReLU(node, slope)` | ReLU with small negative slope | Prevents dying ReLU problem | -| **Identity** | `input` (no-op) | Returns input unchanged | Linear layers, skip connections | - -### Vector Activations (Operates on entire vectors) - -| Activation | TensorOperations Method | Description | Use Cases | -|------------|------------------------|-------------|-----------| -| **Softmax** | `TensorOperations.Softmax(node, axis)` | Converts logits to probability distribution | Multi-class classification output | - ---- - -## Step-by-Step Implementation Guide - -This section shows you exactly how to add JIT compilation support to any neural network layer. - -### Prerequisites - -Before implementing JIT support, ensure: - -1. ✅ Your layer inherits from `LayerBase` or implements `ILayer` -2. ✅ Your layer has a working `Forward()` method -3. ✅ Your layer uses one of the supported activations listed above -4. ✅ Your layer has properly initialized weights and biases - -### Step 1: Override ExportComputationGraph - -The `ExportComputationGraph` method is the core of JIT compilation. It builds a symbolic representation of your layer's computation that can be optimized and compiled. - -```csharp -public override ComputationNode ExportComputationGraph(List> inputNodes) -{ - // 1. Validate inputs - if (inputNodes == null) - throw new ArgumentNullException(nameof(inputNodes)); - - if (_weights == null) - throw new InvalidOperationException("Layer weights not initialized. Call Initialize() or train the layer first."); - - if (_biases == null) - throw new InvalidOperationException("Layer biases not initialized. Call Initialize() or train the layer first."); - - if (InputShape == null || InputShape.Length == 0) - throw new InvalidOperationException("Layer input shape not configured."); - - if (!CanActivationBeJitted()) - { - var activationType = ScalarActivation?.GetType().Name ?? VectorActivation?.GetType().Name ?? "unknown"; - throw new NotSupportedException( - $"Activation function '{activationType}' is not supported for JIT compilation yet. " + - "Supported activations: ReLU, Sigmoid, Tanh, GELU, ELU, Mish, Swish, SiLU, LeakyReLU, Softmax, Identity"); - } - - // 2. Extract layer dimensions - int inputSize = InputShape[0]; // e.g., 784 for MNIST - int outputSize = OutputShape[0]; // e.g., 128 for hidden layer - - // 3. Create input placeholder with symbolic batch dimension - // The -1 means "any batch size" - allows same compiled graph for batch sizes 1, 32, 128, etc. - var inputPlaceholder = new Tensor(new int[] { 1, inputSize }); // Actual placeholder is batch size 1 - var inputNode = TensorOperations.Variable(inputPlaceholder, "input"); - - // 4. Create parameter nodes for weights and biases - // Weights shape: [outputSize, inputSize] - transposed for efficient computation - var weightsNode = TensorOperations.Variable( - new Tensor(new int[] { _weights.Rows, _weights.Columns }, _weights), - "weights" - ); - - // Biases shape: [outputSize] - var biasesNode = TensorOperations.Variable( - new Tensor(new int[] { _biases.Length }, _biases), - "biases" - ); - - // 5. Add nodes to input list (required by JIT compiler) - inputNodes.Add(inputNode); - inputNodes.Add(weightsNode); - inputNodes.Add(biasesNode); - - // 6. Build computation graph matching Forward() logic - // This example shows DenseLayer: output = (input × weights^T) + biases + activation - - // Step 6a: Transpose weights for matrix multiplication - var weightsTransposed = TensorOperations.Transpose(weightsNode); - - // Step 6b: Matrix multiply: input × weights^T - var matmulResult = TensorOperations.MatrixMultiply(inputNode, weightsTransposed); - - // Step 6c: Add biases (broadcasts across batch dimension) - var outputNode = TensorOperations.Add(matmulResult, biasesNode); - - // Step 6d: Apply activation function - var activatedOutput = ApplyActivationToGraph(outputNode); - - // 7. Return the final output node - return activatedOutput; -} -``` - -**Key Points:** - -- **Symbolic batch dimension**: Use `-1` in shape to indicate "any batch size". This allows the same compiled graph to handle different batch sizes efficiently. -- **Match Forward() exactly**: The computation graph must produce identical results to your existing `Forward()` method. -- **Parameter ordering matters**: Add nodes to `inputNodes` in the order: input, then parameters (weights, biases, etc.) -- **Use TensorOperations, not IEngine**: `TensorOperations` methods return `ComputationNode`, which is what we need. - -### Step 2: Implement ApplyActivationToGraph - -This helper method maps your layer's configured activation to the corresponding TensorOperations method. - -```csharp -/// -/// Applies the layer's activation function to a computation graph node. -/// Maps the layer's configured activation to the corresponding TensorOperations method. -/// -private ComputationNode ApplyActivationToGraph(ComputationNode input) -{ - if (input == null) - throw new ArgumentNullException(nameof(input)); - - // Check scalar activation first (element-wise activations) - if (ScalarActivation is not null) - { - // ReLU family - if (ScalarActivation is ReLUActivation) - return TensorOperations.ReLU(input); - else if (ScalarActivation is LeakyReLUActivation leakyRelu) - return TensorOperations.LeakyReLU(input, leakyRelu.NegativeSlope); - - // Sigmoid family - else if (ScalarActivation is SigmoidActivation) - return TensorOperations.Sigmoid(input); - else if (ScalarActivation is TanhActivation) - return TensorOperations.Tanh(input); - else if (ScalarActivation is SwishActivation) - return TensorOperations.Swish(input); - else if (ScalarActivation is SiLUActivation) - return TensorOperations.SiLU(input); - else if (ScalarActivation is MishActivation) - return TensorOperations.Mish(input); - - // Modern activations - else if (ScalarActivation is GELUActivation) - return TensorOperations.GELU(input); - else if (ScalarActivation is ELUActivation elu) - return TensorOperations.ELU(input, elu.Alpha); - - // Identity (no-op) - else if (ScalarActivation is IdentityActivation) - return input; - - // Unsupported activation - else - throw new NotSupportedException( - $"Activation {ScalarActivation.GetType().Name} is not supported for JIT compilation yet"); - } - - // Check vector activation (operates on entire vectors) - if (VectorActivation is not null) - { - if (VectorActivation is SoftmaxActivation) - return TensorOperations.Softmax(input); - else - throw new NotSupportedException( - $"Activation {VectorActivation.GetType().Name} is not supported for JIT compilation yet"); - } - - // No activation configured (identity) - return input; -} -``` - -**Key Points:** - -- **Check both ScalarActivation and VectorActivation**: Layers can have either type -- **Parameterized activations**: Some activations like LeakyReLU and ELU have parameters - extract and pass them -- **Identity is a no-op**: Just return the input unchanged -- **Clear error messages**: Tell users which activations are not yet supported - -### Step 3: Implement CanActivationBeJitted - -This helper method checks if the layer's current activation is supported for JIT compilation. - -```csharp -/// -/// Checks if the layer's current activation function is supported for JIT compilation. -/// -private bool CanActivationBeJitted() -{ - // Check scalar activations - if (ScalarActivation is ReLUActivation || - ScalarActivation is SigmoidActivation || - ScalarActivation is TanhActivation || - ScalarActivation is GELUActivation || - ScalarActivation is ELUActivation || - ScalarActivation is MishActivation || - ScalarActivation is SwishActivation || - ScalarActivation is SiLUActivation || - ScalarActivation is LeakyReLUActivation || - ScalarActivation is IdentityActivation) - { - return true; - } - - // Check vector activations - if (VectorActivation is SoftmaxActivation) - { - return true; - } - - // No activation is fine (identity) - if (ScalarActivation == null && VectorActivation == null) - { - return true; - } - - return false; -} -``` - -**Key Points:** - -- **Whitelist approach**: Explicitly list supported activations -- **No activation = identity**: Return true if no activation configured -- **Easy to extend**: Just add new activation types as they're implemented - -### Step 4: Update SupportsJitCompilation - -This property tells the framework whether the layer can be JIT compiled in its current configuration. - -```csharp -/// -/// Gets whether this layer currently supports JIT compilation. -/// -/// -/// True if the layer's activation function is supported for JIT compilation. -/// Supported activations: ReLU, Sigmoid, Tanh, GELU, ELU, Mish, Swish, SiLU, LeakyReLU, Softmax, Identity. -/// -public override bool SupportsJitCompilation => CanActivationBeJitted(); -``` - -**Key Points:** - -- **Dynamic check**: Layer might support JIT with ReLU but not with a custom activation -- **Used by JIT compiler**: Framework checks this before attempting compilation -- **Document supported activations**: Keep XML comment updated as you add more activations - -### Step 5: Add Validation (Optional but Recommended) - -For production-quality implementations, add validation to catch common errors early. - -```csharp -/// -/// Validates that the layer is ready for JIT compilation. -/// -private void ValidateForJitCompilation() -{ - if (_weights == null) - throw new InvalidOperationException( - "Layer weights not initialized. Call Initialize() or train the layer first."); - - if (_biases == null) - throw new InvalidOperationException( - "Layer biases not initialized. Call Initialize() or train the layer first."); - - if (InputShape == null || InputShape.Length == 0) - throw new InvalidOperationException( - "Layer input shape not configured. Set InputShape before exporting computation graph."); - - if (OutputShape == null || OutputShape.Length == 0) - throw new InvalidOperationException( - "Layer output shape not configured. This should be set during initialization."); - - if (!CanActivationBeJitted()) - { - var activationType = ScalarActivation?.GetType().Name ?? - VectorActivation?.GetType().Name ?? - "unknown"; - throw new NotSupportedException( - $"Activation function '{activationType}' is not supported for JIT compilation. " + - $"Supported activations: ReLU, Sigmoid, Tanh, GELU, ELU, Mish, Swish, SiLU, LeakyReLU, Softmax, Identity"); - } -} -``` - -Then call it at the start of `ExportComputationGraph`: - -```csharp -public override ComputationNode ExportComputationGraph(List> inputNodes) -{ - ValidateForJitCompilation(); - // ... rest of implementation -} -``` - ---- - -## Common Patterns - -### Pattern 1: Matrix Operations - -Most layers perform matrix multiplication (dense, convolutional, attention, etc.): - -```csharp -// Dense layer: output = input × weights^T -var weightsTransposed = TensorOperations.Transpose(weightsNode); -var output = TensorOperations.MatrixMultiply(inputNode, weightsTransposed); - -// Add bias -output = TensorOperations.Add(output, biasesNode); -``` - -### Pattern 2: Element-wise Operations - -Activation functions, batch normalization, layer normalization use element-wise ops: - -```csharp -// Element-wise multiply -var scaled = TensorOperations.ElementwiseMultiply(input, scaleNode); - -// Element-wise add -var shifted = TensorOperations.Add(scaled, offsetNode); - -// Activation -var activated = TensorOperations.ReLU(shifted); -``` - -### Pattern 3: Convolution Operations - -Convolutional layers use Conv2D: - -```csharp -// Convolution: output = Conv2D(input, kernel) + bias -var convResult = TensorOperations.Conv2D( - inputNode, - kernelNode, - stride: new[] { strideY, strideX }, - padding: new[] { padY, padX }, - dilation: new[] { dilationY, dilationX } -); - -var withBias = TensorOperations.Add(convResult, biasNode); -var activated = ApplyActivationToGraph(withBias); -``` - -### Pattern 4: Pooling Operations - -MaxPooling and AveragePooling layers: - -```csharp -// Max pooling -var pooled = TensorOperations.MaxPool2D( - inputNode, - poolSize: new[] { poolHeight, poolWidth }, - stride: new[] { strideY, strideX }, - padding: new[] { padY, padX } -); - -// Average pooling -var pooled = TensorOperations.AvgPool2D( - inputNode, - poolSize: new[] { poolHeight, poolWidth }, - stride: new[] { strideY, strideX }, - padding: new[] { padY, padX } -); -``` - -### Pattern 5: Normalization Operations - -Batch normalization and layer normalization: - -```csharp -// Batch normalization -var normalized = TensorOperations.BatchNorm( - inputNode, - gammaNode, // Scale parameter - betaNode, // Shift parameter - meanNode, // Running mean - varianceNode, // Running variance - epsilon: 1e-5 -); - -// Layer normalization -var normalized = TensorOperations.LayerNorm( - inputNode, - gammaNode, - betaNode, - epsilon: 1e-5 -); -``` - -### Pattern 6: Concatenation and Splitting - -Combine or split tensors: - -```csharp -// Concatenate multiple inputs -var combined = TensorOperations.Concat( - new List> { input1, input2, input3 }, - axis: 1 // Concatenate along feature dimension -); - -// Reshape to split -var reshaped = TensorOperations.Reshape(inputNode, newShape); -``` - -### Pattern 7: Attention Mechanism - -Self-attention and multi-head attention: - -```csharp -// Query, Key, Value projections -var query = TensorOperations.MatrixMultiply(inputNode, queryWeightsNode); -var key = TensorOperations.MatrixMultiply(inputNode, keyWeightsNode); -var value = TensorOperations.MatrixMultiply(inputNode, valueWeightsNode); - -// Attention scores: Q × K^T / sqrt(d_k) -var keyTransposed = TensorOperations.Transpose(key); -var scores = TensorOperations.MatrixMultiply(query, keyTransposed); - -// Scale -var scaleFactor = Math.Sqrt(embeddingDim); -var scaled = TensorOperations.Divide(scores, TensorOperations.Constant(scaleFactor)); - -// Softmax -var attention = TensorOperations.Softmax(scaled, axis: -1); - -// Apply attention to values -var output = TensorOperations.MatrixMultiply(attention, value); -``` - ---- - -## Troubleshooting - -### Error: "Activation X is not supported for JIT compilation" - -**Cause**: Your layer uses an activation function that hasn't been added to `ApplyActivationToGraph` yet. - -**Solution**: -1. Check if the activation is in the supported list (see "Supported Activations" section) -2. If it's listed but not working, add it to `CanActivationBeJitted()` and `ApplyActivationToGraph()` -3. If it's not listed, add the TensorOperations method first, then update your layer - -**Example fix**: -```csharp -// Add to CanActivationBeJitted() -if (ScalarActivation is SELUActivation) - return true; - -// Add to ApplyActivationToGraph() -else if (ScalarActivation is SELUActivation) - return TensorOperations.SELU(input); -``` - -### Error: "Layer weights not initialized" - -**Cause**: Trying to export computation graph before calling `Initialize()` or training the layer. - -**Solution**: -```csharp -var layer = new DenseLayer(inputSize: 784, outputSize: 128); -layer.Initialize(); // Initialize weights and biases -var graph = layer.ExportComputationGraph(inputNodes); -``` - -### Error: "InputShape not configured" - -**Cause**: Layer doesn't know its input dimensions. - -**Solution**: -```csharp -layer.InputShape = new int[] { 784 }; // Set before exporting graph -``` - -### Build Error: "Cannot convert TensorOperations result to expected type" - -**Cause**: Using IEngine methods instead of TensorOperations methods. - -**Solution**: -```csharp -// ❌ WRONG - IEngine methods don't return ComputationNode -var result = _engine.MatrixMultiply(input, weights); - -// ✅ CORRECT - Use TensorOperations -var result = TensorOperations.MatrixMultiply(inputNode, weightsNode); -``` - -### Error: "Backward function not implemented" - -**Cause**: This is expected! Gradient computation is not yet implemented. - -**Current status**: Forward pass works, backward pass is placeholder. - -**Workaround**: Use JIT compilation for inference only. For training, gradients will be added in a future phase. - -### Performance Issue: Compilation takes too long - -**Cause**: Very large or complex graphs can take time to compile. - -**Solutions**: -1. Compile once, reuse for multiple batches -2. Use smaller subgraphs (compile individual layers instead of entire model) -3. Cache compiled graphs - -**Example**: -```csharp -// Compile once -var compiled = jitCompiler.Compile(layer); - -// Reuse for many batches -for (int i = 0; i < numBatches; i++) -{ - var output = compiled.Execute(batch[i]); -} -``` - -### Shape Mismatch: "Expected shape [X, Y] but got [A, B]" - -**Cause**: Symbolic batch dimension (-1) not handled correctly. - -**Solution**: Use symbolic shapes consistently: -```csharp -// ✅ CORRECT - Symbolic batch dimension -var inputShape = new int[] { -1, inputSize }; - -// ❌ WRONG - Fixed batch dimension -var inputShape = new int[] { 32, inputSize }; -``` - ---- - -## Complete Example: Adding JIT Support to ConvolutionalLayer - -Here's a full example showing how to add JIT compilation to `ConvolutionalLayer`: - -```csharp -public class ConvolutionalLayer : LayerBase -{ - // ... existing fields and properties ... - - public override ComputationNode ExportComputationGraph(List> inputNodes) - { - // 1. Validate - if (inputNodes == null) - throw new ArgumentNullException(nameof(inputNodes)); - - if (_kernels == null) - throw new InvalidOperationException("Kernels not initialized"); - - if (!CanActivationBeJitted()) - throw new NotSupportedException($"Activation not supported for JIT"); - - // 2. Extract dimensions - // InputShape: [channels, height, width] - int channels = InputShape[0]; - int height = InputShape[1]; - int width = InputShape[2]; - - // 3. Create input placeholder with symbolic batch - var inputPlaceholder = new Tensor(new int[] { 1, channels, height, width }); - var inputNode = TensorOperations.Variable(inputPlaceholder, "input"); - - // 4. Create kernel parameters - // Kernels shape: [numFilters, channels, kernelHeight, kernelWidth] - var kernelNode = TensorOperations.Variable( - new Tensor(_kernels.Shape, _kernels.ToArray()), - "kernels" - ); - - // Biases shape: [numFilters] - var biasNode = TensorOperations.Variable( - new Tensor(new int[] { NumFilters }, _biases), - "biases" - ); - - // 5. Add to input list - inputNodes.Add(inputNode); - inputNodes.Add(kernelNode); - inputNodes.Add(biasNode); - - // 6. Build computation graph - var convResult = TensorOperations.Conv2D( - inputNode, - kernelNode, - stride: new[] { StrideY, StrideX }, - padding: new[] { PaddingY, PaddingX }, - dilation: new[] { DilationY, DilationX } - ); - - var withBias = TensorOperations.Add(convResult, biasNode); - var activated = ApplyActivationToGraph(withBias); - - return activated; - } - - private ComputationNode ApplyActivationToGraph(ComputationNode input) - { - if (input == null) - throw new ArgumentNullException(nameof(input)); - - if (ScalarActivation is not null) - { - if (ScalarActivation is ReLUActivation) - return TensorOperations.ReLU(input); - else if (ScalarActivation is SigmoidActivation) - return TensorOperations.Sigmoid(input); - // ... add other activations ... - else - throw new NotSupportedException($"Activation {ScalarActivation.GetType().Name} not supported"); - } - - return input; - } - - private bool CanActivationBeJitted() - { - if (ScalarActivation is ReLUActivation || - ScalarActivation is SigmoidActivation || - ScalarActivation is TanhActivation || - ScalarActivation is IdentityActivation) - { - return true; - } - - if (ScalarActivation == null && VectorActivation == null) - { - return true; - } - - return false; - } - - public override bool SupportsJitCompilation => CanActivationBeJitted(); -} -``` - ---- - -## Next Steps - -After implementing JIT support for your layer: - -1. **Test compilation**: Ensure `ExportComputationGraph` runs without errors -2. **Verify correctness**: Compare JIT output with eager mode output -3. **Measure performance**: Benchmark to confirm speedup -4. **Add more activations**: Extend `ApplyActivationToGraph` as needed -5. **Document**: Update this guide with any new patterns you discover - -For the complete roadmap and list of layers to implement, see [JIT_ROADMAP.md](JIT_ROADMAP.md). - -For activation function reference, see [JIT_ACTIVATION_MAPPING.md](JIT_ACTIVATION_MAPPING.md). diff --git a/docs/JIT_ROADMAP.md b/docs/JIT_ROADMAP.md deleted file mode 100644 index f9173bbe64..0000000000 --- a/docs/JIT_ROADMAP.md +++ /dev/null @@ -1,452 +0,0 @@ -# JIT Compilation Roadmap - -## Current Status - -### Phase 1: Foundation (Complete ✅) - -**Agents 1-5** implemented the core infrastructure for JIT compilation: - -#### Agent 1: TensorOperations Foundation -- ✅ Created `TensorOperations` class with generic type support -- ✅ Implemented core operations: Add, Subtract, ElementwiseMultiply, Divide, Power -- ✅ Implemented mathematical operations: Exp, Log, Sqrt, Tanh, Sigmoid, ReLU -- ✅ Implemented matrix operations: MatrixMultiply, Transpose -- ✅ Implemented reduction operations: Sum, Mean -- ✅ Implemented shape operations: Reshape, Concat, Pad -- ✅ All operations return `ComputationNode` for autodiff support - -#### Agent 2: IR Operations (Group 1 - ReLU Family) -- ✅ Added IR operations for ReLU family activations -- ✅ Integrated with IEngine for GPU acceleration -- ✅ Operations: ReLU, LeakyReLU, GELU, ELU, SELU, CELU, PReLU, RReLU, ThresholdedReLU - -#### Agent 3: IR Operations (Group 2 - Sigmoid Family) -- ✅ Added IR operations for Sigmoid family activations -- ✅ Integrated with IEngine for GPU acceleration -- ✅ Operations: Sigmoid, Tanh, Swish, SiLU, Mish, HardSigmoid, HardTanh, Softplus, Softsign - -#### Agent 4: IR Operations (Group 3 - Softmax & Special) -- ✅ Added IR operations for Softmax family -- ✅ Added IR operations for special activations -- ✅ Operations: Softmax, Softmin, LogSoftmax, LogSoftmin, Sign, Gaussian, ISRU, LiSHT, SQRBF, Squash, BinarySpiking, BentIdentity, Identity -- ✅ Placeholder implementations for complex activations: Sparsemax, SphericalSoftmax, GumbelSoftmax, TaylorSoftmax, HierarchicalSoftmax, Maxout - -#### Agent 5: TensorOperations Method Completion -- ✅ Added TensorOperations methods for all 37 activation functions -- ✅ 27 fully implemented (ReLU, Sigmoid families, special activations) -- ✅ 6 placeholder implementations (complex activations) -- ✅ 4 pre-existing (ReLU, Sigmoid, Tanh, Softmax) -- ✅ All methods integrated with IEngine for hardware acceleration - -**Summary**: Infrastructure is complete. All 37 activation functions have TensorOperations methods and IEngine integration. - ---- - -### Phase 2: DenseLayer Production-Ready (Complete ✅) - -**Agent 6** made DenseLayer production-ready for JIT compilation: - -#### Implementation -- ✅ Implemented `ExportComputationGraph` with symbolic batch dimensions (-1) -- ✅ Implemented `ApplyActivationToGraph` helper method -- ✅ Implemented `CanActivationBeJitted` validation -- ✅ Updated `SupportsJitCompilation` property -- ✅ Added comprehensive validation - -#### Supported Activations (10) -- ✅ ReLU, Sigmoid, Tanh, Softmax, Identity (baseline) -- ✅ GELU, ELU, Mish, Swish, SiLU (modern activations) - -#### Testing & Validation -- ✅ Computation graph exports correctly -- ✅ Symbolic batch dimensions work -- ✅ Parameter nodes (weights, biases) handled correctly -- ✅ Activation mapping verified -- ✅ Build succeeds without errors - -**Summary**: DenseLayer is the reference implementation. Pattern is established and documented. - ---- - -### Phase 3: Rollout to Other Layers (Pending ⏳) - -**Agent 7** created comprehensive documentation (this document and related guides). - -**Next step**: Apply the DenseLayer pattern to 76 remaining layers. - ---- - -## Layer Implementation Priorities - -### Total Layers: 77 -- **Production-Ready**: 1 (DenseLayer) -- **Pending Implementation**: 76 - ---- - -### Priority 1: Core Layers (6 layers) - -These are the most commonly used layers in neural networks. Implementing these will enable JIT compilation for the majority of models. - -| Layer | File | Priority Reason | Estimated Complexity | -|-------|------|----------------|----------------------| -| **ConvolutionalLayer** | `ConvolutionalLayer.cs` | Used in all CNNs (ResNet, VGG, etc.) | Medium - Conv2D operation | -| **LayerNormalizationLayer** | `LayerNormalizationLayer.cs` | Critical for Transformers (BERT, GPT) | Medium - LayerNorm operation | -| **PoolingLayer** | `PoolingLayer.cs` | Used in all CNNs for downsampling | Low - MaxPool2D/AvgPool2D | -| **BatchNormalizationLayer** | `BatchNormalizationLayer.cs` | Used in most modern CNNs | Medium - BatchNorm operation | -| **DropoutLayer** | `DropoutLayer.cs` | Used in almost all models | Low - Element-wise mask | -| **FlattenLayer** | `FlattenLayer.cs` | Connects CNNs to dense layers | Low - Reshape operation | - -**Estimated time**: 1-2 days per layer = 6-12 days total - ---- - -### Priority 2: Recurrent Layers (3 layers) - -Essential for sequence models (NLP, time series). - -| Layer | File | Priority Reason | Estimated Complexity | -|-------|------|----------------|----------------------| -| **LSTMLayer** | `LSTMLayer.cs` | Most popular RNN variant | High - Complex gates | -| **GRULayer** | `GRULayer.cs` | Alternative to LSTM, simpler | High - Complex gates | -| **RecurrentLayer** | `RecurrentLayer.cs` | Basic RNN layer | Medium - Recurrent connections | - -**Estimated time**: 2-3 days per layer = 6-9 days total - ---- - -### Priority 3: Attention Layers (4 layers) - -Critical for Transformers and modern NLP/vision models. - -| Layer | File | Priority Reason | Estimated Complexity | -|-------|------|----------------|----------------------| -| **MultiHeadAttentionLayer** | `MultiHeadAttentionLayer.cs` | Core of Transformer architecture | High - Complex attention mechanism | -| **SelfAttentionLayer** | `SelfAttentionLayer.cs` | Used in Transformers | High - Attention computation | -| **AttentionLayer** | `AttentionLayer.cs` | Basic attention mechanism | Medium - QKV projections | -| **TransformerEncoderLayer** | `TransformerEncoderLayer.cs` | Complete encoder block | High - Combines attention + FFN | - -**Estimated time**: 2-3 days per layer = 8-12 days total - ---- - -### Priority 4: Specialized Convolutional Layers (6 layers) - -Important for advanced vision models. - -| Layer | File | Priority Reason | Estimated Complexity | -|-------|------|----------------|----------------------| -| **DepthwiseSeparableConvolutionalLayer** | `DepthwiseSeparableConvolutionalLayer.cs` | MobileNet, EfficientNet | Medium - Depthwise + Pointwise | -| **DeconvolutionalLayer** | `DeconvolutionalLayer.cs` | GANs, image generation | Medium - ConvTranspose2D | -| **DilatedConvolutionalLayer** | `DilatedConvolutionalLayer.cs` | WaveNet, semantic segmentation | Medium - Dilated convolution | -| **SeparableConvolutionalLayer** | `SeparableConvolutionalLayer.cs` | Efficient CNNs | Medium - Separable convolution | -| **LocallyConnectedLayer** | `LocallyConnectedLayer.cs` | Face recognition, pattern-specific | Medium - Local connections | -| **ConvLSTMLayer** | `ConvLSTMLayer.cs` | Video processing, spatio-temporal | High - Conv + LSTM fusion | - -**Estimated time**: 1-2 days per layer = 6-12 days total - ---- - -### Priority 5: Utility Layers (10 layers) - -Small but frequently used layers. - -| Layer | File | Estimated Complexity | -|-------|------|---------------------| -| **AddLayer** | `AddLayer.cs` | Low - Element-wise add | -| **MultiplyLayer** | `MultiplyLayer.cs` | Low - Element-wise multiply | -| **ConcatenateLayer** | `ConcatenateLayer.cs` | Low - Concat operation | -| **ReshapeLayer** | `ReshapeLayer.cs` | Low - Reshape operation | -| **ActivationLayer** | `ActivationLayer.cs` | Low - Just activation | -| **ResidualLayer** | `ResidualLayer.cs` | Low - Add input to output | -| **PaddingLayer** | `PaddingLayer.cs` | Low - Pad operation | -| **CroppingLayer** | `CroppingLayer.cs` | Low - Crop operation | -| **UpsamplingLayer** | `UpsamplingLayer.cs` | Low - Upsample operation | -| **SplitLayer** | `SplitLayer.cs` | Low - Split operation | - -**Estimated time**: 0.5-1 day per layer = 5-10 days total - ---- - -### Priority 6: Advanced Architecture Layers (8 layers) - -Modern architectural innovations. - -| Layer | File | Priority Reason | Estimated Complexity | -|-------|------|----------------|----------------------| -| **ResidualLayer** | `ResidualLayer.cs` | ResNet, skip connections | Low - Add operation | -| **HighwayLayer** | `HighwayLayer.cs` | Highway networks | Medium - Gated shortcut | -| **SqueezeAndExcitationLayer** | `SqueezeAndExcitationLayer.cs` | SENet, channel attention | Medium - Global pooling + FC | -| **GatedLinearUnitLayer** | `GatedLinearUnitLayer.cs` | Language modeling | Medium - Gated activation | -| **MixtureOfExpertsLayer** | `MixtureOfExpertsLayer.cs` | Sparse models (Switch Transformer) | High - Routing + experts | -| **CapsuleLayer** | `CapsuleLayer.cs` | Capsule Networks | High - Dynamic routing | -| **GraphConvolutionalLayer** | `GraphConvolutionalLayer.cs` | Graph neural networks | High - Graph operations | -| **SpatialTransformerLayer** | `SpatialTransformerLayer.cs` | Spatial attention | High - Affine transformation | - -**Estimated time**: 1-3 days per layer = 8-24 days total - ---- - -### Priority 7: Embedding & Encoding Layers (5 layers) - -Essential for NLP and sequence models. - -| Layer | File | Estimated Complexity | -|-------|------|---------------------| -| **EmbeddingLayer** | `EmbeddingLayer.cs` | Low - Lookup table | -| **PositionalEncodingLayer** | `PositionalEncodingLayer.cs` | Low - Add positional embeddings | -| **PatchEmbeddingLayer** | `PatchEmbeddingLayer.cs` | Medium - Vision Transformers | -| **TransformerDecoderLayer** | `TransformerDecoderLayer.cs` | High - Decoder block | -| **DecoderLayer** | `DecoderLayer.cs` | Medium - Seq2seq decoder | - -**Estimated time**: 1-2 days per layer = 5-10 days total - ---- - -### Priority 8: Specialized & Research Layers (34 layers) - -These are specialized layers for specific use cases, research, or niche applications. - -| Category | Layers | Estimated Time | -|----------|--------|----------------| -| **Pooling Variants** | MaxPoolingLayer, GlobalPoolingLayer | 1-2 days | -| **Normalization** | (Already covered: BatchNorm, LayerNorm) | - | -| **Noise & Regularization** | GaussianNoiseLayer, MaskingLayer | 1-2 days | -| **Memory-Augmented** | MemoryReadLayer, MemoryWriteLayer, ContinuumMemorySystemLayer, TemporalMemoryLayer | 4-6 days | -| **Spiking Neural Networks** | SpikingLayer, SynapticPlasticityLayer | 2-3 days | -| **Quantum** | QuantumLayer | 1-2 days | -| **Capsule Networks** | PrimaryCapsuleLayer, DigitCapsuleLayer | 2-3 days | -| **Specialized Conv** | SubpixelConvolutionalLayer | 1 day | -| **RBF & Kernel Methods** | RBFLayer, LogVarianceLayer | 1-2 days | -| **Anomaly Detection** | AnomalyDetectorLayer | 1 day | -| **Bidirectional** | BidirectionalLayer | 2 days | -| **Time Distributed** | TimeDistributedLayer | 1 day | -| **Readout & Measurement** | ReadoutLayer, MeasurementLayer | 1-2 days | -| **Reconstruction** | ReconstructionLayer | 1 day | -| **Reparameterization** | RepParameterizationLayer | 1 day | -| **Reservoir Computing** | ReservoirLayer | 1-2 days | -| **Spatial Pooler** | SpatialPoolerLayer | 1-2 days | -| **RBM** | RBMLayer | 2-3 days | -| **Feed Forward** | FeedForwardLayer, FullyConnectedLayer | 1 day | -| **Expert** | ExpertLayer | 1 day | -| **Input** | InputLayer | 0.5 day | -| **Lambda** | LambdaLayer | 1 day | -| **Mean** | MeanLayer | 0.5 day | -| **CRF** | ConditionalRandomFieldLayer | 2-3 days | - -**Estimated time**: 30-50 days total - ---- - -## Timeline Estimate - -### Optimistic (Single Developer, Full-Time) - -| Phase | Duration | Cumulative | -|-------|----------|------------| -| Priority 1 (Core) | 6-12 days | 6-12 days | -| Priority 2 (RNN) | 6-9 days | 12-21 days | -| Priority 3 (Attention) | 8-12 days | 20-33 days | -| Priority 4 (Specialized Conv) | 6-12 days | 26-45 days | -| Priority 5 (Utility) | 5-10 days | 31-55 days | -| Priority 6 (Advanced) | 8-24 days | 39-79 days | -| Priority 7 (Embedding) | 5-10 days | 44-89 days | -| Priority 8 (Specialized) | 30-50 days | 74-139 days | - -**Total**: 2.5-5 months (full-time) - -### Realistic (With Testing, Documentation, Reviews) - -Multiply by 1.5-2x for: -- Testing each layer -- Handling edge cases -- Code reviews -- Documentation updates -- Bug fixes - -**Total**: 4-10 months (full-time) - ---- - -## Implementation Strategy - -### Batch Approach - -Instead of implementing layers one-by-one, batch similar layers together: - -**Batch 1: Simple Utility Layers (Week 1)** -- FlattenLayer, ReshapeLayer, AddLayer, MultiplyLayer, ConcatenateLayer -- 5 layers × 1 day = 5 days - -**Batch 2: Core Vision Layers (Week 2)** -- ConvolutionalLayer, PoolingLayer, BatchNormalizationLayer -- 3 layers × 2 days = 6 days - -**Batch 3: Normalization & Regularization (Week 3)** -- LayerNormalizationLayer, DropoutLayer, GaussianNoiseLayer -- 3 layers × 1.5 days = 4-5 days - -**Batch 4: Recurrent Layers (Weeks 4-5)** -- LSTMLayer, GRULayer, RecurrentLayer -- 3 layers × 3 days = 9 days - -**Batch 5: Attention Layers (Weeks 6-7)** -- MultiHeadAttentionLayer, SelfAttentionLayer, AttentionLayer -- 3 layers × 3 days = 9 days - -Continue batching by layer type... - ---- - -## Acceptance Criteria - -For each layer to be considered "production-ready": - -### Code Requirements -- [ ] `ExportComputationGraph` method implemented -- [ ] `ApplyActivationToGraph` helper method implemented -- [ ] `CanActivationBeJitted` validation implemented -- [ ] `SupportsJitCompilation` property updated -- [ ] Symbolic batch dimensions (-1) supported -- [ ] All parameters exported as nodes -- [ ] Computation graph matches Forward() method exactly - -### Documentation Requirements -- [ ] XML documentation updated with JIT support status -- [ ] Supported activations listed in XML comment -- [ ] Code example added to pattern guide (if new pattern) - -### Testing Requirements -- [ ] Build succeeds without errors -- [ ] Computation graph exports without exceptions -- [ ] JIT compilation succeeds -- [ ] Output matches eager mode (forward pass) -- [ ] Works with different batch sizes (1, 32, 128, etc.) -- [ ] Works with all supported activations - -### Integration Requirements -- [ ] IEngine operations used (for GPU acceleration) -- [ ] Error messages are clear and helpful -- [ ] Follows DenseLayer pattern consistently -- [ ] No breaking changes to existing API - ---- - -## Future Work - -### Phase 4: Gradient Computation (Not Scheduled) - -After all layers support forward pass JIT compilation: - -**Tasks**: -- Implement backward functions for all TensorOperations methods -- Add gradient accumulation support -- Implement optimizer integration with JIT graphs -- Test training with JIT compilation - -**Estimated time**: 2-3 months - -**Benefits**: -- Enable JIT compilation for training (not just inference) -- 5-10x speedup for training large models -- Reduced memory usage during backpropagation - ---- - -### Phase 5: Advanced Optimizations (Not Scheduled) - -After gradient computation is complete: - -**Tasks**: -- Graph fusion (combine multiple operations into one) -- Constant folding (pre-compute constant subgraphs) -- Common subexpression elimination -- Memory layout optimizations -- Kernel fusion for GPU - -**Estimated time**: 1-2 months - -**Benefits**: -- Further 2-5x speedup on top of basic JIT -- Reduced memory fragmentation -- Better GPU utilization - ---- - -### Phase 6: Extended Activation Support (Not Scheduled) - -**Tasks**: -- Fully implement 6 placeholder activations (Sparsemax, etc.) -- Add custom activation support -- Add activation fusion optimizations - -**Estimated time**: 2-3 weeks - -**Benefits**: -- 100% activation coverage -- Support for cutting-edge research models -- Custom activation functions for specialized domains - ---- - -## Success Metrics - -### Coverage -- **Current**: 1/77 layers (1.3%) -- **Target (Priority 1-5)**: 35/77 layers (45%) -- **Target (All)**: 77/77 layers (100%) - -### Performance -- **Target speedup**: 5-10x for inference -- **Target memory reduction**: 30-50% - -### Adoption -- **Target**: 80% of models in test suite can use JIT compilation -- **Target**: All major architectures supported (ResNet, BERT, GPT, etc.) - ---- - -## Resources - -### Documentation -- [JIT_COMPILATION_PATTERN_GUIDE.md](JIT_COMPILATION_PATTERN_GUIDE.md) - Implementation guide -- [JIT_ACTIVATION_MAPPING.md](JIT_ACTIVATION_MAPPING.md) - Activation reference - -### Reference Implementation -- `src/NeuralNetworks/Layers/DenseLayer.cs` - Production-ready example - -### Infrastructure -- `src/Autodiff/TensorOperations.cs` - All operations -- `src/Engines/IEngine.cs` - Hardware acceleration -- `src/Autodiff/IR/` - Intermediate representation - ---- - -## Contributing - -To contribute to JIT compilation implementation: - -1. **Pick a layer** from the priority list above -2. **Read the pattern guide** ([JIT_COMPILATION_PATTERN_GUIDE.md](JIT_COMPILATION_PATTERN_GUIDE.md)) -3. **Study DenseLayer** implementation as reference -4. **Implement the pattern** in your chosen layer -5. **Test thoroughly** with various activations and batch sizes -6. **Create a PR** with clear description and test results - -### Questions? - -If you encounter issues or have questions: -- Check the Troubleshooting section in the pattern guide -- Review the DenseLayer implementation -- Ask in the project's discussion forum -- Open an issue with the `jit-compilation` label - ---- - -## Version History - -**v1.0** (2025-11-23) -- Initial roadmap document -- Phases 1-2 complete (foundation + DenseLayer) -- 76 layers pending implementation -- Priority list established diff --git a/docs/JIT_WORKSPACE_DESIGN.md b/docs/JIT_WORKSPACE_DESIGN.md deleted file mode 100644 index a733568048..0000000000 --- a/docs/JIT_WORKSPACE_DESIGN.md +++ /dev/null @@ -1,189 +0,0 @@ -# JIT Compiler + TensorWorkspace: Zero-Allocation Forward Pass Design - -## Problem Statement - -The current CodeGenerator emits calls to `TensorOperations` which wraps every intermediate result in a `ComputationNode` and allocates a new tensor per operation. The "fused" IR ops I added (FusedGroupNormActivationOp, FusedConv2DBiasActivationOp, etc.) just chain the same allocating calls — they provide ZERO actual performance benefit. - -For a production SD15 UNet with 50 denoising steps, the current approach allocates ~60GB of intermediate tensors, causing OOM. - -## Target Architecture - -```text -Inputs: Tensor[] - | - v -WorkspaceCodeGenerator.Generate(IRGraph, MemoryPlan) - | - v -Compiled Function: Action[], TensorWorkspace, IEngine> - | - v -At runtime: - 1. workspace = pre-allocated TensorWorkspace with all slots - 2. engine = current IEngine (CpuEngine or DirectGpuTensorEngine) - 3. compiled(inputs, workspace, engine) // ZERO allocation -``` - -## Key Design Decisions - -### 1. New Code Generator: WorkspaceCodeGenerator - -The existing `CodeGenerator` targets `TensorOperations` (autodiff-enabled, allocating). We need a SEPARATE code generator that targets `IEngine` directly (zero-allocation). - -```csharp -public class WorkspaceCodeGenerator -{ - // Generates: Action[], TensorWorkspace, IEngine> - // NOT: Func[], Tensor[]> - public Action[], TensorWorkspace, IEngine> Generate( - IRGraph graph, - Dictionary tensorToSlot) - { - // For each IR operation, emit the corresponding IEngine call - // using workspace.Get(slot) for all intermediates - } -} -``` - -### 2. Operation Mapping: IROp -> IEngine Method - -Each IR operation maps to a specific IEngine method: - -| IROp | IEngine Method | Allocation | -|------|---------------|------------| -| AddOp | TensorAddInto(dest, a, b) | Zero (dest = workspace slot) | -| Conv2DOp | Conv2DInto(dest, input, kernel, ...) | Zero | -| GroupNormOp | GroupNormInto(dest, input, ...) | Small (mean/var stats) | -| SwishOp | SwishInto(dest, input) **[MISSING]** | Zero | -| ReLUOp | ReLUInto(dest, input) | Zero | -| SigmoidOp | SigmoidInto(dest, input) | Zero | -| MatMulOp | MatMulInto(dest, a, b) **[MISSING]** | Zero | -| ConcatOp | ConcatInto(dest, tensors, axis) **[MISSING]** | Zero | -| ReshapeOp | View only (no data copy needed) | Zero | -| TransposeOp | TransposeInto(dest, input, axes) **[MISSING]** | Zero | - -**Missing IEngine "Into" methods that MUST be added:** -- SwishInto / SwishInPlace (SiLU activation — used in every DiffusionResBlock) -- MishInto / MishInPlace -- GELUInto / GELUInPlace -- TanhInto / TanhInPlace -- MatMulInto (matrix multiply into pre-allocated output) -- ConcatInto (concatenate into pre-allocated output) -- TransposeInto (transpose into pre-allocated output) -- LeakyReLUInto / LeakyReLUInPlace - -### 3. Fused Operations Mapping - -For TRULY fused operations, we need single-pass kernels: - -| Fused IROp | Implementation Strategy | -|-----------|----------------------| -| FusedGroupNormActivationOp | New IEngine method: GroupNormSwishInto(dest, input, gamma, beta, groups, eps) — single pass: normalize + SiLU | -| FusedConv2DBiasActivationOp | Existing: Engine.FusedConv2D(input, kernel, bias, ..., activation) | -| FusedGroupNormActivationConv2DOp | Two calls: GroupNormSwishInto(temp, input, ...) then Conv2DInto(dest, temp, kernel, ...) — temp is a workspace slot | -| FusedAddGroupNormOp | New IEngine method: AddGroupNormInto(dest, a, b, gamma, beta, groups, eps) — single pass: add + normalize | - -**New IEngine methods needed for true fusion:** -- GroupNormSwishInto: GroupNorm + SiLU in single data pass -- GroupNormReLUInto: GroupNorm + ReLU in single data pass -- AddGroupNormInto: Add + GroupNorm in single data pass - -### 4. Memory Plan Integration - -The `MemoryPlanningPass` computes `tensorToSlot` mapping. The `WorkspaceCodeGenerator` uses this mapping: - -```csharp -// For each operation in the IR graph: -int outputSlot = tensorToSlot[op.OutputId]; -var outputTensor = workspace.Get(outputSlot); - -// Emit: Engine.Conv2DInto(outputTensor, inputTensor, kernelTensor, ...) -``` - -The workspace is allocated ONCE at model construction time and reused for every forward pass. - -### 5. Expression Tree Generation - -The WorkspaceCodeGenerator builds expression trees that reference: -- `workspace.Get(slotId)` — workspace parameter -- `Engine.SomeIntoMethod(dest, src, ...)` — engine parameter -- `inputs[i]` — input array parameter - -```csharp -// Generated expression tree (conceptual): -(inputs, workspace, engine) => { - var slot0 = workspace.Get(0); - engine.GroupNormInto(slot0, inputs[0], 32, gamma, beta, 1e-5, out _, out _); - - var slot1 = workspace.Get(1); - engine.SwishInto(slot1, slot0); // SiLU into separate slot (preserves slot0 for potential reuse) - - var slot2 = workspace.Get(2); - engine.Conv2DInto(slot2, slot1, conv1_kernel, 1, 1, 1); - - engine.TensorBroadcastAddInPlace(slot2, conv1_bias_4d); - - // ... second half of ResBlock ... - - engine.TensorAddInPlace(slot4, skip_connection); // residual add - - // slot4 is the output — caller reads it from workspace.Get(4) -} -``` - -### 6. Integration Points - -#### A. NeuralNetworkBase.CompileForward() -1. Export computation graph from all layers -2. Build IR graph via IRBuilder -3. Run optimization passes (fusion, constant folding, dead code elimination) -4. Run MemoryPlanningPass to get tensor-to-slot mapping -5. Create TensorWorkspace with slot shapes from memory plan -6. Generate compiled function via WorkspaceCodeGenerator -7. Store compiled function + workspace - -#### B. UNetNoisePredictor.CompileForward() -Same as above but for the UNet-specific graph structure (encoder + middle + decoder with skip connections). - -#### C. Predict/Forward path -```csharp -if (_compiledForward != null) -{ - _compiledForward(inputs, _workspace, Engine); - return _workspace.Get(_outputSlot); -} -else -{ - // Interpreted fallback - return InterpretedForward(inputs); -} -``` - -## Implementation Plan (Ordered by Priority) - -### Phase 1: IEngine "Into" Methods (AiDotNet.Tensors) — COMPLETE -Added Into/InPlace variants for all major operations to IEngine, CpuEngine, and DirectGpuTensorEngine. - -### Phase 2: WorkspaceCodeGenerator (AiDotNet) — COMPLETE -Emits IEngine calls with workspace slots for 63 IR operations via Into/InPlace equivalents. - -### Phase 3: True Fused Kernels (AiDotNet.Tensors) — FUTURE -GroupNormSwishInto, AddGroupNormInto — single-pass kernels that eliminate intermediates. - -### Phase 4: Integration — COMPLETE -CompileWithWorkspace wired into JitCompiler and UNetNoisePredictor. Automatic fallback to interpreted path. - -### Phase 5: Benchmarks -BenchmarkDotNet project comparing: -- Interpreted vs Compiled forward pass -- Our compiled vs PyTorch (via TorchSharp) -- Operation-level comparisons (Conv2D, GroupNorm, Attention) -- Memory usage comparisons - -## Implementation Status (PR #1018) - -1. **WorkspaceCodeGenerator**: Supports 63 IR ops with Into/InPlace calls targeting TensorWorkspace slots -2. **CompileWithWorkspace**: Wired into JitCompiler, generates workspace-backed compiled functions -3. **UNet CompileForward**: Wired to CompileWithWorkspace with skip connection support -4. **Fused ops (future)**: FusedGroupNormActivationConv2DOp and FusedAddGroupNormOp currently decompose into sequential calls; true kernel fusion requires new IEngine methods (GroupNormSwishInto, AddGroupNormInto) -5. **Integration tests**: WorkspaceCompilationTests validate Add, Multiply, ReLU, Sigmoid, MatMul, and chained ops diff --git a/examples/JitCompiler/BasicUsageExample.cs b/examples/JitCompiler/BasicUsageExample.cs deleted file mode 100644 index 008403957f..0000000000 --- a/examples/JitCompiler/BasicUsageExample.cs +++ /dev/null @@ -1,325 +0,0 @@ -using AiDotNet.Autodiff; -using AiDotNet.Enums; -using AiDotNet.JitCompiler; -using System; -using System.Collections.Generic; -using System.Diagnostics; - -namespace AiDotNet.Examples.JitCompiler; - -/// -/// Basic examples demonstrating JIT compiler usage. -/// -public class BasicUsageExample -{ - /// - /// Example 1: Simple element-wise operation - /// - public static void SimpleElementwiseOperation() - { - Console.WriteLine("=== Example 1: Simple Element-wise Operation ===\n"); - - // Create input tensors - var inputData = new Tensor(new[] { 3, 3 }); - for (int i = 0; i < inputData.Length; i++) - { - inputData[i] = i + 1; // [1, 2, 3, 4, 5, 6, 7, 8, 9] - } - - // Build computation graph - var input = new ComputationNode(inputData) - { - OperationType = OperationType.Input, - Name = "input" - }; - - // result = ReLU(input) - var result = new ComputationNode( - new Tensor(new[] { 3, 3 }), - parents: new List> { input }) - { - OperationType = OperationType.ReLU, - Name = "relu_output" - }; - - // Create JIT compiler and compile - var jit = new global::AiDotNet.JitCompiler.JitCompiler(); - var (compiled, stats) = jit.CompileWithStats(result, new List> { input }); - - Console.WriteLine($"Compilation Stats:"); - Console.WriteLine($" Original operations: {stats.OriginalOperationCount}"); - Console.WriteLine($" Optimized operations: {stats.OptimizedOperationCount}"); - Console.WriteLine($" Compilation time: {stats.CompilationTime.TotalMilliseconds:F2}ms\n"); - - // Execute compiled function - var output = compiled(new[] { inputData }); - - Console.WriteLine("Input: " + string.Join(", ", GetTensorValues(inputData))); - Console.WriteLine("Output (ReLU): " + string.Join(", ", GetTensorValues(output[0]))); - Console.WriteLine(); - } - - /// - /// Example 2: Linear layer (MatMul + Add) - /// - public static void LinearLayerExample() - { - Console.WriteLine("=== Example 2: Linear Layer (MatMul + Add + ReLU) ===\n"); - - // Create inputs - var inputData = new Tensor(new[] { 1, 3 }); - inputData[0] = 1.0f; inputData[1] = 2.0f; inputData[2] = 3.0f; - - var weightsData = new Tensor(new[] { 3, 4 }); - for (int i = 0; i < weightsData.Length; i++) - { - weightsData[i] = 0.1f * (i + 1); - } - - var biasData = new Tensor(new[] { 1, 4 }); - for (int i = 0; i < biasData.Length; i++) - { - biasData[i] = 0.5f; - } - - // Build computation graph: output = ReLU(input @ weights + bias) - var input = new ComputationNode(inputData) { OperationType = OperationType.Input }; - var weights = new ComputationNode(weightsData) { OperationType = OperationType.Input }; - var bias = new ComputationNode(biasData) { OperationType = OperationType.Input }; - - var matmul = new ComputationNode( - new Tensor(new[] { 1, 4 }), - parents: new List> { input, weights }) - { - OperationType = OperationType.MatMul - }; - - var add = new ComputationNode( - new Tensor(new[] { 1, 4 }), - parents: new List> { matmul, bias }) - { - OperationType = OperationType.Add - }; - - var relu = new ComputationNode( - new Tensor(new[] { 1, 4 }), - parents: new List> { add }) - { - OperationType = OperationType.ReLU - }; - - // Compile - var jit = new global::AiDotNet.JitCompiler.JitCompiler(); - var (compiled, stats) = jit.CompileWithStats(relu, new List> { input, weights, bias }); - - Console.WriteLine($"Compilation Stats:"); - Console.WriteLine($" Original operations: {stats.OriginalOperationCount}"); - Console.WriteLine($" Optimized operations: {stats.OptimizedOperationCount}"); - Console.WriteLine($" Operations eliminated: {stats.OperationsEliminated} ({stats.OptimizationPercentage:F1}%)"); - Console.WriteLine($" Optimizations: {string.Join(", ", stats.OptimizationsApplied)}"); - Console.WriteLine($" Compilation time: {stats.CompilationTime.TotalMilliseconds:F2}ms\n"); - - // Execute - var output = compiled(new[] { inputData, weightsData, biasData }); - - Console.WriteLine("Input: " + string.Join(", ", GetTensorValues(inputData))); - Console.WriteLine("Output: " + string.Join(", ", GetTensorValues(output[0]))); - Console.WriteLine(); - } - - /// - /// Example 3: JIT compilation performance benchmark - /// - public static void PerformanceComparisonExample() - { - Console.WriteLine("=== Example 3: JIT Performance Benchmark ===\n"); - - // Create larger tensors for meaningful benchmark - var inputData = new Tensor(new[] { 100, 100 }); - for (int i = 0; i < inputData.Length; i++) - { - inputData[i] = (float)Math.Sin(i * 0.01); - } - - // Build computation graph: exp(relu(input)) - var input = new ComputationNode(inputData) { OperationType = OperationType.Input }; - - var relu = new ComputationNode( - new Tensor(new[] { 100, 100 }), - parents: new List> { input }) - { - OperationType = OperationType.ReLU - }; - - var exp = new ComputationNode( - new Tensor(new[] { 100, 100 }), - parents: new List> { relu }) - { - OperationType = OperationType.Exp - }; - - // Compile - var jit = new global::AiDotNet.JitCompiler.JitCompiler(); - var (compiled, stats) = jit.CompileWithStats(exp, new List> { input }); - - Console.WriteLine($"Graph compiled in {stats.CompilationTime.TotalMilliseconds:F2}ms"); - Console.WriteLine($"Optimizations applied: {string.Join(", ", stats.OptimizationsApplied)}\n"); - - // Warm-up - for (int i = 0; i < 10; i++) - { - compiled(new[] { inputData }); - } - - // Benchmark - const int iterations = 1000; - var sw = Stopwatch.StartNew(); - for (int i = 0; i < iterations; i++) - { - compiled(new[] { inputData }); - } - sw.Stop(); - - double avgTimeMs = sw.Elapsed.TotalMilliseconds / iterations; - Console.WriteLine($"JIT Compiled Execution:"); - Console.WriteLine($" {iterations} iterations in {sw.Elapsed.TotalMilliseconds:F2}ms"); - Console.WriteLine($" Average: {avgTimeMs:F4}ms per iteration"); - Console.WriteLine($" Throughput: {1000.0 / avgTimeMs:F0} operations/second\n"); - } - - /// - /// Example 4: Caching demonstration - /// - public static void CachingExample() - { - Console.WriteLine("=== Example 4: Caching Demonstration ===\n"); - - var jit = new global::AiDotNet.JitCompiler.JitCompiler(); - - // First compilation - var input1 = new ComputationNode(new Tensor(new[] { 2, 3 })) { OperationType = OperationType.Input }; - var relu1 = new ComputationNode( - new Tensor(new[] { 2, 3 }), - parents: new List> { input1 }) - { - OperationType = OperationType.ReLU - }; - - var (compiled1, stats1) = jit.CompileWithStats(relu1, new List> { input1 }); - Console.WriteLine($"First compilation:"); - Console.WriteLine($" Cache hit: {stats1.CacheHit}"); - Console.WriteLine($" Compilation time: {stats1.CompilationTime.TotalMilliseconds:F2}ms\n"); - - // Second compilation with same structure (should hit cache) - var input2 = new ComputationNode(new Tensor(new[] { 2, 3 })) { OperationType = OperationType.Input }; - var relu2 = new ComputationNode( - new Tensor(new[] { 2, 3 }), - parents: new List> { input2 }) - { - OperationType = OperationType.ReLU - }; - - var (compiled2, stats2) = jit.CompileWithStats(relu2, new List> { input2 }); - Console.WriteLine($"Second compilation (same structure):"); - Console.WriteLine($" Cache hit: {stats2.CacheHit}"); - Console.WriteLine($" Compilation time: {stats2.CompilationTime.TotalMilliseconds:F2}ms\n"); - - // Different structure (won't hit cache) - var sigmoid2 = new ComputationNode( - new Tensor(new[] { 2, 3 }), - parents: new List> { input2 }) - { - OperationType = OperationType.Sigmoid - }; - - var (compiled3, stats3) = jit.CompileWithStats(sigmoid2, new List> { input2 }); - Console.WriteLine($"Third compilation (different structure):"); - Console.WriteLine($" Cache hit: {stats3.CacheHit}"); - Console.WriteLine($" Compilation time: {stats3.CompilationTime.TotalMilliseconds:F2}ms\n"); - - // Cache stats - var cacheStats = jit.GetCacheStats(); - Console.WriteLine($"Cache statistics:"); - Console.WriteLine($" Cached graphs: {cacheStats.CachedGraphCount}"); - Console.WriteLine($" Estimated memory: {cacheStats.EstimatedMemoryBytes / 1024.0:F2} KB\n"); - } - - /// - /// Example 5: Custom compiler options - /// - public static void CustomOptionsExample() - { - Console.WriteLine("=== Example 5: Custom Compiler Options ===\n"); - - // Default options (all optimizations enabled) - var jitDefault = new global::AiDotNet.JitCompiler.JitCompiler(); - - // Custom options (selective optimizations) - var customOptions = new JitCompilerOptions - { - EnableConstantFolding = true, - EnableDeadCodeElimination = true, - EnableOperationFusion = false, // Disable fusion - EnableCaching = true - }; - var jitCustom = new global::AiDotNet.JitCompiler.JitCompiler(customOptions); - - // Build a graph - var input = new ComputationNode(new Tensor(new[] { 2, 3 })) { OperationType = OperationType.Input }; - var exp = new ComputationNode( - new Tensor(new[] { 2, 3 }), - parents: new List> { input }) - { - OperationType = OperationType.Exp - }; - - // Compile with default options - var (_, statsDefault) = jitDefault.CompileWithStats(exp, new List> { input }); - Console.WriteLine($"With default options:"); - Console.WriteLine($" Optimizations: {string.Join(", ", statsDefault.OptimizationsApplied)}\n"); - - // Compile with custom options - var (_, statsCustom) = jitCustom.CompileWithStats(exp, new List> { input }); - Console.WriteLine($"With custom options (fusion disabled):"); - Console.WriteLine($" Optimizations: {string.Join(", ", statsCustom.OptimizationsApplied)}\n"); - } - - /// - /// Helper to get tensor values as array - /// - private static float[] GetTensorValues(Tensor tensor) - { - var values = new float[tensor.Length]; - for (int i = 0; i < tensor.Length; i++) - { - values[i] = tensor[i]; - } - return values; - } - - /// - /// Run all examples - /// - public static void RunAllExamples() - { - try - { - SimpleElementwiseOperation(); - LinearLayerExample(); - PerformanceComparisonExample(); - CachingExample(); - CustomOptionsExample(); - - Console.WriteLine("=== All Examples Completed Successfully! ==="); - } - catch (Exception ex) - { - // Rethrow critical exceptions that should not be caught - if (ex is OutOfMemoryException || ex is StackOverflowException || ex is System.Threading.ThreadAbortException) - throw; - - Console.WriteLine($"Error running examples: {ex.Message}"); - Console.WriteLine(ex.StackTrace); - } - } -} diff --git a/examples/JitCompiler/README.md b/examples/JitCompiler/README.md deleted file mode 100644 index f7506c1f04..0000000000 --- a/examples/JitCompiler/README.md +++ /dev/null @@ -1,262 +0,0 @@ -# JIT Compiler Examples - -This directory contains practical examples demonstrating how to use the AiDotNet JIT compiler. - -## Examples Overview - -### BasicUsageExample.cs - -Contains 5 complete examples showing different aspects of JIT compilation: - -1. **Simple Element-wise Operation** - - Shows basic JIT compilation of a single operation - - Demonstrates compilation stats - - Executes compiled function - -2. **Linear Layer Example** - - Demonstrates fusion of MatMul + Add + ReLU - - Shows optimization statistics - - 3 operations → 1 fused operation - -3. **Performance Comparison** - - Benchmarks JIT compiled execution - - Measures throughput and latency - - Demonstrates real performance gains - -4. **Caching Demonstration** - - Shows cache hit/miss behavior - - Demonstrates compilation time savings - - Displays cache statistics - -5. **Custom Compiler Options** - - Shows how to configure optimization passes - - Compares default vs custom configurations - - Demonstrates selective optimization - -## Running the Examples - -### Option 1: From Code - -```csharp -using AiDotNet.Examples.JitCompiler; - -// Run all examples -BasicUsageExample.RunAllExamples(); - -// Or run individual examples -BasicUsageExample.SimpleElementwiseOperation(); -BasicUsageExample.LinearLayerExample(); -BasicUsageExample.PerformanceComparisonExample(); -BasicUsageExample.CachingExample(); -BasicUsageExample.CustomOptionsExample(); -``` - -### Option 2: Create Console App - -Create a simple console application: - -```csharp -using AiDotNet.Examples.JitCompiler; - -class Program -{ - static void Main(string[] args) - { - BasicUsageExample.RunAllExamples(); - } -} -``` - -### Option 3: Interactive (C# Interactive / LINQPad) - -```csharp -#load "BasicUsageExample.cs" - -using AiDotNet.Examples.JitCompiler; - -BasicUsageExample.SimpleElementwiseOperation(); -``` - -## Expected Output - -### Example 1: Simple Element-wise Operation -``` -=== Example 1: Simple Element-wise Operation === - -Compilation Stats: - Original operations: 1 - Optimized operations: 1 - Compilation time: 12.34ms - -Input: 1, 2, 3, 4, 5, 6, 7, 8, 9 -Output (ReLU): 1, 2, 3, 4, 5, 6, 7, 8, 9 -``` - -### Example 2: Linear Layer -``` -=== Example 2: Linear Layer (MatMul + Add + ReLU) === - -Compilation Stats: - Original operations: 3 - Optimized operations: 1 - Operations eliminated: 2 (66.7%) - Optimizations: Constant Folding, Dead Code Elimination, Operation Fusion - Compilation time: 18.56ms - -Input: 1, 2, 3 -Output: 2.3, 3.1, 3.9, 4.7 -``` - -### Example 3: Performance Comparison -``` -=== Example 3: Performance Comparison === - -Graph compiled in 15.23ms -Optimizations applied: Constant Folding, Dead Code Elimination, Operation Fusion - -JIT Compiled Execution: - 1000 iterations in 45.67ms - Average: 0.0457ms per iteration - Throughput: 21882 operations/second -``` - -### Example 4: Caching -``` -=== Example 4: Caching Demonstration === - -First compilation: - Cache hit: False - Compilation time: 12.45ms - -Second compilation (same structure): - Cache hit: True - Compilation time: 0.00ms - -Third compilation (different structure): - Cache hit: False - Compilation time: 11.23ms - -Cache statistics: - Cached graphs: 2 - Estimated memory: 2.00 KB -``` - -### Example 5: Custom Options -``` -=== Example 5: Custom Compiler Options === - -With default options: - Optimizations: Constant Folding, Dead Code Elimination, Operation Fusion - -With custom options (fusion disabled): - Optimizations: Constant Folding, Dead Code Elimination -``` - -## Learning Path - -1. **Start with Example 1** - Understand basic compilation workflow -2. **Move to Example 2** - See real optimization in action -3. **Study Example 3** - Understand performance benefits -4. **Explore Example 4** - Learn about caching behavior -5. **Experiment with Example 5** - Customize compiler settings - -## Tips and Best Practices - -### Setting Operation Metadata - -For JIT compilation to work, ComputationNodes must have `OperationType` set: - -```csharp -var node = new ComputationNode(tensor, parents: inputs) -{ - OperationType = "Add", // Required for JIT! - Name = "my_addition" // Optional, for debugging -}; -``` - -### When to Use JIT - -**Best for:** -- Inference (forward pass only) -- Repeated execution of same graph structure -- Large models with many operations -- Production deployments - -**Less beneficial for:** -- Training (backward pass not yet supported) -- Graphs that change structure frequently -- Very small operations (compilation overhead) - -### Performance Tips - -1. **Compile once, execute many times** - ```csharp - var compiled = jit.Compile(graph, inputs); - for (int i = 0; i < 1000; i++) { - var result = compiled(batchData[i]); // Fast! - } - ``` - -2. **Let caching work for you** - - Same graph structure → cache hit (instant) - - Different data → same compiled function works - -3. **Enable all optimizations** (default) - - Fusion can provide 2-5x speedup alone - - DCE removes overhead - - Constant folding reduces runtime work - -4. **Monitor compilation stats** - ```csharp - var (compiled, stats) = jit.CompileWithStats(graph, inputs); - if (stats.OptimizationPercentage > 50%) { - Console.WriteLine("Great optimizations!"); - } - ``` - -## Common Issues - -### "Node does not have OperationType metadata" - -**Problem:** ComputationNode missing `OperationType` property. - -**Solution:** Set it when creating nodes: -```csharp -node.OperationType = "ReLU"; -``` - -### Slow first execution - -**Problem:** First call includes compilation time. - -**Solution:** This is normal! Compile during initialization: -```csharp -// During setup -var compiled = jit.Compile(graph, inputs); - -// In hot path (fast!) -var result = compiled(data); -``` - -### Cache using too much memory - -**Problem:** Too many compiled graphs cached. - -**Solution:** Monitor and clear cache: -```csharp -var stats = jit.GetCacheStats(); -if (stats.EstimatedMemoryBytes > threshold) { - jit.ClearCache(); -} -``` - -## Next Steps - -- Read the [JIT Compiler Usage Guide](../../docs/JIT-Compiler-Usage-Guide.md) -- Explore the [Architecture README](../../src/JitCompiler/README.md) -- Run the performance benchmarks -- Integrate into your own models - -## Feedback - -Found an issue or have a question? Please file an issue on GitHub! diff --git a/scripts/run-tests-sharded.ps1 b/scripts/run-tests-sharded.ps1 index 415e6d78a6..97210f80bc 100644 --- a/scripts/run-tests-sharded.ps1 +++ b/scripts/run-tests-sharded.ps1 @@ -214,7 +214,6 @@ try { "Helpers", "Inference", "Interpretability", - "JitCompiler", "KnowledgeDistillation", "LearningRateSchedulers", "LinearAlgebra", @@ -242,14 +241,13 @@ try { $shards.Add((New-TestShard -Name "AiDotNet.Tests - Unit - 04 Diagnostics/Diffusion/Encoding" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-UnitShardFilter -CategoryFilter $categoryFilter -Roots $unitNamespaceRoots -Segments @("Diagnostics", "Diffusion", "Encoding")))) $shards.Add((New-TestShard -Name "AiDotNet.Tests - Unit - 05 Feature/Fit/Fitness" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-UnitShardFilter -CategoryFilter $categoryFilter -Roots $unitNamespaceRoots -Segments @("FeatureSelectors", "FitDetectors", "FitnessCalculators")))) $shards.Add((New-TestShard -Name "AiDotNet.Tests - Unit - 06 Genetics/Helpers/Inference" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-UnitShardFilter -CategoryFilter $categoryFilter -Roots $unitNamespaceRoots -Segments @("Genetics", "Helpers", "Inference")))) - $shards.Add((New-TestShard -Name "AiDotNet.Tests - Unit - 07 Interpretability/JIT/KD" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-UnitShardFilter -CategoryFilter $categoryFilter -Roots $unitNamespaceRoots -Segments @("Interpretability", "JitCompiler", "KnowledgeDistillation")))) + $shards.Add((New-TestShard -Name "AiDotNet.Tests - Unit - 07 Interpretability/KD" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-UnitShardFilter -CategoryFilter $categoryFilter -Roots $unitNamespaceRoots -Segments @("Interpretability", "KnowledgeDistillation")))) $shards.Add((New-TestShard -Name "AiDotNet.Tests - Unit - 08 Schedulers/LA/Logging/Loss" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-UnitShardFilter -CategoryFilter $categoryFilter -Roots $unitNamespaceRoots -Segments @("LearningRateSchedulers", "LinearAlgebra", "Logging", "LossFunctions")))) $shards.Add((New-TestShard -Name "AiDotNet.Tests - Unit - 09 Meta/Mixed/Compression" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-UnitShardFilter -CategoryFilter $categoryFilter -Roots $unitNamespaceRoots -Segments @("MetaLearning", "MixedPrecision", "ModelCompression")))) $shards.Add((New-TestShard -Name "AiDotNet.Tests - Unit - 10 NN/Optimizers/RAG" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-UnitShardFilter -CategoryFilter $categoryFilter -Roots $unitNamespaceRoots -Segments @("NeuralNetworks", "Optimizers", "RAG")))) $shards.Add((New-TestShard -Name "AiDotNet.Tests - Unit - 11 Regularization/RL/RAG2" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-UnitShardFilter -CategoryFilter $categoryFilter -Roots $unitNamespaceRoots -Segments @("Regularization", "ReinforcementLearning", "RetrievalAugmentedGeneration")))) $shards.Add((New-TestShard -Name "AiDotNet.Tests - Unit - 12 Serving/TimeSeries/Token/Transfer" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-UnitShardFilter -CategoryFilter $categoryFilter -Roots $unitNamespaceRoots -Segments @("Serving", "TimeSeries", "Tokenization", "TransferLearning")))) - $shards.Add((New-TestShard -Name "AiDotNet.Tests - Other - 13 InferenceOptimization" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter "$categoryFilter&FullyQualifiedName~AiDotNet.Tests.InferenceOptimization")) - $shards.Add((New-TestShard -Name "AiDotNet.Tests - Other - 14 PromptEngineering" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter "$categoryFilter&FullyQualifiedName~AiDotNet.Tests.PromptEngineering")) + $shards.Add((New-TestShard -Name "AiDotNet.Tests - Other - 13 PromptEngineering" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter "$categoryFilter&FullyQualifiedName~AiDotNet.Tests.PromptEngineering")) $shards.Add((New-TestShard -Name "AiDotNet.Tests - Other - 15 Recovery/Concurrency" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter "$categoryFilter&FullyQualifiedName~AiDotNet.Tests.Concurrency|$categoryFilter&FullyQualifiedName~AiDotNet.Tests.Recovery")) $shards.Add((New-TestShard -Name "AiDotNet.Tests - Other - 16 ActivationFunctions" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-NamespaceFilter -CategoryFilter $categoryFilter -Namespaces @("AiDotNet.Tests.ActivationFunctions")))) $shards.Add((New-TestShard -Name "AiDotNet.Tests - Other - 17 Factories" -Project "tests\AiDotNet.Tests\AiDotNetTests.csproj" -Filter (New-NamespaceFilter -CategoryFilter $categoryFilter -Namespaces @("AiDotNet.Tests.Factories")))) diff --git a/src/AiDotNet.Generators/TestScaffoldGenerator.cs b/src/AiDotNet.Generators/TestScaffoldGenerator.cs index 4ffd29ee50..709bfdda70 100644 --- a/src/AiDotNet.Generators/TestScaffoldGenerator.cs +++ b/src/AiDotNet.Generators/TestScaffoldGenerator.cs @@ -3839,8 +3839,6 @@ private static void EmitMockModelFactory(StringBuilder sb) sb.AppendLine(" public void SetActiveFeatureIndices(System.Collections.Generic.IEnumerable f) { }"); sb.AppendLine(" public bool IsFeatureUsed(int i) => false;"); sb.AppendLine(" public System.Collections.Generic.Dictionary GetFeatureImportance() => new();"); - sb.AppendLine(" public bool SupportsJitCompilation => false;"); - sb.AppendLine(" public AiDotNet.Autodiff.ComputationNode ExportComputationGraph(System.Collections.Generic.List> n) => throw new System.NotSupportedException();"); sb.AppendLine(" }"); } diff --git a/src/AiModelBuilder.cs b/src/AiModelBuilder.cs index 06fa27a9b7..176d63fed5 100644 --- a/src/AiModelBuilder.cs +++ b/src/AiModelBuilder.cs @@ -185,6 +185,11 @@ public partial class AiModelBuilder : IAiModelBuilder? _accelerationLogger; + private AiDotNet.Diagnostics.AccelerationSnapshot? _accelerationSnapshot; + private bool _tensorsOpProfilingEnabled; + private AiDotNet.Diagnostics.TensorsOperationProfile? _tensorsOperationProfile; /// /// When true, does NOT force deterministic @@ -927,6 +932,90 @@ public IAiModelBuilder ConfigureJitCompilation( return this; } + /// + /// Captures a snapshot of the active acceleration environment (SIMD, GPU, native BLAS) + /// at build time, logs it through if supplied, and surfaces the + /// structured snapshot on PredictionModelResult.AccelerationSnapshot. + /// + /// + /// Optional callback that receives the formatted report string. When null, the + /// report is written to . + /// + /// This builder for fluent chaining. + /// + /// + /// Useful for production observability — shows exactly which of AVX2/AVX-512/NEON, + /// CUDA/OpenCL/HIP, OpenBLAS/CLBlast/MKL are active on the target host. Users can + /// assert against the returned snapshot in CI (e.g., fail build if AVX-512 isn't + /// detected on an Intel Xeon host). + /// + /// + /// Wraps AiDotNet.Tensors.Engines.PlatformDetector and + /// AiDotNet.Tensors.Engines.NativeLibraryDetector. + /// + /// + public IAiModelBuilder ReportAccelerationStatus(Action? logger = null) + { + _reportAccelerationAtBuild = true; + _accelerationLogger = logger; + return this; + } + + /// + /// Enables disk-backed caching of compiled inference plans. Plans are saved after + /// the first compilation and loaded transparently on subsequent process starts, + /// skipping the trace+compile cost of cold start. + /// + /// + /// Filesystem directory where plan files are stored. Created if missing. Plans are + /// keyed by (concrete model type, element type, structure version, input shape) + /// so one directory can host plans for multiple models. + /// + /// This builder for fluent chaining. + /// + /// + /// PyTorch-parity equivalent: torch.jit.save + torch.jit.load. + /// Plans are tied to the host's hardware fingerprint (via Tensors' + /// PlanCompatibilityInfo); plans compiled on one host are rejected on + /// incompatible hardware, triggering a fresh compile. + /// + /// + /// Caching is opt-in — without this call, plans live only for the process + /// lifetime via the in-memory CompiledModelCache. + /// + /// + public IAiModelBuilder ConfigurePlanCaching(string directory) + { + if (string.IsNullOrWhiteSpace(directory)) + throw new ArgumentException("Plan cache directory must be a non-empty path.", nameof(directory)); + + AiDotNet.NeuralNetworks.PlanCache.SetCurrent(new AiDotNet.NeuralNetworks.PlanCache(directory)); + return this; + } + + /// + /// Enables low-level per-tensor-op profiling via Tensors' + /// PerformanceProfiler.Instance. After BuildAsync() returns, kernel + /// timings are captured on AiModelResult.TensorsOperationProfile. + /// Orthogonal to the higher-level ConfigureProfiling (AiDotNet workflow + /// timings) — both can be enabled together to get a full picture. + /// + /// This builder for fluent chaining. + /// + /// + /// The profiler wraps every engine op in a dispatchable scope — expect a small + /// overhead (~1-3%) during the measured window. Disable in production latency- + /// sensitive paths. PyTorch-parity equivalent: low-level + /// torch.profiler.profile CUDA/CPU op breakdown. + /// + /// + public IAiModelBuilder EnableTensorsOpProfiling() + { + _tensorsOpProfilingEnabled = true; + AiDotNet.Tensors.Engines.Optimization.PerformanceProfiler.Instance.Enabled = true; + return this; + } + /// /// Opts out of the builder's deterministic-by-default policy. Call this when /// you want the engine to pick the fastest available kernels even if they @@ -1505,7 +1594,7 @@ private AiModelResult BuildProgramSynthesisInferenceOnlyResu MemoryConfig = _memoryConfig }; - var programSynthesisResult = new AiModelResult(options); + var programSynthesisResult = AttachDiagnostics(new AiModelResult(options)); ProcessKnowledgeGraphOptions(programSynthesisResult); AttachSafetyPipeline(programSynthesisResult); return programSynthesisResult; @@ -1855,7 +1944,7 @@ private async Task> BuildStreamingSupervisedAs MemoryConfig = _memoryConfig }; - var nnResult = new AiModelResult(options); + var nnResult = AttachDiagnostics(new AiModelResult(options)); ProcessKnowledgeGraphOptions(nnResult); AttachSafetyPipeline(nnResult); return nnResult; @@ -3071,7 +3160,7 @@ T ObjectiveFunction(Dictionary trialHyperparameters) HyperparameterTrialId = bestHyperparameterTrialId }; - var finalResult = new AiModelResult(options); + var finalResult = AttachDiagnostics(new AiModelResult(options)); finalResult.SetUncertaintyQuantificationOptions(_uncertaintyQuantificationOptions); TryComputeAndAttachDeepEnsembleModels(finalResult, deepEnsembleTemplate, optimizationInputData, optimizer, _uncertaintyQuantificationOptions); @@ -3223,7 +3312,7 @@ private AiModelResult BuildMetaLearningInternalAsync() MemoryConfig = _memoryConfig }; - var result = new AiModelResult(metaOptions); + var result = AttachDiagnostics(new AiModelResult(metaOptions)); ProcessKnowledgeGraphOptions(result); AttachSafetyPipeline(result); @@ -3548,7 +3637,7 @@ private async Task> BuildRLInternalAsync(int e MemoryConfig = _memoryConfig }; - var result = new AiModelResult(rlOptions); + var result = AttachDiagnostics(new AiModelResult(rlOptions)); ProcessKnowledgeGraphOptions(result); AttachSafetyPipeline(result); @@ -3798,7 +3887,7 @@ public AiModelResult LoadModel(string filePath) trialManager.RecordOperationOrThrow(); } - var result = new AiModelResult(); + var result = AttachDiagnostics(new AiModelResult()); result.Model = fullModel; // Reattach Graph RAG components if configured @@ -3889,7 +3978,7 @@ public AiModelResult DeserializeModel(byte[] modelData) throw new ArgumentException("Model data cannot be empty.", nameof(modelData)); ModelPersistenceGuard.EnforceBeforeLoad(); - var result = new AiModelResult(); + var result = AttachDiagnostics(new AiModelResult()); using (ModelPersistenceGuard.InternalOperation()) { result.Deserialize(modelData); @@ -6516,6 +6605,38 @@ private void ConfigureDocumentTransformers(IFullModel? model } private void ApplyGpuConfiguration() + { + ApplyGpuConfigurationCore(); + ReportAccelerationIfRequested(); + } + + private void ReportAccelerationIfRequested() + { + if (!_reportAccelerationAtBuild) + { + return; + } + + _accelerationSnapshot = AiDotNet.Diagnostics.AccelerationDiagnostics.GetSnapshot(); + var report = AiDotNet.Diagnostics.AccelerationDiagnostics.GetReport(); + (_accelerationLogger ?? Console.WriteLine).Invoke(report); + } + + private AiModelResult AttachDiagnostics(AiModelResult result) + { + if (_accelerationSnapshot is not null) + { + result.AccelerationSnapshot = _accelerationSnapshot; + } + if (_tensorsOpProfilingEnabled) + { + _tensorsOperationProfile = AiDotNet.Diagnostics.TensorsOperationProfile.Capture(); + result.TensorsOperationProfile = _tensorsOperationProfile; + } + return result; + } + + private void ApplyGpuConfigurationCore() { // Skip if no GPU configuration was provided (null = default = auto-detect with CPU fallback) if (_gpuAccelerationConfig == null) diff --git a/src/AutoML/AutoMLModelBase.cs b/src/AutoML/AutoMLModelBase.cs index 6978375f6c..04d9a17bf3 100644 --- a/src/AutoML/AutoMLModelBase.cs +++ b/src/AutoML/AutoMLModelBase.cs @@ -1156,35 +1156,6 @@ public virtual void ApplyGradients(Vector gradients, T learningRate) InterfaceGuard.GradientComputable(BestModel).ApplyGradients(gradients, learningRate); } - #endregion - #region IJitCompilable Implementation - - /// - /// Gets whether this model currently supports JIT compilation. - /// - /// True if the best model found supports JIT compilation, false otherwise. - /// - /// - /// AutoML models delegate JIT compilation support to their best model. - /// If no best model has been found yet, JIT compilation is not supported. - /// - /// For Beginners: AutoML models can only be JIT compiled if the best model they found supports it. - /// - /// Since AutoML searches across multiple model types, JIT support depends on: - /// - Whether a best model has been selected - /// - Whether that specific model supports JIT compilation - /// - /// Before running SearchAsync, this will return false. - /// After finding a best model, it delegates to that model's JIT support. - /// - /// - public virtual bool SupportsJitCompilation => false; - - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation has been removed."); - } - #endregion /// diff --git a/src/CausalInference/CausalModelBase.cs b/src/CausalInference/CausalModelBase.cs index 4de1f5cf84..09424a1eaa 100644 --- a/src/CausalInference/CausalModelBase.cs +++ b/src/CausalInference/CausalModelBase.cs @@ -83,11 +83,6 @@ public abstract class CausalModelBase : ICausalModel, IModelShape /// public virtual bool SupportsParameterInitialization => ParameterCount > 0; - /// - /// Gets whether JIT compilation is supported. - /// - public virtual bool SupportsJitCompilation => false; - /// /// Initializes a new instance of the CausalModelBase class. /// @@ -587,14 +582,6 @@ public virtual void LoadState(Stream stream) Deserialize(serializedData); } - /// - /// Exports the computation graph for JIT compilation. - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation is not supported for this causal model."); - } - /// /// Gets additional model data to include in serialization. /// Override in derived classes to persist class-specific state. diff --git a/src/Classification/ClassifierBase.cs b/src/Classification/ClassifierBase.cs index ab7e3a82f8..072171082a 100644 --- a/src/Classification/ClassifierBase.cs +++ b/src/Classification/ClassifierBase.cs @@ -675,16 +675,4 @@ public virtual void LoadState(Stream stream) Deserialize(serializedData); } - #region IJitCompilable Implementation - - /// - public virtual bool SupportsJitCompilation => false; - - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation is not supported for this classifier. Override this method in derived classes to enable JIT support."); - } - - #endregion } diff --git a/src/Classification/MultiLabel/MultiLabelClassifierBase.cs b/src/Classification/MultiLabel/MultiLabelClassifierBase.cs index d24a061679..303844d20a 100644 --- a/src/Classification/MultiLabel/MultiLabelClassifierBase.cs +++ b/src/Classification/MultiLabel/MultiLabelClassifierBase.cs @@ -445,19 +445,6 @@ public virtual Vector ComputeGradients(Matrix input, Matrix target, ILo return gradients; } - #region IJitCompilable Implementation - - /// - public virtual bool SupportsJitCompilation => false; - - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation is not supported for this multi-label classifier."); - } - - #endregion - /// /// Binary cross-entropy loss for multi-label classification. /// diff --git a/src/Diagnostics/AccelerationDiagnostics.cs b/src/Diagnostics/AccelerationDiagnostics.cs new file mode 100644 index 0000000000..28ec19769d --- /dev/null +++ b/src/Diagnostics/AccelerationDiagnostics.cs @@ -0,0 +1,120 @@ +using System.Text; +using AiDotNet.Tensors.Engines; +using AiDotNet.Tensors.Helpers.Autotune; + +namespace AiDotNet.Diagnostics; + +/// +/// Snapshots the live acceleration environment so users can see which SIMD, GPU, and +/// native BLAS paths are actually engaged at runtime, instead of assuming from config. +/// +/// +/// +/// Wraps Tensors' and +/// into a single facade-friendly report that can be logged at builder time and surfaced +/// on PredictionModelResult for production observability. +/// +/// +public static class AccelerationDiagnostics +{ + public static string GetReport() + { + var caps = PlatformDetector.Capabilities; + var nativeStatus = NativeLibraryDetector.Status; + var sb = new StringBuilder(); + + sb.AppendLine("=== AiDotNet Acceleration Report ==="); + sb.AppendLine($" Engine: {AiDotNetEngine.Current.Name}"); + sb.AppendLine($" Deterministic: {AiDotNetEngine.DeterministicMode}"); + sb.AppendLine($" Framework: {caps.FrameworkDescription}"); + sb.AppendLine($" OS: {caps.OSDescription}"); + sb.AppendLine($" Arch: {caps.Architecture} ({(caps.Is64BitProcess ? "64-bit" : "32-bit")} process)"); + sb.AppendLine($" Processors: {caps.ProcessorCount}"); + sb.AppendLine($" L1/L2/L3 cache: {caps.L1CacheSize / 1024}KB / {caps.L2CacheSize / 1024}KB / {caps.L3CacheSize / 1024}KB"); + sb.AppendLine($" Best SIMD: {caps.GetBestSimdSet()}"); + + sb.AppendLine(" x86 SIMD: " + + $"SSE={caps.HasSSE} SSE2={caps.HasSSE2} SSE3={caps.HasSSE3} SSSE3={caps.HasSSSE3} " + + $"SSE4.1={caps.HasSSE41} SSE4.2={caps.HasSSE42} AVX={caps.HasAVX} AVX2={caps.HasAVX2} FMA={caps.HasFMA}"); + sb.AppendLine(" AVX-512: " + + $"F={caps.HasAVX512F} BW={caps.HasAVX512BW} DQ={caps.HasAVX512DQ} VL={caps.HasAVX512VL}"); + sb.AppendLine(" ARM SIMD: " + + $"NEON={caps.HasNeon} AES={caps.HasArmAes} CRC32={caps.HasArmCrc32} DP={caps.HasArmDp}"); + + sb.AppendLine(" GPU backends: " + + $"CUDA={caps.HasCudaSupport} OpenCL={caps.HasOpenCLSupport} HIP={caps.HasHipSupport}"); + sb.AppendLine(" Native BLAS: " + + $"OpenBLAS={nativeStatus.HasOpenBlas} CLBlast={nativeStatus.HasClBlast} " + + $"MKL={nativeStatus.HasMkl} CpuBLAS={nativeStatus.HasCpuBlas}"); + sb.AppendLine(" Native GPU libs: " + + $"CUDA={nativeStatus.HasCuda} HIP={nativeStatus.HasHip} OpenCL={nativeStatus.HasOpenCl}"); + + sb.AppendLine(" Autotune cache: " + + $"path={AutotuneCache.DefaultCachePath}"); + sb.AppendLine(" Autotune HW fp: " + + $"{AutotuneCache.CurrentHardwareFingerprint}"); + + sb.Append(NativeLibraryDetector.GetStatusSummary()); + return sb.ToString(); + } + + /// + /// Gets a structured snapshot of the current acceleration environment. + /// Intended for programmatic checks (assertions in tests, automated CI reports). + /// + public static AccelerationSnapshot GetSnapshot() + { + var caps = PlatformDetector.Capabilities; + var status = NativeLibraryDetector.Status; + return new AccelerationSnapshot + { + EngineName = AiDotNetEngine.Current.Name, + DeterministicMode = AiDotNetEngine.DeterministicMode, + BestSimdSet = caps.GetBestSimdSet(), + HasAvx2 = caps.HasAVX2, + HasAvx512F = caps.HasAVX512F, + HasFma = caps.HasFMA, + HasNeon = caps.HasNeon, + HasCuda = caps.HasCudaSupport, + HasOpenCl = caps.HasOpenCLSupport, + HasHip = caps.HasHipSupport, + HasOpenBlas = status.HasOpenBlas, + HasClBlast = status.HasClBlast, + HasMkl = status.HasMkl, + HasGpuAcceleration = status.HasGpuAcceleration, + ProcessorCount = caps.ProcessorCount, + L1CacheKB = caps.L1CacheSize / 1024, + L2CacheKB = caps.L2CacheSize / 1024, + L3CacheKB = caps.L3CacheSize / 1024, + AutotuneCachePath = AutotuneCache.DefaultCachePath, + AutotuneHardwareFingerprint = AutotuneCache.CurrentHardwareFingerprint, + }; + } +} + +/// +/// Immutable snapshot of acceleration state at a point in time. +/// +public sealed class AccelerationSnapshot +{ + public string EngineName { get; init; } = ""; + public bool DeterministicMode { get; init; } + public string BestSimdSet { get; init; } = ""; + public bool HasAvx2 { get; init; } + public bool HasAvx512F { get; init; } + public bool HasFma { get; init; } + public bool HasNeon { get; init; } + public bool HasCuda { get; init; } + public bool HasOpenCl { get; init; } + public bool HasHip { get; init; } + public bool HasOpenBlas { get; init; } + public bool HasClBlast { get; init; } + public bool HasMkl { get; init; } + public bool HasGpuAcceleration { get; init; } + public int ProcessorCount { get; init; } + public int L1CacheKB { get; init; } + public int L2CacheKB { get; init; } + public int L3CacheKB { get; init; } + public string AutotuneCachePath { get; init; } = ""; + public string AutotuneHardwareFingerprint { get; init; } = ""; +} diff --git a/src/Diagnostics/ProfilingReport.cs b/src/Diagnostics/ProfilingReport.cs new file mode 100644 index 0000000000..71c186755d --- /dev/null +++ b/src/Diagnostics/ProfilingReport.cs @@ -0,0 +1,86 @@ +using System.Text; +using AiDotNet.Tensors.Engines.Optimization; + +namespace AiDotNet.Diagnostics; + +/// +/// Structured performance report captured at build time when the builder opts in via +/// EnableProfiling(). Wraps Tensors' output +/// so callers don't have to dip into Tensors internals for a timing breakdown. +/// +public sealed class TensorsOperationProfile +{ + /// Per-operation timing statistics, sorted by total time descending. + public IReadOnlyList Operations { get; init; } = Array.Empty(); + + /// Total wall-clock time across every profiled operation (ms). + public double TotalMilliseconds { get; init; } + + /// The profiler's raw text report (via PerformanceProfiler.GenerateReport). + public string RawReport { get; init; } = ""; + + /// + /// Builds a ProfilingReport from the live Tensors + /// . + /// + public static TensorsOperationProfile Capture() + { + var profiler = PerformanceProfiler.Instance; + var stats = profiler.GetAllStats(); + var ops = new List(stats.Length); + double total = 0; + + foreach (var s in stats) + { + ops.Add(new OperationTiming + { + Name = s.OperationName, + CallCount = s.CallCount, + TotalMilliseconds = s.TotalMilliseconds, + AverageMilliseconds = s.AverageMilliseconds, + MinMilliseconds = s.MinMilliseconds, + MaxMilliseconds = s.MaxMilliseconds, + ThroughputOpsPerSecond = s.ThroughputOpsPerSecond, + TotalMemoryMB = s.TotalMemoryMB, + }); + total += s.TotalMilliseconds; + } + + ops.Sort((a, b) => b.TotalMilliseconds.CompareTo(a.TotalMilliseconds)); + + return new TensorsOperationProfile + { + Operations = ops, + TotalMilliseconds = total, + RawReport = profiler.GenerateReport(), + }; + } + + public string FormatSummary(int topN = 10) + { + var sb = new StringBuilder(); + sb.AppendLine($"=== AiDotNet Profiling Summary ({Operations.Count} ops, {TotalMilliseconds:F2} ms total) ==="); + foreach (var op in Operations.Take(topN)) + { + sb.AppendLine( + $" {op.Name,-40} " + + $"calls={op.CallCount,6} " + + $"total={op.TotalMilliseconds,9:F2}ms " + + $"avg={op.AverageMilliseconds,7:F3}ms " + + $"throughput={op.ThroughputOpsPerSecond,10:F0}/s"); + } + return sb.ToString(); + } +} + +public sealed class OperationTiming +{ + public string Name { get; init; } = ""; + public long CallCount { get; init; } + public double TotalMilliseconds { get; init; } + public double AverageMilliseconds { get; init; } + public double MinMilliseconds { get; init; } + public double MaxMilliseconds { get; init; } + public double ThroughputOpsPerSecond { get; init; } + public double TotalMemoryMB { get; init; } +} diff --git a/src/Diffusion/DiffusionModelBase.cs b/src/Diffusion/DiffusionModelBase.cs index ca398dd539..97c43359e4 100644 --- a/src/Diffusion/DiffusionModelBase.cs +++ b/src/Diffusion/DiffusionModelBase.cs @@ -238,9 +238,6 @@ private static IEnumerable ReflectInstanceDisposables(object root) /// public NeuralNetworkArchitecture? Architecture => _architecture; - /// - public virtual bool SupportsJitCompilation => false; - /// /// Initializes a new instance of the DiffusionModelBase class. /// @@ -937,15 +934,6 @@ public virtual void ApplyGradients(Vector gradients, T learningRate) #endregion - #region IJitCompilable Implementation - - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("This diffusion model does not support JIT compilation. Override ExportComputationGraph in derived class if needed."); - } - - #endregion #region Helper Methods diff --git a/src/Diffusion/NoisePredictors/NoisePredictorBase.cs b/src/Diffusion/NoisePredictors/NoisePredictorBase.cs index 000543c417..809d53d5d0 100644 --- a/src/Diffusion/NoisePredictors/NoisePredictorBase.cs +++ b/src/Diffusion/NoisePredictors/NoisePredictorBase.cs @@ -42,7 +42,7 @@ public abstract class NoisePredictorBase : INoisePredictor, IModelShape, I /// loop. First call traces, subsequent calls replay. Falls back to eager when /// compilation is disabled or fails. /// - private readonly AiDotNet.NeuralNetworks.CompiledModelHost _compileHost = new(); + private readonly AiDotNet.NeuralNetworks.CompiledModelHost _compileHost; /// /// Monotonic layer-graph version. Concrete predictors bump this via @@ -233,9 +233,6 @@ protected void InvalidateCompiledPlans() /// public ILossFunction DefaultLossFunction => LossFunction; - /// - public virtual bool SupportsJitCompilation => false; - /// /// Initializes a new instance of the NoisePredictorBase class. /// @@ -247,6 +244,9 @@ protected NoisePredictorBase(ILossFunction? lossFunction = null, int? seed = RandomGenerator = seed.HasValue ? RandomHelper.CreateSeededRandom(seed.Value) : RandomHelper.CreateSecureRandom(); + _compileHost = new AiDotNet.NeuralNetworks.CompiledModelHost( + shapeMode: AiDotNet.NeuralNetworks.SymbolicShapeMode.BatchDynamic, + modelIdentity: GetType().FullName ?? GetType().Name); } #region Lazy Layer Factories @@ -724,16 +724,6 @@ public virtual void ApplyGradients(Vector gradients, T learningRate) #endregion - #region IJitCompilable Implementation - - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("This noise predictor does not support JIT compilation. Override ExportComputationGraph in derived class if needed."); - } - - #endregion - #region Helper Methods /// diff --git a/src/Diffusion/VAE/VAEModelBase.cs b/src/Diffusion/VAE/VAEModelBase.cs index 563f9606b4..0d348c9dff 100644 --- a/src/Diffusion/VAE/VAEModelBase.cs +++ b/src/Diffusion/VAE/VAEModelBase.cs @@ -90,9 +90,6 @@ public abstract class VAEModelBase : IVAEModel, IModelShape /// public ILossFunction DefaultLossFunction => LossFunction; - /// - public virtual bool SupportsJitCompilation => false; - /// /// Initializes a new instance of the VAEModelBase class. /// @@ -593,16 +590,6 @@ public virtual void ApplyGradients(Vector gradients, T learningRate) #endregion - #region IJitCompilable Implementation - - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("This VAE does not support JIT compilation. Override ExportComputationGraph in derived class if needed."); - } - - #endregion - #region Helper Methods /// diff --git a/src/DistributedTraining/ShardedModelBase.cs b/src/DistributedTraining/ShardedModelBase.cs index 0953563a50..c398f8d8e7 100644 --- a/src/DistributedTraining/ShardedModelBase.cs +++ b/src/DistributedTraining/ShardedModelBase.cs @@ -484,32 +484,6 @@ public virtual void ApplyGradients(Vector gradients, T learningRate) } - #region IJitCompilable Implementation - - /// - /// Gets whether this model currently supports JIT compilation. - /// - /// True if the wrapped model supports JIT compilation, false otherwise. - /// - /// - /// Sharded models delegate JIT compilation support to their wrapped model. - /// JIT compilation is performed on the full model representation, not on individual shards. - /// - /// For Beginners: Distributed models can be JIT compiled if the underlying model supports it. - /// - /// The sharding strategy (splitting parameters across processes) doesn't prevent JIT compilation. - /// The JIT compiler works with the full computation graph, which is the same across all processes. - /// Individual processes execute the same compiled code but operate on different parameter shards. - /// - /// - public virtual bool SupportsJitCompilation => false; - - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation has been removed."); - } - - #endregion /// /// Saves the model's current state to a stream. /// diff --git a/src/Enums/OperationType.cs b/src/Enums/OperationType.cs index 6e1c7058b2..acaa7d4c8c 100644 --- a/src/Enums/OperationType.cs +++ b/src/Enums/OperationType.cs @@ -722,7 +722,7 @@ public enum OperationType /// Attention, - // InferenceOptimization Operations + // Layer and graph node operations /// /// Output node in computation graph. @@ -944,7 +944,7 @@ public enum OperationType /// Scatter, - // Fused Operations for InferenceOptimization + // Fused operations /// /// Fused Conv + BatchNorm + ReLU. diff --git a/src/Helpers/InterfaceGuard.cs b/src/Helpers/InterfaceGuard.cs index 62df27634a..74e2189d3b 100644 --- a/src/Helpers/InterfaceGuard.cs +++ b/src/Helpers/InterfaceGuard.cs @@ -2,11 +2,19 @@ namespace AiDotNet.Helpers; /// /// Provides safe runtime capability checks for interface segregation. -/// After removing IParameterizable, IFeatureAware, IGradientComputable, and IJitCompilable -/// from IFullModel, callers must validate capabilities before use. These methods provide -/// clear error messages when a model doesn't support the requested capability. +/// After removing IParameterizable, IFeatureAware, and IGradientComputable from IFullModel, +/// callers must validate capabilities before use. These methods provide clear error +/// messages when a model doesn't support the requested capability. /// -public static class InterfaceGuard +/// +/// +/// Visibility: internal to match the facade pattern — users interact with +/// AiModelBuilder / AiModelResult, and the InternalsVisibleTo attribute on +/// AiDotNet.csproj exposes this helper to the test/console/serving assemblies that need +/// capability checks from outside the main assembly. +/// +/// +internal static class InterfaceGuard { /// /// Returns the model as IParameterizable or throws with a clear message. diff --git a/src/Helpers/UsingsHelper.cs b/src/Helpers/UsingsHelper.cs index e396297ace..0b1c6019e5 100644 --- a/src/Helpers/UsingsHelper.cs +++ b/src/Helpers/UsingsHelper.cs @@ -7,4 +7,3 @@ // Resolve type ambiguity between AiDotNet and AiDotNet.Tensors.Helpers (0.13.0+) global using QuantizationMode = AiDotNet.Enums.QuantizationMode; -global using MemoryLayout = AiDotNet.InferenceOptimization.IR.Common.MemoryLayout; diff --git a/src/InferenceOptimization/ARCHITECTURE.md b/src/InferenceOptimization/ARCHITECTURE.md deleted file mode 100644 index b17eaa3f62..0000000000 --- a/src/InferenceOptimization/ARCHITECTURE.md +++ /dev/null @@ -1,60 +0,0 @@ -# AiDotNet Inference Optimization Architecture - -This document describes the internal structure of the `AiDotNet.InferenceOptimization` module and its extension points. - -## Design Goals - -- Hardware-aware CPU acceleration (SIMD when available) -- Deterministic behavior with safe fallbacks -- Low overhead when optimizations are disabled -- Extensible operator/kernels surface for future backends -- Thread-safe initialization and registration -- Optional profiling hooks for diagnosis - -## Key Components - -### OptimizationInitializer - -Responsibilities: -- One entrypoint to initialize platform detection and (optionally) profiling -- Ensures module initialization is safe to call multiple times - -### PlatformDetector - -Responsibilities: -- Detects process architecture and SIMD availability (x86/x64 and ARM) -- Exposes `PlatformCapabilities` used for selecting implementations - -Notes: -- Capability checks are runtime-based; unsupported intrinsics must always fall back to scalar implementations. - -### CustomOperatorRegistry - -Responsibilities: -- Registers multiple implementations per operation name -- Chooses the best supported implementation at runtime -- Caches the selection to avoid repeated capability checks - -### Kernels (`Kernels/*`) - -Responsibilities: -- Optimized building blocks for critical inference workloads: - - GEMM / matmul - - attention - - convolution - -Notes: -- Kernels are implemented with safe, span-based loops and use platform intrinsics only behind runtime capability checks. - -### CPU Helpers (`AiDotNet.Tensors/Engines/Optimization/*`) - -Responsibilities: -- Cache-aware helpers (tiling/transposition heuristics) -- Loop tiling/unrolling utilities where beneficial -- Optional profiling (`PerformanceProfiler`) for hotspot tracking - -## Integration Points - -- `AiDotNet.Inference.InferenceOptimizer` selects and applies inference-time implementations (e.g., attention variants, paged KV-cache) based on `InferenceOptimizationConfig`. -- `AiDotNet.Models.Results.AiModelResult` exposes facade-friendly entrypoints (`Predict`, `BeginInferenceSession`) while keeping internal complexity non-user-facing by default. - diff --git a/src/InferenceOptimization/Core/GraphBuilder.cs b/src/InferenceOptimization/Core/GraphBuilder.cs deleted file mode 100644 index 04324e0673..0000000000 --- a/src/InferenceOptimization/Core/GraphBuilder.cs +++ /dev/null @@ -1,166 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.Interfaces; - -namespace AiDotNet.InferenceOptimization.Core; - -/// -/// Builds a computation graph from a neural network or sequence of layers. -/// -/// The numeric type (double, float, decimal) -public class GraphBuilder where T : struct -{ - private readonly OptimizationGraph _graph; - private readonly Dictionary> _layerToNode; - - public GraphBuilder() - { - _graph = new OptimizationGraph(); - _layerToNode = new Dictionary>(); - } - - /// - /// Creates a graph from a list of layers. - /// - public IOptimizationGraph BuildFromLayers(IEnumerable> layers) - { - OptimizationNode? previousNode = null; - - // Create input node - var inputNode = new OptimizationNode - { - OperationType = OperationType.Input, - Name = "input" - }; - _graph.AddNode(inputNode); - previousNode = inputNode; - - // Convert each layer to a node - foreach (var layer in layers) - { - var node = LayerToNode(layer); - - if (previousNode != null) - { - node.AddInput(previousNode); - } - - _graph.AddNode(node); - _layerToNode[layer] = node; - - previousNode = node; - } - - // Create output node - var outputNode = new OptimizationNode - { - OperationType = OperationType.Output, - Name = "output" - }; - - if (previousNode != null) - { - outputNode.AddInput(previousNode); - } - - _graph.AddNode(outputNode); - - return _graph; - } - - /// - /// Converts a layer to a computation node. - /// - private OptimizationNode LayerToNode(ILayer layer) - { - var layerType = layer.GetType(); - var operationType = InferOperationType(layerType.Name); - - var node = new OptimizationNode - { - OperationType = operationType, - Name = layerType.Name, - OriginalLayer = layer, - Parameters = ExtractParameters(layer) - }; - - return node; - } - - /// - /// Infers the operation type from the layer type name. - /// - private OperationType InferOperationType(string layerTypeName) - { - // Remove generic type suffixes and "Layer" suffix - var cleanName = layerTypeName - .Replace("`1", "") - .Replace("Layer", "") - .Replace("", ""); - - return cleanName switch - { - "Convolutional" => OperationType.Convolution, - "Convolution2D" => OperationType.Convolution2D, - "BatchNormalization" => OperationType.BatchNormalization, - "LayerNormalization" => OperationType.LayerNormalization, - "ReLU" => OperationType.ReLU, - "LeakyReLU" => OperationType.LeakyReLU, - "Sigmoid" => OperationType.Sigmoid, - "Tanh" => OperationType.Tanh, - "Softmax" => OperationType.Softmax, - "MaxPooling" => OperationType.MaxPooling, - "AveragePooling" => OperationType.AveragePooling, - "FullyConnected" => OperationType.FullyConnected, - "Dense" => OperationType.Dense, - "LSTM" => OperationType.LSTM, - "GRU" => OperationType.GRU, - "Attention" => OperationType.Attention, - "MultiHeadAttention" => OperationType.MultiHeadAttention, - "Dropout" => OperationType.Dropout, - "Flatten" => OperationType.Flatten, - "Embedding" => OperationType.Embedding, - _ => OperationType.Custom - }; - } - - /// - /// Extracts parameters from a layer. - /// - private Dictionary ExtractParameters(ILayer layer) - { - var parameters = new Dictionary(); - - // Use reflection to get layer properties - var props = layer.GetType().GetProperties(); - - foreach (var prop in props) - { - try - { - var value = prop.GetValue(layer); - if (value != null) - { - parameters[prop.Name] = value; - } - } - catch (System.Reflection.TargetInvocationException) - { - // Skip properties whose getter throws - } - catch (InvalidOperationException) - { - // Skip properties that can't be read - } - } - - return parameters; - } - - /// - /// Gets the computation graph. - /// - public IOptimizationGraph GetGraph() - { - return _graph; - } -} diff --git a/src/InferenceOptimization/Core/GraphOptimizer.cs b/src/InferenceOptimization/Core/GraphOptimizer.cs deleted file mode 100644 index 6f5574ffd7..0000000000 --- a/src/InferenceOptimization/Core/GraphOptimizer.cs +++ /dev/null @@ -1,187 +0,0 @@ -using System.Diagnostics; -using AiDotNet.InferenceOptimization.Passes; - -namespace AiDotNet.InferenceOptimization.Core; - -/// -/// Main optimizer engine that orchestrates all optimization passes. -/// Applies transformations to computation graphs to improve inference performance. -/// -/// The numeric type (double, float, decimal) -public class GraphOptimizer where T : struct -{ - private readonly OptimizationOptions _options; - private readonly List> _passes; - - public GraphOptimizer(OptimizationOptions? options = null) - { - _options = options ?? OptimizationOptions.FromLevel(OptimizationLevel.Standard); - _passes = new List>(); - - InitializePasses(); - } - - /// - /// Optimizes the given computation graph. - /// - public IOptimizationGraph Optimize(IOptimizationGraph graph) - { - if (graph == null) - throw new ArgumentNullException(nameof(graph)); - - if (_options.Level == OptimizationLevel.None) - { - return graph; - } - - var optimizedGraph = graph.Clone(); - var stopwatch = Stopwatch.StartNew(); - - if (_options.PrintStatistics) - { - Console.WriteLine("=== Graph Optimization Started ==="); - Console.WriteLine($"Optimization Level: {_options.Level}"); - Console.WriteLine($"Initial Graph: {optimizedGraph.GetStatistics()}"); - Console.WriteLine(); - } - - int iteration = 0; - bool graphChanged = true; - - while (graphChanged && iteration < _options.MaxIterations) - { - graphChanged = false; - - foreach (var pass in _passes) - { - if (!pass.CanApply(optimizedGraph)) - { - continue; - } - - var passStopwatch = Stopwatch.StartNew(); - bool passModified = pass.Apply(optimizedGraph); - passStopwatch.Stop(); - - if (passModified) - { - graphChanged = true; - - if (_options.PrintStatistics) - { - Console.WriteLine($"[Iteration {iteration}] {pass.Name}: Modified graph in {passStopwatch.ElapsedMilliseconds}ms"); - } - - if (_options.ValidateAfterEachPass && !optimizedGraph.Validate()) - { - throw new InvalidOperationException($"Graph validation failed after {pass.Name}"); - } - } - } - - iteration++; - } - - stopwatch.Stop(); - - if (_options.PrintStatistics) - { - Console.WriteLine(); - Console.WriteLine($"Final Graph: {optimizedGraph.GetStatistics()}"); - Console.WriteLine($"Total Iterations: {iteration}"); - Console.WriteLine($"Total Time: {stopwatch.ElapsedMilliseconds}ms"); - Console.WriteLine("=== Graph Optimization Completed ==="); - } - - return optimizedGraph; - } - - /// - /// Adds a custom optimization pass. - /// - public void AddPass(IOptimizationPass pass) - { - if (pass == null) - throw new ArgumentNullException(nameof(pass)); - - _passes.Add(pass); - } - - /// - /// Removes all passes of a specific type. - /// - public void RemovePass(Type passType) - { - _passes.RemoveAll(p => p.GetType() == passType); - } - - private void InitializePasses() - { - // Phase 1: Algebraic simplification (reduces graph complexity early) - if (_options.EnableAlgebraicSimplification) - { - _passes.Add(new AlgebraicSimplificationPass()); - } - - // Phase 2: Constant folding (evaluate constants early) - if (_options.EnableConstantFolding) - { - _passes.Add(new ConstantFoldingPass()); - } - - // Phase 3: Operator fusion (critical for performance) - if (_options.EnableOperatorFusion) - { - // Order matters: try more specific fusions first - _passes.Add(new ConvBatchNormReLUFusionPass()); - _passes.Add(new ConvBatchNormFusionPass()); - _passes.Add(new MatMulBiasActivationFusionPass()); - _passes.Add(new MatMulBiasFusionPass()); - _passes.Add(new MultiHeadAttentionFusionPass()); - _passes.Add(new ElementwiseFusionPass()); - } - - // Phase 4: Common subexpression elimination - if (_options.EnableCSE) - { - _passes.Add(new CommonSubexpressionEliminationPass()); - } - - // Phase 5: Strength reduction - if (_options.EnableStrengthReduction) - { - _passes.Add(new StrengthReductionPass()); - } - - // Phase 6: Layout optimization - if (_options.EnableLayoutOptimization) - { - _passes.Add(new LayoutOptimizationPass(_options.TargetLayout)); - } - - // Phase 7: Memory optimizations - if (_options.EnableInPlaceOptimization) - { - _passes.Add(new InPlaceOptimizationPass()); - } - - if (_options.EnableMemoryReuse) - { - _passes.Add(new MemoryReuseOptimizationPass()); - } - - // Phase 8: Dead code elimination (should be last to clean up) - if (_options.EnableDeadCodeElimination) - { - _passes.Add(new DeadCodeEliminationPass()); - } - } - - /// - /// Gets the list of active optimization passes. - /// - public IReadOnlyList> GetPasses() - { - return _passes.AsReadOnly(); - } -} diff --git a/src/InferenceOptimization/Core/IOptimizationGraph.cs b/src/InferenceOptimization/Core/IOptimizationGraph.cs deleted file mode 100644 index f6c0a3ed19..0000000000 --- a/src/InferenceOptimization/Core/IOptimizationGraph.cs +++ /dev/null @@ -1,70 +0,0 @@ -namespace AiDotNet.InferenceOptimization.Core; - -/// -/// Interface for an optimization graph that represents the structure of neural network operations. -/// The graph can be optimized through various passes for improved inference performance. -/// -/// -/// -/// IOptimizationGraph is the core interface for the middle-layer IR in our two-tier architecture. -/// It provides graph manipulation capabilities needed for optimization passes. -/// -/// -/// The numeric type (double, float, decimal) -public interface IOptimizationGraph where T : struct -{ - /// - /// All nodes in the optimization graph. - /// - List> Nodes { get; } - - /// - /// Input nodes of the graph. - /// - List> InputNodes { get; } - - /// - /// Output nodes of the graph. - /// - List> OutputNodes { get; } - - /// - /// Adds a new node to the graph. - /// - void AddNode(OptimizationNode node); - - /// - /// Removes a node from the graph. - /// - void RemoveNode(OptimizationNode node); - - /// - /// Finds a node by its ID. - /// - OptimizationNode? FindNodeById(string id); - - /// - /// Finds nodes by name. - /// - List> FindNodesByName(string name); - - /// - /// Gets nodes in topological order (inputs to outputs). - /// - List> GetTopologicalOrder(); - - /// - /// Validates the graph structure. - /// - bool Validate(); - - /// - /// Creates a deep copy of the graph. - /// - IOptimizationGraph Clone(); - - /// - /// Gets statistics about the graph. - /// - GraphStatistics GetStatistics(); -} diff --git a/src/InferenceOptimization/Core/OptimizationGraph.cs b/src/InferenceOptimization/Core/OptimizationGraph.cs deleted file mode 100644 index c691a89b42..0000000000 --- a/src/InferenceOptimization/Core/OptimizationGraph.cs +++ /dev/null @@ -1,263 +0,0 @@ -using AiDotNet.Enums; - -namespace AiDotNet.InferenceOptimization.Core; - -/// -/// Represents an optimization graph for neural network inference. -/// The graph consists of nodes (operations) and edges (data dependencies). -/// -/// -/// -/// OptimizationGraph is the concrete implementation of the middle-layer IR graph. -/// It provides efficient graph manipulation for optimization passes. -/// -/// -/// The numeric type (double, float, decimal) -public class OptimizationGraph : IOptimizationGraph where T : struct -{ - public List> Nodes { get; private set; } - public List> InputNodes { get; private set; } - public List> OutputNodes { get; private set; } - - private readonly Dictionary> _nodeIndex; - - public OptimizationGraph() - { - Nodes = new List>(); - InputNodes = new List>(); - OutputNodes = new List>(); - _nodeIndex = new Dictionary>(); - } - - public void AddNode(OptimizationNode node) - { - if (node == null) - { - throw new ArgumentNullException(nameof(node)); - } - - if (!_nodeIndex.ContainsKey(node.Id)) - { - Nodes.Add(node); - _nodeIndex[node.Id] = node; - - // Track input/output nodes - if (node.OperationType == OperationType.Input) - { - InputNodes.Add(node); - } - else if (node.OperationType == OperationType.Output) - { - OutputNodes.Add(node); - } - } - } - - public void RemoveNode(OptimizationNode node) - { - if (node == null) return; - - // Remove connections - foreach (var input in node.Inputs.ToList()) - { - input.Outputs.Remove(node); - } - - foreach (var output in node.Outputs.ToList()) - { - output.Inputs.Remove(node); - } - - // Remove from collections - Nodes.Remove(node); - InputNodes.Remove(node); - OutputNodes.Remove(node); - _nodeIndex.Remove(node.Id); - } - - public OptimizationNode? FindNodeById(string id) - { - if (id == null) - throw new ArgumentNullException(nameof(id)); - - return _nodeIndex.TryGetValue(id, out var node) ? node : null; - } - - public List> FindNodesByName(string name) - { - if (name == null) - throw new ArgumentNullException(nameof(name)); - - return Nodes.Where(n => n.Name == name).ToList(); - } - - public List> GetTopologicalOrder() - { - var visited = new HashSet>(); - var result = new List>(); - var inStack = new HashSet>(); - - foreach (var node in Nodes.Where(node => !visited.Contains(node))) - { - if (!TopologicalSortUtil(node, visited, inStack, result)) - { - throw new InvalidOperationException("Graph contains a cycle"); - } - } - - return result; - } - - private bool TopologicalSortUtil( - OptimizationNode node, - HashSet> visited, - HashSet> inStack, - List> result) - { - if (inStack.Contains(node)) - { - return false; // Cycle detected - } - - if (visited.Contains(node)) - { - return true; // Already processed - } - - visited.Add(node); - inStack.Add(node); - - foreach (var input in node.Inputs) - { - if (!TopologicalSortUtil(input, visited, inStack, result)) - { - return false; - } - } - - inStack.Remove(node); - result.Add(node); - - return true; - } - - public bool Validate() - { - try - { - // Check for cycles (GetTopologicalOrder throws if there's a cycle) - GetTopologicalOrder(); - - // Check that all nodes are reachable from inputs - var reachable = new HashSet>(); - var queue = new Queue>(InputNodes); - - while (queue.Count > 0) - { - var node = queue.Dequeue(); - if (reachable.Contains(node)) continue; - - reachable.Add(node); - - foreach (var output in node.Outputs) - { - queue.Enqueue(output); - } - } - - // All nodes should be reachable from inputs - var unreachable = Nodes.Where(n => !reachable.Contains(n) && n.OperationType != OperationType.Constant).ToList(); - if (unreachable.Any()) - { - return false; - } - - return true; - } - catch (InvalidOperationException) - { - // Graph contains a cycle or other structural issue - return false; - } - } - - public IOptimizationGraph Clone() - { - var clonedGraph = new OptimizationGraph(); - var nodeMapping = new Dictionary>(); - - // Clone all nodes - foreach (var node in Nodes) - { - var clonedNode = node.Clone(); - clonedGraph.AddNode(clonedNode); - nodeMapping[node.Id] = clonedNode; - } - - // Rebuild connections - foreach (var node in Nodes) - { - var clonedNode = nodeMapping[node.Id]; - - foreach (var input in node.Inputs) - { - var clonedInput = nodeMapping[input.Id]; - clonedNode.AddInput(clonedInput); - } - } - - return clonedGraph; - } - - /// - /// Gets statistics about the graph. - /// - public GraphStatistics GetStatistics() - { - var stats = new GraphStatistics - { - TotalNodes = Nodes.Count, - InputNodes = InputNodes.Count, - OutputNodes = OutputNodes.Count, - FusedNodes = Nodes.Count(n => n.IsFused), - OperationTypeCounts = new Dictionary() - }; - - foreach (var node in Nodes) - { - if (!stats.OperationTypeCounts.ContainsKey(node.OperationType)) - { - stats.OperationTypeCounts[node.OperationType] = 0; - } - stats.OperationTypeCounts[node.OperationType]++; - } - - return stats; - } - - public override string ToString() - { - return $"OptimizationGraph: {Nodes.Count} nodes, {InputNodes.Count} inputs, {OutputNodes.Count} outputs"; - } -} - -/// -/// Statistics about an optimization graph. -/// -public class GraphStatistics -{ - public int TotalNodes { get; set; } - public int InputNodes { get; set; } - public int OutputNodes { get; set; } - public int FusedNodes { get; set; } - public int TotalOperations { get; set; } - public long EstimatedFLOPs { get; set; } - public long EstimatedMemoryBytes { get; set; } - public Dictionary OperationTypeCounts { get; set; } = new(); - - public override string ToString() - { - var opsString = string.Join(", ", OperationTypeCounts.Select(kv => $"{kv.Key}: {kv.Value}")); - return $"Total: {TotalNodes}, Inputs: {InputNodes}, Outputs: {OutputNodes}, Fused: {FusedNodes}\nOperations: {opsString}"; - } -} diff --git a/src/InferenceOptimization/Core/OptimizationLevel.cs b/src/InferenceOptimization/Core/OptimizationLevel.cs deleted file mode 100644 index 27f77de547..0000000000 --- a/src/InferenceOptimization/Core/OptimizationLevel.cs +++ /dev/null @@ -1,37 +0,0 @@ -namespace AiDotNet.InferenceOptimization.Core; - -/// -/// Defines the level of optimization to apply to the computation graph. -/// Higher levels apply more aggressive optimizations but may take longer to compile. -/// -public enum OptimizationLevel -{ - /// - /// No optimization - use the graph as-is. - /// - None = 0, - - /// - /// Basic optimizations - dead code elimination, constant folding. - /// Fast to compile, minimal speedup. - /// - Basic = 1, - - /// - /// Standard optimizations - includes basic + operator fusion + algebraic simplification. - /// Balanced compile time and performance. Recommended for most use cases. - /// - Standard = 2, - - /// - /// Aggressive optimizations - includes standard + memory optimizations + CSE. - /// Longer compile time, significant speedup. Good for production deployments. - /// - Aggressive = 3, - - /// - /// Maximum optimizations - all available optimizations. - /// Longest compile time, maximum speedup. Use for critical inference paths. - /// - Maximum = 4 -} diff --git a/src/InferenceOptimization/Core/OptimizationNode.cs b/src/InferenceOptimization/Core/OptimizationNode.cs deleted file mode 100644 index a111c929f5..0000000000 --- a/src/InferenceOptimization/Core/OptimizationNode.cs +++ /dev/null @@ -1,211 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.LinearAlgebra; - -namespace AiDotNet.InferenceOptimization.Core; - -/// -/// Represents a single operation node in the optimization graph. -/// This is the middle-layer IR node used for graph-level optimizations -/// before lowering to the final execution representation. -/// -/// -/// Architecture Position: -/// -/// OptimizationNode sits in the middle layer of our two-tier IR architecture: -/// -/// -/// High-Level IR (HLIR): Model-level semantic operations -/// Middle Layer (OptimizationNode): Graph optimization target -/// Low-Level IR (LLIR): Hardware-optimized execution operations -/// -/// -/// OptimizationNode differs from Autodiff.ComputationNode in that it focuses on -/// static graph representation for inference optimization rather than gradient computation. -/// -/// -/// The numeric type (double, float, decimal) -public class OptimizationNode where T : struct -{ - /// - /// Unique identifier for this node in the graph. - /// - public string Id { get; set; } - - /// - /// The type of operation this node performs. - /// - public OperationType OperationType { get; set; } - - /// - /// Human-readable name for this node (e.g., "conv1", "bn1", "relu1"). - /// - public string Name { get; set; } - - /// - /// Input nodes that feed data into this node. - /// - public List> Inputs { get; set; } - - /// - /// Output nodes that consume data from this node. - /// - public List> Outputs { get; set; } - - /// - /// Shape of the output tensor produced by this node. - /// - public int[] OutputShape { get; set; } - - /// - /// Parameters associated with this operation (e.g., weights, biases). - /// - public Dictionary Parameters { get; set; } - - /// - /// Metadata for this node (e.g., stride, padding, kernel size). - /// - public Dictionary Metadata { get; set; } - - /// - /// Constant value if this is a constant node. - /// - public Tensor? ConstantValue { get; set; } - - /// - /// Indicates if this node can be eliminated (e.g., has no side effects). - /// - public bool CanEliminate { get; set; } - - /// - /// Indicates if this node can perform in-place operations. - /// - public bool CanOperateInPlace { get; set; } - - /// - /// Reference to the original layer (if applicable). - /// - public object? OriginalLayer { get; set; } - - /// - /// Indicates if this node has been marked for deletion. - /// - public bool IsMarkedForDeletion { get; set; } - - /// - /// Indicates if this node is a fused operation. - /// - public bool IsFused { get; set; } - - /// - /// If this is a fused node, contains the original nodes that were fused. - /// - public List>? FusedFrom { get; set; } - - public OptimizationNode() - { - Id = Guid.NewGuid().ToString(); - Name = string.Empty; - Inputs = new List>(); - Outputs = new List>(); - OutputShape = Array.Empty(); - Parameters = new Dictionary(); - Metadata = new Dictionary(); - CanEliminate = true; - CanOperateInPlace = false; - IsMarkedForDeletion = false; - IsFused = false; - } - - /// - /// Adds an input node to this node. - /// - public void AddInput(OptimizationNode inputNode) - { - if (inputNode == null) - throw new ArgumentNullException(nameof(inputNode)); - - if (!Inputs.Contains(inputNode)) - { - Inputs.Add(inputNode); - } - - if (!inputNode.Outputs.Contains(this)) - { - inputNode.Outputs.Add(this); - } - } - - /// - /// Removes an input node from this node. - /// - public void RemoveInput(OptimizationNode inputNode) - { - if (inputNode == null) - throw new ArgumentNullException(nameof(inputNode)); - - Inputs.Remove(inputNode); - inputNode.Outputs.Remove(this); - } - - /// - /// Replaces an input node with another node. - /// - public void ReplaceInput(OptimizationNode oldInput, OptimizationNode newInput) - { - if (oldInput == null) - throw new ArgumentNullException(nameof(oldInput)); - if (newInput == null) - throw new ArgumentNullException(nameof(newInput)); - - var index = Inputs.IndexOf(oldInput); - if (index >= 0) - { - Inputs[index] = newInput; - oldInput.Outputs.Remove(this); - - if (!newInput.Outputs.Contains(this)) - { - newInput.Outputs.Add(this); - } - } - } - - /// - /// Checks if this node has any consumers (output nodes). - /// - public bool HasConsumers() => Outputs.Count > 0; - - /// - /// Gets the number of consumers for this node. - /// - public int ConsumerCount() => Outputs.Count; - - /// - /// Creates a deep copy of this node (without connections). - /// - public OptimizationNode Clone() - { - return new OptimizationNode - { - Id = Guid.NewGuid().ToString(), // New ID for clone - OperationType = OperationType, - Name = Name + "_clone", - OutputShape = (int[])OutputShape.Clone(), - Parameters = new Dictionary(Parameters), - Metadata = new Dictionary(Metadata), - ConstantValue = ConstantValue, - CanEliminate = CanEliminate, - CanOperateInPlace = CanOperateInPlace, - OriginalLayer = OriginalLayer, - IsFused = IsFused - }; - } - - public override string ToString() - { - var inputCount = Inputs.Count; - var outputCount = Outputs.Count; - var shape = OutputShape.Length > 0 ? $"[{string.Join(", ", OutputShape)}]" : "[]"; - return $"{Name} ({OperationType}) - Inputs: {inputCount}, Outputs: {outputCount}, Shape: {shape}"; - } -} diff --git a/src/InferenceOptimization/Core/OptimizationOptions.cs b/src/InferenceOptimization/Core/OptimizationOptions.cs deleted file mode 100644 index c55a504d57..0000000000 --- a/src/InferenceOptimization/Core/OptimizationOptions.cs +++ /dev/null @@ -1,144 +0,0 @@ -namespace AiDotNet.InferenceOptimization.Core; - -/// -/// Configuration options for graph optimization. -/// -public class OptimizationOptions -{ - /// - /// The optimization level to apply. - /// - public OptimizationLevel Level { get; set; } = OptimizationLevel.Standard; - - /// - /// Target layout for tensor operations (NCHW or NHWC). - /// - public string TargetLayout { get; set; } = "NCHW"; - - /// - /// Maximum number of optimization iterations. - /// - public int MaxIterations { get; set; } = 10; - - /// - /// Enable operator fusion. - /// - public bool EnableOperatorFusion { get; set; } = true; - - /// - /// Enable constant folding. - /// - public bool EnableConstantFolding { get; set; } = true; - - /// - /// Enable dead code elimination. - /// - public bool EnableDeadCodeElimination { get; set; } = true; - - /// - /// Enable common subexpression elimination. - /// - public bool EnableCSE { get; set; } = true; - - /// - /// Enable layout optimization. - /// - public bool EnableLayoutOptimization { get; set; } = true; - - /// - /// Enable in-place operations. - /// - public bool EnableInPlaceOptimization { get; set; } = true; - - /// - /// Enable memory reuse optimization. - /// - public bool EnableMemoryReuse { get; set; } = true; - - /// - /// Enable algebraic simplification. - /// - public bool EnableAlgebraicSimplification { get; set; } = true; - - /// - /// Enable strength reduction. - /// - public bool EnableStrengthReduction { get; set; } = true; - - /// - /// Print optimization statistics. - /// - public bool PrintStatistics { get; set; } = false; - - /// - /// Validate graph after each pass. - /// - public bool ValidateAfterEachPass { get; set; } = false; - - /// - /// Creates options based on optimization level. - /// - public static OptimizationOptions FromLevel(OptimizationLevel level) - { - var options = new OptimizationOptions { Level = level }; - - switch (level) - { - case OptimizationLevel.None: - options.DisableAllOptimizations(); - break; - - case OptimizationLevel.Basic: - options.DisableAllOptimizations(); - options.EnableDeadCodeElimination = true; - options.EnableConstantFolding = true; - break; - - case OptimizationLevel.Standard: - options.EnableOperatorFusion = true; - options.EnableConstantFolding = true; - options.EnableDeadCodeElimination = true; - options.EnableAlgebraicSimplification = true; - break; - - case OptimizationLevel.Aggressive: - options.EnableOperatorFusion = true; - options.EnableConstantFolding = true; - options.EnableDeadCodeElimination = true; - options.EnableCSE = true; - options.EnableAlgebraicSimplification = true; - options.EnableStrengthReduction = true; - options.EnableInPlaceOptimization = true; - options.EnableMemoryReuse = true; - break; - - case OptimizationLevel.Maximum: - // Enable everything - options.EnableOperatorFusion = true; - options.EnableConstantFolding = true; - options.EnableDeadCodeElimination = true; - options.EnableCSE = true; - options.EnableLayoutOptimization = true; - options.EnableInPlaceOptimization = true; - options.EnableMemoryReuse = true; - options.EnableAlgebraicSimplification = true; - options.EnableStrengthReduction = true; - break; - } - - return options; - } - - private void DisableAllOptimizations() - { - EnableOperatorFusion = false; - EnableConstantFolding = false; - EnableDeadCodeElimination = false; - EnableCSE = false; - EnableLayoutOptimization = false; - EnableInPlaceOptimization = false; - EnableMemoryReuse = false; - EnableAlgebraicSimplification = false; - EnableStrengthReduction = false; - } -} diff --git a/src/InferenceOptimization/CustomOperatorRegistry.cs b/src/InferenceOptimization/CustomOperatorRegistry.cs deleted file mode 100644 index cf661f5446..0000000000 --- a/src/InferenceOptimization/CustomOperatorRegistry.cs +++ /dev/null @@ -1,205 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; - -namespace AiDotNet.InferenceOptimization -{ - /// - /// Thread-safe registry for managing custom operators with automatic fallback - /// - public sealed class CustomOperatorRegistry - { - private static readonly Lazy _instance = - new Lazy(() => new CustomOperatorRegistry()); - - private readonly ConcurrentDictionary> _operators; - private readonly ConcurrentDictionary _selectedOperators; - private readonly ConcurrentDictionary _operatorVersions; - - /// - /// Gets the singleton instance of the registry - /// - public static CustomOperatorRegistry Instance => _instance.Value; - - private CustomOperatorRegistry() - { - _operators = new ConcurrentDictionary>(); - _selectedOperators = new ConcurrentDictionary(); - _operatorVersions = new ConcurrentDictionary(); - } - - /// - /// Registers a custom operator - /// - public void Register(ICustomOperator op) - { - if (op == null) - throw new ArgumentNullException(nameof(op)); - - // Bump the version after the operator set is updated. - // This avoids stale cached selections without requiring coarse locking. - void BumpVersion() => _operatorVersions.AddOrUpdate(op.Name, 1, (_, v) => v + 1); - - // Use AddOrUpdate with factory that always creates a new sorted list - // This ensures thread-safety by never mutating existing lists - _operators.AddOrUpdate( - op.Name, - _ => new List { op }, - (_, existingList) => - { - // Create a new list with all existing operators plus the new one - // This avoids race conditions from modifying the existing list - List newList; - lock (existingList) - { - newList = new List(existingList) { op }; - } - newList.Sort((a, b) => b.Priority.CompareTo(a.Priority)); - return newList; - }); - - BumpVersion(); - } - - /// - /// Gets the best available operator for the given name - /// - public ICustomOperator? GetOperator(string name) - { - if (string.IsNullOrEmpty(name)) - throw new ArgumentException("Operator name cannot be null or empty", nameof(name)); - - while (true) - { - long version = _operatorVersions.GetOrAdd(name, 0); - - if (_selectedOperators.TryGetValue(name, out var existing) && existing.Version == version) - { - return existing.Operator is NullOperator ? null : existing.Operator; - } - - var selected = SelectOperatorOrNull(name); - - // Only publish the cached selection if the operator set version did not change while we were selecting. - if (_operatorVersions.TryGetValue(name, out var current) && current == version) - { - _selectedOperators[name] = new SelectedOperatorEntry(version, selected); - return selected is NullOperator ? null : selected; - } - - // Operator set changed while selecting; retry to avoid caching a stale choice. - } - } - - private ICustomOperator SelectOperatorOrNull(string name) - { - if (!_operators.TryGetValue(name, out var candidates)) - return new NullOperator(); - - lock (candidates) - { - // Find the highest priority supported operator - var result = candidates.FirstOrDefault(op => op.IsSupported()); - return result ?? new NullOperator(); - } - } - - /// - /// Gets a typed operator - /// - public ICustomOperator? GetOperator(string name) where T : struct - { - return GetOperator(name) as ICustomOperator; - } - - /// - /// Internal marker type for null operators - /// - private sealed class NullOperator : ICustomOperator - { - public string Name => string.Empty; - public string Version => string.Empty; - public int Priority => int.MinValue; - public bool IsSupported() => false; - public double EstimatedSpeedup() => 0; - } - - /// - /// Checks if an operator is available - /// - public bool HasOperator(string name) - { - return GetOperator(name) != null; - } - - /// - /// Unregisters all operators with the given name - /// - public void Unregister(string name) - { - _operators.TryRemove(name, out _); - _selectedOperators.TryRemove(name, out _); - _operatorVersions.TryRemove(name, out _); - } - - /// - /// Gets all registered operator names - /// - public IEnumerable GetRegisteredOperatorNames() - { - return _operators.Keys.ToArray(); - } - - /// - /// Gets detailed information about all registered operators - /// - public Dictionary> GetOperatorInfo() - { - var result = new Dictionary>(); - - foreach (var kvp in _operators) - { - lock (kvp.Value) - { - result[kvp.Key] = kvp.Value.Select(op => new OperatorInfo - { - Name = op.Name, - Version = op.Version, - Priority = op.Priority, - IsSupported = op.IsSupported(), - EstimatedSpeedup = op.EstimatedSpeedup(), - Type = op.GetType().FullName ?? op.GetType().Name - }).ToList(); - } - } - - return result; - } - - /// - /// Clears all registered operators - /// - public void Clear() - { - _operators.Clear(); - _selectedOperators.Clear(); - _operatorVersions.Clear(); - } - - private readonly record struct SelectedOperatorEntry(long Version, ICustomOperator Operator); - } - - /// - /// Information about a registered operator - /// - public class OperatorInfo - { - public string Name { get; set; } = string.Empty; - public string Version { get; set; } = string.Empty; - public int Priority { get; set; } - public bool IsSupported { get; set; } - public double EstimatedSpeedup { get; set; } - public string Type { get; set; } = string.Empty; - } -} diff --git a/src/InferenceOptimization/ICustomOperator.cs b/src/InferenceOptimization/ICustomOperator.cs deleted file mode 100644 index 7342855347..0000000000 --- a/src/InferenceOptimization/ICustomOperator.cs +++ /dev/null @@ -1,48 +0,0 @@ -using System; -using AiDotNet.LinearAlgebra; - -namespace AiDotNet.InferenceOptimization -{ - /// - /// Defines the contract for custom operators with hardware-specific optimizations - /// - public interface ICustomOperator - { - /// - /// Gets the unique name of the operator - /// - string Name { get; } - - /// - /// Gets the version of the operator implementation - /// - string Version { get; } - - /// - /// Gets the priority level (higher values are preferred) - /// - int Priority { get; } - - /// - /// Determines if the operator can run on the current platform - /// - bool IsSupported(); - - /// - /// Estimates the relative performance gain over reference implementation - /// - /// Expected speedup multiplier (e.g., 2.0 for 2x speedup) - double EstimatedSpeedup(); - } - - /// - /// Base interface for custom operators that work with tensors - /// - public interface ICustomOperator : ICustomOperator where T : struct - { - /// - /// Executes the operator on input tensors - /// - Tensor Execute(params Tensor[] inputs); - } -} diff --git a/src/InferenceOptimization/IR/Common/IRTypes.cs b/src/InferenceOptimization/IR/Common/IRTypes.cs deleted file mode 100644 index f0b43868e1..0000000000 --- a/src/InferenceOptimization/IR/Common/IRTypes.cs +++ /dev/null @@ -1,306 +0,0 @@ -using System.Numerics; - -namespace AiDotNet.InferenceOptimization.IR.Common; - -/// -/// Represents the data type of a tensor in the IR. -/// Exceeds industry standards by supporting all common ML types plus quantized types. -/// -/// -/// Industry Comparison: -/// -/// MLIR: Uses builtin types with explicit bit widths -/// XLA: PrimitiveType enum with similar coverage -/// TVM: DataType class with code/bits/lanes -/// Our approach: Comprehensive enum with quantization support and extension points -/// -/// -public enum IRDataType -{ - // Standard floating point - Float16, - Float32, - Float64, - BFloat16, - - // Standard integers - Int8, - Int16, - Int32, - Int64, - UInt8, - UInt16, - UInt32, - UInt64, - - // Specialized types - Bool, - Complex64, - Complex128, - Decimal, - - // Quantized types (exceeds most frameworks) - QInt8, // Quantized int8 with scale/zero-point - QUInt8, // Quantized uint8 with scale/zero-point - QInt4, // 4-bit quantized (for LLMs) - QInt2, // 2-bit quantized (extreme compression) - - // Dynamic/unknown - Unknown -} - -/// -/// Memory layout for tensors. Critical for hardware optimization. -/// -/// -/// Industry Comparison: -/// -/// TVM: Uses layout strings like "NCHW", "NHWC" -/// ONNX: Implicit layouts based on operator -/// Our approach: Explicit enum with all common layouts plus extensibility -/// -/// -public enum MemoryLayout -{ - // Standard layouts - RowMajor, // C-style, last dimension contiguous - ColumnMajor, // Fortran-style, first dimension contiguous - - // Image layouts - NCHW, // Batch, Channel, Height, Width (PyTorch default) - NHWC, // Batch, Height, Width, Channel (TensorFlow default) - CHWN, // For specific hardware optimizations - - // Tiled layouts (for GPU/TPU optimization) - Tiled4x4, - Tiled8x8, - Tiled16x16, - Tiled32x32, - - // Blocked layouts (for CPU SIMD) - Blocked, - - // Custom/unknown - Custom, - Unknown -} - -/// -/// Execution device target for operations. -/// -public enum DeviceType -{ - CPU, - GPU, - TPU, - NPU, - FPGA, - Auto, // Let the scheduler decide - Any // Can run on any device -} - -/// -/// Quantization parameters for quantized tensor types. -/// -public class QuantizationParams -{ - public double Scale { get; set; } = 1.0; - public int ZeroPoint { get; set; } = 0; - public double Min { get; set; } = double.MinValue; - public double Max { get; set; } = double.MaxValue; - public bool PerChannel { get; set; } = false; - public int QuantizationAxis { get; set; } = -1; - public double[]? PerChannelScales { get; set; } - public int[]? PerChannelZeroPoints { get; set; } -} - -/// -/// Comprehensive tensor type information. -/// Exceeds industry standards by combining type, shape, layout, and device info. -/// -public class TensorType -{ - /// - /// Element data type. - /// - public IRDataType DataType { get; set; } = IRDataType.Float32; - - /// - /// Tensor shape. Empty for scalars, -1 for dynamic dimensions. - /// - public int[] Shape { get; set; } = Array.Empty(); - - /// - /// Memory layout. - /// - public MemoryLayout Layout { get; set; } = MemoryLayout.RowMajor; - - /// - /// Target device. - /// - public DeviceType Device { get; set; } = DeviceType.Auto; - - /// - /// Quantization parameters (if quantized type). - /// - public QuantizationParams? Quantization { get; set; } - - /// - /// Strides for each dimension (computed from shape and layout if not specified). - /// - public long[]? Strides { get; set; } - - /// - /// Whether this tensor has dynamic (runtime-determined) shape. - /// - public bool HasDynamicShape => Shape.Any(d => d < 0); - - /// - /// Total number of elements (returns -1 if dynamic). - /// - public long NumElements - { - get - { - if (HasDynamicShape) return -1; - if (Shape.Length == 0) return 1; // scalar - return Shape.Aggregate(1L, (acc, dim) => acc * dim); - } - } - - /// - /// Size in bytes of each element. - /// - public int ElementSize => DataType switch - { - IRDataType.Bool => 1, - IRDataType.Int8 or IRDataType.UInt8 or IRDataType.QInt8 or IRDataType.QUInt8 => 1, - IRDataType.Int16 or IRDataType.UInt16 or IRDataType.Float16 or IRDataType.BFloat16 => 2, - IRDataType.Int32 or IRDataType.UInt32 or IRDataType.Float32 => 4, - IRDataType.Int64 or IRDataType.UInt64 or IRDataType.Float64 or IRDataType.Complex64 => 8, - IRDataType.Complex128 or IRDataType.Decimal => 16, - IRDataType.QInt4 or IRDataType.QInt2 => 1, // packed - _ => 8 - }; - - /// - /// Total memory size in bytes. - /// - public long TotalBytes => NumElements >= 0 ? NumElements * ElementSize : -1; - - /// - /// Check if this type is compatible with another for broadcasting. - /// - public bool IsBroadcastCompatible(TensorType other) - { - if (other == null) - throw new ArgumentNullException(nameof(other)); - - if (HasDynamicShape || other.HasDynamicShape) return true; - - int maxRank = Math.Max(Shape.Length, other.Shape.Length); - for (int i = 0; i < maxRank; i++) - { - int dim1 = i < Shape.Length ? Shape[Shape.Length - 1 - i] : 1; - int dim2 = i < other.Shape.Length ? other.Shape[other.Shape.Length - 1 - i] : 1; - - if (dim1 != dim2 && dim1 != 1 && dim2 != 1) return false; - } - return true; - } - - public TensorType Clone() => new() - { - DataType = DataType, - Shape = (int[])Shape.Clone(), - Layout = Layout, - Device = Device, - Quantization = Quantization, - Strides = Strides != null ? (long[])Strides.Clone() : null - }; - - public override string ToString() - { - var shape = Shape.Length == 0 ? "scalar" : $"[{string.Join(", ", Shape.Select(d => d < 0 ? "?" : d.ToString()))}]"; - return $"{DataType}{shape}@{Device}"; - } -} - -/// -/// Extension methods for IRDataType. -/// -public static class IRDataTypeExtensions -{ - public static IRDataType FromSystemType(Type type) - { - if (type == null) - throw new ArgumentNullException(nameof(type)); - - return type switch - { - Type t when t == typeof(float) => IRDataType.Float32, - Type t when t == typeof(double) => IRDataType.Float64, - Type t when t == typeof(Half) => IRDataType.Float16, - Type t when t == typeof(int) => IRDataType.Int32, - Type t when t == typeof(long) => IRDataType.Int64, - Type t when t == typeof(short) => IRDataType.Int16, - Type t when t == typeof(byte) => IRDataType.UInt8, - Type t when t == typeof(sbyte) => IRDataType.Int8, - Type t when t == typeof(ushort) => IRDataType.UInt16, - Type t when t == typeof(uint) => IRDataType.UInt32, - Type t when t == typeof(ulong) => IRDataType.UInt64, - Type t when t == typeof(bool) => IRDataType.Bool, - Type t when t == typeof(decimal) => IRDataType.Decimal, - Type t when t == typeof(Complex) => IRDataType.Complex128, - _ => IRDataType.Unknown - }; - } - - public static Type ToSystemType(this IRDataType type) - { - return type switch - { - IRDataType.Float16 => typeof(Half), - IRDataType.Float32 => typeof(float), - IRDataType.Float64 => typeof(double), - IRDataType.BFloat16 => typeof(Half), // Approximation - IRDataType.Int8 or IRDataType.QInt8 => typeof(sbyte), - IRDataType.Int16 => typeof(short), - IRDataType.Int32 => typeof(int), - IRDataType.Int64 => typeof(long), - IRDataType.UInt8 or IRDataType.QUInt8 or IRDataType.QInt4 or IRDataType.QInt2 => typeof(byte), // Packed quantized types stored as bytes - IRDataType.UInt16 => typeof(ushort), - IRDataType.UInt32 => typeof(uint), - IRDataType.UInt64 => typeof(ulong), - IRDataType.Bool => typeof(bool), - IRDataType.Decimal => typeof(decimal), - IRDataType.Complex64 or IRDataType.Complex128 => typeof(Complex), - _ => typeof(double) - }; - } - - public static bool IsFloatingPoint(this IRDataType type) => - type is IRDataType.Float16 or IRDataType.Float32 or IRDataType.Float64 or IRDataType.BFloat16; - - public static bool IsInteger(this IRDataType type) => - type is IRDataType.Int8 or IRDataType.Int16 or IRDataType.Int32 or IRDataType.Int64 or - IRDataType.UInt8 or IRDataType.UInt16 or IRDataType.UInt32 or IRDataType.UInt64; - - public static bool IsQuantized(this IRDataType type) => - type is IRDataType.QInt8 or IRDataType.QUInt8 or IRDataType.QInt4 or IRDataType.QInt2; - - /// - /// Gets the size in bytes of each element for the given data type. - /// - public static int ElementSizeInBytes(this IRDataType type) => type switch - { - IRDataType.Bool => 1, - IRDataType.Int8 or IRDataType.UInt8 or IRDataType.QInt8 or IRDataType.QUInt8 => 1, - IRDataType.Int16 or IRDataType.UInt16 or IRDataType.Float16 or IRDataType.BFloat16 => 2, - IRDataType.Int32 or IRDataType.UInt32 or IRDataType.Float32 => 4, - IRDataType.Int64 or IRDataType.UInt64 or IRDataType.Float64 or IRDataType.Complex64 => 8, - IRDataType.Complex128 or IRDataType.Decimal => 16, - IRDataType.QInt4 or IRDataType.QInt2 => 1, // packed representation - _ => 8 // default to 8 for unknown types - }; -} diff --git a/src/InferenceOptimization/IR/HighLevel/HLIRGraph.cs b/src/InferenceOptimization/IR/HighLevel/HLIRGraph.cs deleted file mode 100644 index a85b5a8960..0000000000 --- a/src/InferenceOptimization/IR/HighLevel/HLIRGraph.cs +++ /dev/null @@ -1,797 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.IR.Common; - -namespace AiDotNet.InferenceOptimization.IR.HighLevel; - -/// -/// High-Level Intermediate Representation Graph. -/// Represents the complete computation graph at a semantic level. -/// -/// -/// Design Philosophy: -/// -/// HLIRGraph provides a container for HLIRNodes with efficient traversal, validation, -/// and transformation capabilities. It maintains both node-reference and ID-based -/// representations for flexibility. -/// -/// -/// Industry Comparison: -/// -/// TVM Relay: Module with global functions - we add richer graph operations -/// MLIR: ModuleOp with nested regions - we simplify for ML workloads -/// XLA HLO: HloModule with computations - we add bidirectional edges -/// ONNX: GraphProto - we add dynamic modification support -/// -/// -/// Exceeds Standards By: -/// -/// Incremental validation and repair -/// Built-in pattern matching for optimization passes -/// Automatic ID management with compaction -/// Comprehensive graph statistics and profiling -/// Subgraph extraction and splicing -/// -/// -public class HLIRGraph where T : struct -{ - #region Fields - - private readonly Dictionary> _nodeMap = new(); - private int _nextNodeId; - private bool _isDirty = true; - private List>? _cachedTopologicalOrder; - - #endregion - - #region Properties - - /// - /// All nodes in the graph. - /// - public IReadOnlyList> Nodes => _nodeMap.Values.ToList(); - - /// - /// Input nodes (nodes with no inputs). - /// - public List> InputNodes { get; } = new(); - - /// - /// Output nodes (nodes with no outputs, or explicitly marked). - /// - public List> OutputNodes { get; } = new(); - - /// - /// Graph name for debugging. - /// - public string Name { get; set; } = "HLIRGraph"; - - /// - /// Graph-level metadata. - /// - public Dictionary Metadata { get; } = new(); - - /// - /// Number of nodes in the graph. - /// - public int NodeCount => _nodeMap.Count; - - /// - /// Version number, incremented on each modification. - /// - public int Version { get; private set; } - - #endregion - - #region Node Management - - /// - /// Adds a new node to the graph, assigning it a unique ID. - /// - public HLIRNode AddNode(HLIRNode node) - { - if (node.Id < 0) - { - node.Id = _nextNodeId++; - } - else if (_nodeMap.ContainsKey(node.Id)) - { - throw new InvalidOperationException($"Node with ID {node.Id} already exists"); - } - else - { - _nextNodeId = Math.Max(_nextNodeId, node.Id + 1); - } - - _nodeMap[node.Id] = node; - MarkDirty(); - node.AddProvenance($"Added to graph '{Name}'"); - return node; - } - - /// - /// Creates and adds a new node with the specified operation. - /// - public HLIRNode CreateNode(OperationType operation, string name, params HLIRNode[] inputs) - { - var node = new HLIRNode - { - Id = _nextNodeId++, - Name = name, - Operation = operation - }; - - foreach (var input in inputs) - { - node.AddInput(input); - } - - _nodeMap[node.Id] = node; - MarkDirty(); - node.AddProvenance($"Created in graph '{Name}'"); - return node; - } - - /// - /// Removes a node from the graph. - /// - public bool RemoveNode(HLIRNode node) - { - if (!_nodeMap.ContainsKey(node.Id)) return false; - - // Disconnect from inputs - foreach (var input in node.Inputs.ToList()) - { - node.RemoveInput(input); - } - - // Disconnect from outputs - foreach (var output in node.Outputs.ToList()) - { - output.RemoveInput(node); - } - - // Remove from input/output lists - InputNodes.Remove(node); - OutputNodes.Remove(node); - - _nodeMap.Remove(node.Id); - MarkDirty(); - return true; - } - - /// - /// Finds a node by ID. - /// - public HLIRNode? FindNode(int id) => - _nodeMap.TryGetValue(id, out var node) ? node : null; - - /// - /// Finds nodes by name (partial match). - /// - public IEnumerable> FindNodesByName(string name) => - _nodeMap.Values.Where(n => n.Name.Contains(name, StringComparison.OrdinalIgnoreCase)); - - /// - /// Finds nodes by operation type. - /// - public IEnumerable> FindNodesByOperation(OperationType operation) => - _nodeMap.Values.Where(n => n.Operation == operation); - - /// - /// Replaces a node with another, updating all connections. - /// - /// The existing node to replace. - /// The new node to insert in its place. - /// Thrown when oldNode is not in the graph. - /// - /// - /// This method performs a complete replacement by: - /// - /// Copying all input connections from oldNode to newNode - /// Redirecting all output connections from oldNode to newNode - /// Adding newNode to the graph - /// Removing oldNode from the graph - /// - /// - /// - /// If newNode has the same ID as oldNode, a new ID will be automatically assigned - /// to avoid conflicts during the replacement process. - /// - /// - public void ReplaceNode(HLIRNode oldNode, HLIRNode newNode) - { - if (!_nodeMap.ContainsKey(oldNode.Id)) - { - throw new InvalidOperationException($"Node {oldNode.Id} not in graph"); - } - - // Handle case where new node has same ID as old node - // This prevents AddNode from failing when oldNode hasn't been removed yet - if (newNode.Id == oldNode.Id) - { - newNode.Id = -1; // Force new ID assignment during AddNode - } - - // Copy connections - foreach (var input in oldNode.Inputs.ToList()) - { - newNode.AddInput(input); - } - - foreach (var output in oldNode.Outputs.ToList()) - { - output.ReplaceInput(oldNode, newNode); - } - - // Update graph - order matters: add new node first, then remove old - AddNode(newNode); - RemoveNode(oldNode); - - newNode.AddProvenance($"Replaced node n{oldNode.Id}"); - } - - #endregion - - #region Traversal - - /// - /// Gets nodes in topological order (inputs before outputs). - /// - public List> GetTopologicalOrder() - { - if (!_isDirty && _cachedTopologicalOrder != null) - { - return _cachedTopologicalOrder; - } - - var result = new List>(); - var visited = new HashSet(); - var visiting = new HashSet(); - - void Visit(HLIRNode node) - { - if (visited.Contains(node.Id)) return; - if (visiting.Contains(node.Id)) - { - throw new InvalidOperationException($"Cycle detected at node {node.Id}"); - } - - visiting.Add(node.Id); - - foreach (var input in node.Inputs) - { - Visit(input); - } - - visiting.Remove(node.Id); - visited.Add(node.Id); - result.Add(node); - } - - // Start from output nodes or all nodes - var startNodes = OutputNodes.Count > 0 - ? OutputNodes - : _nodeMap.Values.Where(n => n.Outputs.Count == 0).ToList(); - - foreach (var node in startNodes) - { - Visit(node); - } - - // Add any remaining nodes (disconnected components) - foreach (var node in _nodeMap.Values.Where(n => !visited.Contains(n.Id))) - { - Visit(node); - } - - _cachedTopologicalOrder = result; - _isDirty = false; - return result; - } - - /// - /// Gets nodes in reverse topological order (outputs before inputs). - /// - /// A new list containing nodes in reverse topological order (does not modify cached order). - public List> GetReverseTopologicalOrder() - { - var order = GetTopologicalOrder(); - var reversed = new List>(order); - reversed.Reverse(); - return reversed; - } - - /// - /// Iterates over all nodes in BFS order from inputs. - /// - public IEnumerable> BreadthFirstFromInputs() - { - var visited = new HashSet(); - var queue = new Queue>(InputNodes); - - foreach (var node in InputNodes) - { - visited.Add(node.Id); - } - - while (queue.Count > 0) - { - var node = queue.Dequeue(); - yield return node; - - foreach (var output in node.Outputs) - { - if (!visited.Contains(output.Id)) - { - visited.Add(output.Id); - queue.Enqueue(output); - } - } - } - } - - #endregion - - #region Pattern Matching - - /// - /// Finds sequences of nodes matching an operation pattern. - /// - public List>> FindPatterns(params OperationType[] pattern) - { - if (pattern.Length == 0) return new List>>(); - - var results = new List>>(); - - foreach (var startNode in FindNodesByOperation(pattern[0])) - { - if (startNode.IsFused || startNode.IsMarkedForDeletion) continue; - - var sequence = TryMatchPattern(startNode, pattern); - if (sequence != null) - { - results.Add(sequence); - } - } - - return results; - } - - private List>? TryMatchPattern(HLIRNode startNode, OperationType[] pattern) - { - var sequence = new List> { startNode }; - var currentNode = startNode; - - for (int i = 1; i < pattern.Length; i++) - { - // Must have exactly one output - if (currentNode.Outputs.Count != 1) return null; - - var nextNode = currentNode.Outputs[0]; - - // Next node must match pattern and have single input - if (nextNode.Operation != pattern[i] || - nextNode.Inputs.Count != 1 || - nextNode.IsFused || - nextNode.IsMarkedForDeletion) - { - return null; - } - - sequence.Add(nextNode); - currentNode = nextNode; - } - - return sequence; - } - - /// - /// Finds diamond patterns (fork-join). - /// - public List<(HLIRNode fork, List> branches, HLIRNode join)> FindDiamondPatterns() - { - var results = new List<(HLIRNode, List>, HLIRNode)>(); - - foreach (var forkNode in _nodeMap.Values.Where(n => n.Outputs.Count > 1)) - { - // Check if all outputs eventually merge to the same node - var outputPaths = new Dictionary, HashSet>>(); - - foreach (var output in forkNode.Outputs) - { - outputPaths[output] = GetReachableNodes(output); - } - - // Find common descendants - if (outputPaths.Count < 2) continue; - - var common = outputPaths.Values.First().ToHashSet(); - foreach (var paths in outputPaths.Values.Skip(1)) - { - common.IntersectWith(paths); - } - - // Find the first common node (join point) - var joinNode = common - .OrderBy(n => GetTopologicalOrder().IndexOf(n)) - .FirstOrDefault(); - - if (joinNode != null && forkNode.Outputs.All(o => joinNode.Inputs.Contains(o) || - GetReachableNodes(o).Contains(joinNode))) - { - results.Add((forkNode, forkNode.Outputs.ToList(), joinNode)); - } - } - - return results; - } - - private HashSet> GetReachableNodes(HLIRNode start) - { - var reachable = new HashSet>(); - var stack = new Stack>(); - stack.Push(start); - - while (stack.Count > 0) - { - var node = stack.Pop(); - if (reachable.Add(node)) - { - foreach (var output in node.Outputs) - { - stack.Push(output); - } - } - } - - return reachable; - } - - #endregion - - #region Validation - - /// - /// Validates the graph structure. - /// - public ValidationResult Validate() - { - var errors = new List(); - var warnings = new List(); - - // Check for duplicate IDs - var ids = new HashSet(); - foreach (var node in _nodeMap.Values) - { - if (!ids.Add(node.Id)) - { - errors.Add($"Duplicate node ID: {node.Id}"); - } - } - - // Check node validity - foreach (var node in _nodeMap.Values) - { - if (!node.Validate()) - { - errors.Add($"Invalid node: n{node.Id} ({node.Name})"); - } - - // Check that inputs are in graph - foreach (var input in node.Inputs) - { - if (!_nodeMap.ContainsKey(input.Id)) - { - errors.Add($"Node n{node.Id} references missing input n{input.Id}"); - } - } - } - - // Check for cycles - try - { - GetTopologicalOrder(); - } - catch (InvalidOperationException ex) - { - errors.Add($"Graph contains cycle: {ex.Message}"); - } - - // Check input/output node consistency - foreach (var input in InputNodes) - { - if (!_nodeMap.ContainsKey(input.Id)) - { - errors.Add($"Input node n{input.Id} not in graph"); - } - } - - foreach (var output in OutputNodes) - { - if (!_nodeMap.ContainsKey(output.Id)) - { - errors.Add($"Output node n{output.Id} not in graph"); - } - } - - // Warnings - var deadNodes = _nodeMap.Values.Where(n => - !n.HasConsumers && - !OutputNodes.Contains(n) && - n.CanEliminate).ToList(); - - if (deadNodes.Count > 0) - { - warnings.Add($"Graph has {deadNodes.Count} dead nodes that could be eliminated"); - } - - return new ValidationResult - { - IsValid = errors.Count == 0, - Errors = errors, - Warnings = warnings - }; - } - - #endregion - - #region Statistics - - /// - /// Gets comprehensive graph statistics. - /// - public HLIRGraphStatistics GetStatistics() - { - var stats = new HLIRGraphStatistics - { - TotalNodes = _nodeMap.Count, - InputNodes = InputNodes.Count, - OutputNodes = OutputNodes.Count - }; - - // Count by operation - foreach (var node in _nodeMap.Values) - { - if (!stats.NodesByOperation.ContainsKey(node.Operation)) - { - stats.NodesByOperation[node.Operation] = 0; - } - stats.NodesByOperation[node.Operation]++; - - if (node.IsFused) stats.FusedNodes++; - if (node.Cost != null) - { - stats.TotalFLOPs += node.Cost.FLOPs; - stats.TotalMemoryRead += node.Cost.MemoryRead; - stats.TotalMemoryWrite += node.Cost.MemoryWrite; - } - } - - // Graph depth (longest path) - stats.GraphDepth = ComputeGraphDepth(); - - // Critical path - stats.CriticalPathLength = ComputeCriticalPathLength(); - - return stats; - } - - private int ComputeGraphDepth() - { - var depth = new Dictionary(); - foreach (var node in GetTopologicalOrder()) - { - var inputDepth = node.Inputs.Count > 0 - ? node.Inputs.Max(i => depth.GetValueOrDefault(i.Id, 0)) - : 0; - depth[node.Id] = inputDepth + 1; - } - return depth.Count > 0 ? depth.Values.Max() : 0; - } - - private long ComputeCriticalPathLength() - { - var pathCost = new Dictionary(); - foreach (var node in GetTopologicalOrder()) - { - var inputCost = node.Inputs.Count > 0 - ? node.Inputs.Max(i => pathCost.GetValueOrDefault(i.Id, 0)) - : 0; - var nodeCost = node.Cost?.EstimatedLatencyNs ?? 1; - pathCost[node.Id] = inputCost + nodeCost; - } - return pathCost.Count > 0 ? pathCost.Values.Max() : 0; - } - - #endregion - - #region Utilities - - /// - /// Creates a deep copy of the graph. - /// - public HLIRGraph Clone() - { - var clone = new HLIRGraph { Name = Name + "_clone" }; - - // Clone all nodes - var nodeClones = new Dictionary>(); - foreach (var node in _nodeMap.Values) - { - var nodeClone = node.Clone(); - nodeClone.Id = node.Id; - nodeClones[node.Id] = nodeClone; - clone._nodeMap[nodeClone.Id] = nodeClone; - } - - // Reconnect edges - foreach (var node in _nodeMap.Values) - { - var nodeClone = nodeClones[node.Id]; - foreach (var input in node.Inputs) - { - nodeClone.AddInput(nodeClones[input.Id]); - } - } - - // Copy input/output lists - foreach (var input in InputNodes) - { - clone.InputNodes.Add(nodeClones[input.Id]); - } - foreach (var output in OutputNodes) - { - clone.OutputNodes.Add(nodeClones[output.Id]); - } - - // Copy metadata - foreach (var kvp in Metadata) - { - clone.Metadata[kvp.Key] = kvp.Value; - } - - clone._nextNodeId = _nextNodeId; - return clone; - } - - /// - /// Extracts a subgraph containing specified nodes and their dependencies. - /// - public HLIRGraph ExtractSubgraph(IEnumerable> nodes) - { - var subgraph = new HLIRGraph { Name = Name + "_subgraph" }; - var nodeSet = new HashSet(nodes.Select(n => n.Id)); - - // Add all dependencies - var toProcess = new Queue>(nodes); - while (toProcess.Count > 0) - { - var node = toProcess.Dequeue(); - foreach (var input in node.Inputs) - { - if (nodeSet.Add(input.Id)) - { - toProcess.Enqueue(input); - } - } - } - - // Clone nodes into subgraph - var clones = new Dictionary>(); - foreach (var id in nodeSet) - { - var node = _nodeMap[id]; - var clone = node.Clone(); - clone.Id = node.Id; - clones[id] = clone; - subgraph.AddNode(clone); - } - - // Reconnect edges within subgraph - foreach (var id in nodeSet) - { - var original = _nodeMap[id]; - var clone = clones[id]; - foreach (var input in original.Inputs) - { - if (clones.TryGetValue(input.Id, out var inputClone)) - { - clone.AddInput(inputClone); - } - } - } - - return subgraph; - } - - /// - /// Compacts node IDs to be sequential starting from 0. - /// - public void CompactNodeIds() - { - var mapping = new Dictionary(); - var orderedNodes = GetTopologicalOrder(); - - for (int i = 0; i < orderedNodes.Count; i++) - { - mapping[orderedNodes[i].Id] = i; - } - - // First pass: Update all InputIds arrays using old IDs (before node IDs change) - foreach (var node in orderedNodes) - { - node.InputIds = node.Inputs.Select(inp => mapping[inp.Id]).ToArray(); - } - - // Second pass: Update all node IDs and rebuild node map - _nodeMap.Clear(); - foreach (var node in orderedNodes) - { - node.Id = mapping[node.Id]; - _nodeMap[node.Id] = node; - } - - _nextNodeId = orderedNodes.Count; - MarkDirty(); - } - - private void MarkDirty() - { - _isDirty = true; - _cachedTopologicalOrder = null; - Version++; - } - - public override string ToString() - { - return $"HLIRGraph '{Name}': {_nodeMap.Count} nodes, {InputNodes.Count} inputs, {OutputNodes.Count} outputs"; - } - - #endregion -} - -/// -/// Result of graph validation. -/// -public class ValidationResult -{ - public bool IsValid { get; init; } - public List Errors { get; init; } = new(); - public List Warnings { get; init; } = new(); - - public override string ToString() - { - if (IsValid && Warnings.Count == 0) return "Valid"; - var parts = new List(); - if (!IsValid) parts.Add($"{Errors.Count} errors"); - if (Warnings.Count > 0) parts.Add($"{Warnings.Count} warnings"); - return string.Join(", ", parts); - } -} - -/// -/// Comprehensive graph statistics. -/// -public class HLIRGraphStatistics -{ - public int TotalNodes { get; set; } - public int InputNodes { get; set; } - public int OutputNodes { get; set; } - public int FusedNodes { get; set; } - public int GraphDepth { get; set; } - public long CriticalPathLength { get; set; } - public long TotalFLOPs { get; set; } - public long TotalMemoryRead { get; set; } - public long TotalMemoryWrite { get; set; } - public Dictionary NodesByOperation { get; } = new(); - - public double ArithmeticIntensity => - (TotalMemoryRead + TotalMemoryWrite) > 0 - ? (double)TotalFLOPs / (TotalMemoryRead + TotalMemoryWrite) - : 0; - - public override string ToString() - { - return $"Nodes: {TotalNodes}, Depth: {GraphDepth}, FLOPs: {TotalFLOPs:N0}, Memory: {(TotalMemoryRead + TotalMemoryWrite) / 1024.0 / 1024.0:F2} MB"; - } -} diff --git a/src/InferenceOptimization/IR/HighLevel/HLIRNode.cs b/src/InferenceOptimization/IR/HighLevel/HLIRNode.cs deleted file mode 100644 index 6c9bdef778..0000000000 --- a/src/InferenceOptimization/IR/HighLevel/HLIRNode.cs +++ /dev/null @@ -1,425 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.IR.Common; -using AiDotNet.LinearAlgebra; - -namespace AiDotNet.InferenceOptimization.IR.HighLevel; - -/// -/// High-Level Intermediate Representation Node. -/// Represents semantic operations at the model level, similar to TVM's Relay or MLIR's high-level dialects. -/// -/// -/// Design Philosophy: -/// -/// HLIR represents operations at a semantic level, preserving model structure and enabling -/// high-level optimizations like: -/// - Operator fusion (Conv+BN+ReLU) -/// - Algebraic simplification -/// - Common subexpression elimination -/// - Dead code elimination -/// - Constant folding -/// -/// -/// Industry Comparison: -/// -/// TVM Relay: Functional IR with let bindings and closures - we add richer metadata -/// MLIR High-Level: Dialect-based with regions - we add optimization hints -/// XLA HLO: Flat operation list - we add graph structure -/// ONNX: Static graph - we add dynamic shape support and fusion tracking -/// -/// -/// Exceeds Standards By: -/// -/// Combining graph-based and SSA-style representations -/// Rich provenance tracking for debugging -/// Built-in fusion pattern matching -/// Hardware-aware cost model hints -/// Full quantization metadata support -/// -/// -/// The numeric type for constant values -public class HLIRNode where T : struct -{ - #region Identity - - /// - /// Unique identifier for this node. Uses integer for fast lookup (like JitCompiler). - /// - public int Id { get; set; } - - /// - /// Human-readable name for debugging and visualization. - /// - public string Name { get; set; } = string.Empty; - - /// - /// The semantic operation type. - /// - public OperationType Operation { get; set; } - - #endregion - - #region Graph Structure - - /// - /// Input nodes (predecessor edges). Using node references for graph traversal. - /// - public List> Inputs { get; set; } = new(); - - /// - /// Output nodes (successor edges). Maintained bidirectionally for efficient traversal. - /// - public List> Outputs { get; set; } = new(); - - /// - /// Input tensor IDs for SSA-style representation (compatible with JitCompiler). - /// - public int[] InputIds { get; set; } = Array.Empty(); - - #endregion - - #region Type Information - - /// - /// Comprehensive type information for the output tensor. - /// - public TensorType OutputType { get; set; } = new(); - - /// - /// Types of each input (for validation and type inference). - /// - public List InputTypes { get; set; } = new(); - - #endregion - - #region Operation Parameters - - /// - /// Operation-specific parameters (weights, biases, etc.). - /// Key: parameter name, Value: parameter value or tensor. - /// - public Dictionary Parameters { get; set; } = new(); - - /// - /// Operation attributes (stride, padding, kernel size, etc.). - /// - public Dictionary Attributes { get; set; } = new(); - - /// - /// Constant tensor value (if this is a constant node). - /// - public Tensor? ConstantValue { get; set; } - - #endregion - - #region Optimization Metadata - - /// - /// Whether this node can be eliminated (no side effects). - /// - public bool CanEliminate { get; set; } = true; - - /// - /// Whether this node can perform in-place operations. - /// - public bool CanOperateInPlace { get; set; } - - /// - /// Whether this node has been fused from multiple operations. - /// - public bool IsFused { get; set; } - - /// - /// Original nodes that were fused into this one. - /// - public List>? FusedFrom { get; set; } - - /// - /// Marked for deletion during optimization passes. - /// - public bool IsMarkedForDeletion { get; set; } - - /// - /// Cost estimate for scheduling (FLOPs, memory access, etc.). - /// - public OperationCost? Cost { get; set; } - - /// - /// Optimization hints for passes. - /// - public OptimizationHints Hints { get; set; } = new(); - - #endregion - - #region Provenance (Debugging) - - /// - /// Reference to the original layer (for debugging). - /// - public object? OriginalLayer { get; set; } - - /// - /// Source location information for debugging. - /// - public SourceLocation? SourceLocation { get; set; } - - /// - /// Provenance chain showing how this node was derived. - /// - public List Provenance { get; set; } = new(); - - #endregion - - #region Methods - - /// - /// Adds an input node with bidirectional linking. - /// - public void AddInput(HLIRNode input) - { - if (!Inputs.Contains(input)) - { - Inputs.Add(input); - InputTypes.Add(input.OutputType.Clone()); - } - if (!input.Outputs.Contains(this)) - { - input.Outputs.Add(this); - } - } - - /// - /// Removes an input node with bidirectional unlinking. - /// - public void RemoveInput(HLIRNode input) - { - var index = Inputs.IndexOf(input); - if (index >= 0) - { - Inputs.RemoveAt(index); - if (index < InputTypes.Count) - { - InputTypes.RemoveAt(index); - } - } - input.Outputs.Remove(this); - } - - /// - /// Replaces an input node with another. - /// - public void ReplaceInput(HLIRNode oldInput, HLIRNode newInput) - { - var index = Inputs.IndexOf(oldInput); - if (index >= 0) - { - Inputs[index] = newInput; - if (index < InputTypes.Count) - { - InputTypes[index] = newInput.OutputType.Clone(); - } - oldInput.Outputs.Remove(this); - if (!newInput.Outputs.Contains(this)) - { - newInput.Outputs.Add(this); - } - } - } - - /// - /// Checks if this node has any consumers. - /// - public bool HasConsumers => Outputs.Count > 0; - - /// - /// Gets the number of consumers. - /// - public int ConsumerCount => Outputs.Count; - - /// - /// Performs type inference based on operation and inputs. - /// - public void InferOutputType() - { - // Default implementation - specific operations override - if (InputTypes.Count > 0) - { - OutputType = InputTypes[0].Clone(); - } - } - - /// - /// Validates this node's structure. - /// - public bool Validate() - { - // Basic validation - if (Id < 0) return false; - if (OutputType == null) return false; - - // Input/output consistency - foreach (var input in Inputs) - { - if (!input.Outputs.Contains(this)) return false; - } - - foreach (var output in Outputs) - { - if (!output.Inputs.Contains(this)) return false; - } - - return true; - } - - /// - /// Creates a deep copy of this node (without connections). - /// - public HLIRNode Clone() - { - return new HLIRNode - { - Id = -1, // New ID should be assigned by graph - Name = Name + "_clone", - Operation = Operation, - OutputType = OutputType.Clone(), - Parameters = new Dictionary(Parameters), - Attributes = new Dictionary(Attributes), - ConstantValue = ConstantValue, - CanEliminate = CanEliminate, - CanOperateInPlace = CanOperateInPlace, - IsFused = IsFused, - OriginalLayer = OriginalLayer, - Hints = Hints.Clone() - }; - } - - /// - /// Adds provenance information. - /// - public void AddProvenance(string info) - { - Provenance.Add($"[{DateTime.UtcNow:HH:mm:ss}] {info}"); - } - - public override string ToString() - { - var inputStr = Inputs.Count > 0 - ? string.Join(", ", Inputs.Select(i => $"n{i.Id}")) - : "none"; - return $"n{Id}: {Name} ({Operation}) [{OutputType}] <- ({inputStr})"; - } - - #endregion -} - -/// -/// Cost estimate for an operation (for scheduling and optimization). -/// -public class OperationCost -{ - /// - /// Estimated FLOPs (floating-point operations). - /// - public long FLOPs { get; set; } - - /// - /// Estimated memory read in bytes. - /// - public long MemoryRead { get; set; } - - /// - /// Estimated memory write in bytes. - /// - public long MemoryWrite { get; set; } - - /// - /// Arithmetic intensity (FLOPs / memory bytes). - /// Higher means more compute-bound, lower means more memory-bound. - /// - public double ArithmeticIntensity => - (MemoryRead + MemoryWrite) > 0 - ? (double)FLOPs / (MemoryRead + MemoryWrite) - : 0; - - /// - /// Whether this operation is likely memory-bound. - /// - public bool IsMemoryBound => ArithmeticIntensity < 10; - - /// - /// Estimated latency in nanoseconds (device-specific). - /// - public long EstimatedLatencyNs { get; set; } -} - -/// -/// Optimization hints for passes. -/// -public class OptimizationHints -{ - /// - /// Preferred device for execution. - /// - public DeviceType PreferredDevice { get; set; } = DeviceType.Auto; - - /// - /// Whether to prioritize memory efficiency. - /// - public bool PrioritizeMemory { get; set; } - - /// - /// Whether to prioritize latency. - /// - public bool PrioritizeLatency { get; set; } - - /// - /// Whether this node is a good fusion candidate. - /// - public bool IsFusionCandidate { get; set; } = true; - - /// - /// Custom scheduling priority (higher = earlier). - /// - public int SchedulingPriority { get; set; } - - /// - /// Tile sizes for tiled execution. - /// - public int[]? TileSizes { get; set; } - - /// - /// Whether to use vectorization. - /// - public bool EnableVectorization { get; set; } = true; - - /// - /// Whether to use parallelization. - /// - public bool EnableParallelization { get; set; } = true; - - public OptimizationHints Clone() => new() - { - PreferredDevice = PreferredDevice, - PrioritizeMemory = PrioritizeMemory, - PrioritizeLatency = PrioritizeLatency, - IsFusionCandidate = IsFusionCandidate, - SchedulingPriority = SchedulingPriority, - TileSizes = TileSizes != null ? (int[])TileSizes.Clone() : null, - EnableVectorization = EnableVectorization, - EnableParallelization = EnableParallelization - }; -} - -/// -/// Source location for debugging. -/// -public class SourceLocation -{ - public string? FileName { get; set; } - public int Line { get; set; } - public int Column { get; set; } - public string? FunctionName { get; set; } - - public override string ToString() => - $"{FileName ?? "unknown"}:{Line}:{Column} in {FunctionName ?? "unknown"}"; -} diff --git a/src/InferenceOptimization/IR/IIRCompiler.cs b/src/InferenceOptimization/IR/IIRCompiler.cs deleted file mode 100644 index b03b0998cd..0000000000 --- a/src/InferenceOptimization/IR/IIRCompiler.cs +++ /dev/null @@ -1,171 +0,0 @@ -using AiDotNet.InferenceOptimization.IR.Common; -using AiDotNet.InferenceOptimization.IR.HighLevel; -using AiDotNet.InferenceOptimization.IR.LowLevel; - -namespace AiDotNet.InferenceOptimization.IR; - -/// -/// Interface for the two-tier IR compiler pipeline. -/// -/// -/// -/// The IR compiler transforms models through multiple stages: -/// 1. Model → HLIR: Convert model to high-level IR -/// 2. HLIR Optimization: Apply high-level optimizations -/// 3. HLIR → LLIR: Lower to low-level IR -/// 4. LLIR Optimization: Apply low-level optimizations -/// 5. Code Generation: Generate executable code -/// -/// -public interface IIRCompiler where T : struct -{ - /// - /// Compiles a model to optimized LLIR. - /// - LLIRGraph Compile(HLIRGraph hlir); - - /// - /// Gets the compilation options. - /// - IRCompilerOptions Options { get; } - - /// - /// Gets compilation statistics from the last compilation. - /// - IRCompilationStats? LastCompilationStats { get; } -} - -/// -/// Options for IR compilation. -/// -public class IRCompilerOptions -{ - /// - /// Target device for execution. - /// - public DeviceType TargetDevice { get; set; } = DeviceType.CPU; - - /// - /// Device configuration. - /// - public DeviceConfiguration DeviceConfig { get; set; } = new(); - - /// - /// Optimization level. - /// - public IROptimizationLevel OptimizationLevel { get; set; } = IROptimizationLevel.O2; - - /// - /// Target data type. - /// - public IRDataType TargetDataType { get; set; } = IRDataType.Float32; - - /// - /// Whether to enable fusion optimizations. - /// - public bool EnableFusion { get; set; } = true; - - /// - /// Whether to enable constant folding. - /// - public bool EnableConstantFolding { get; set; } = true; - - /// - /// Whether to enable dead code elimination. - /// - public bool EnableDeadCodeElimination { get; set; } = true; - - /// - /// Whether to enable memory optimization. - /// - public bool EnableMemoryOptimization { get; set; } = true; - - /// - /// Whether to enable auto-scheduling. - /// - public bool EnableAutoScheduling { get; set; } = true; - - /// - /// Whether to enable auto-tuning. - /// - public bool EnableAutoTuning { get; set; } = false; - - /// - /// Whether to preserve debug information. - /// - public bool PreserveDebugInfo { get; set; } = false; - - /// - /// Maximum number of fusion candidates to explore. - /// - public int MaxFusionCandidates { get; set; } = 100; - - /// - /// Memory limit for optimization (in bytes). - /// - public long MemoryLimitBytes { get; set; } = long.MaxValue; -} - -/// -/// Optimization level for IR compilation. -/// -public enum IROptimizationLevel -{ - /// - /// No optimization, fastest compilation. - /// - O0, - - /// - /// Basic optimizations (constant folding, DCE). - /// - O1, - - /// - /// Standard optimizations (+ fusion, algebraic simplification). - /// - O2, - - /// - /// Aggressive optimizations (+ memory optimization, auto-scheduling). - /// - O3, - - /// - /// Size optimization. - /// - Os -} - -/// -/// Statistics from IR compilation. -/// -public class IRCompilationStats -{ - public TimeSpan TotalTime { get; set; } - public TimeSpan HLIROptimizationTime { get; set; } - public TimeSpan LoweringTime { get; set; } - public TimeSpan LLIROptimizationTime { get; set; } - - public int OriginalNodeCount { get; set; } - public int OptimizedHLIRNodeCount { get; set; } - public int FinalLLIROpCount { get; set; } - - public int FusionsApplied { get; set; } - public int ConstantsFolded { get; set; } - public int DeadNodesEliminated { get; set; } - - public long OriginalMemoryEstimate { get; set; } - public long OptimizedMemoryEstimate { get; set; } - public double MemoryReductionPercent => - OriginalMemoryEstimate > 0 - ? (1 - (double)OptimizedMemoryEstimate / OriginalMemoryEstimate) * 100 - : 0; - - public override string ToString() - { - return $"Compiled in {TotalTime.TotalMilliseconds:F1}ms: " + - $"{OriginalNodeCount} → {FinalLLIROpCount} ops, " + - $"{FusionsApplied} fusions, {MemoryReductionPercent:F1}% memory reduction"; - } -} diff --git a/src/InferenceOptimization/IR/LowLevel/LLIRGraph.cs b/src/InferenceOptimization/IR/LowLevel/LLIRGraph.cs deleted file mode 100644 index d812521b79..0000000000 --- a/src/InferenceOptimization/IR/LowLevel/LLIRGraph.cs +++ /dev/null @@ -1,655 +0,0 @@ -using AiDotNet.InferenceOptimization.IR.Common; - -namespace AiDotNet.InferenceOptimization.IR.LowLevel; - -/// -/// Low-Level Intermediate Representation Graph. -/// Represents the computation graph optimized for hardware execution. -/// -/// -/// Design Philosophy: -/// -/// LLIRGraph is the final representation before code generation. It contains: -/// - Operations with scheduling information -/// - Memory allocation plan -/// - Device placement decisions -/// - Execution order -/// -/// -/// Exceeds Standards By: -/// -/// Integrated memory planning with buffer reuse -/// Multi-device execution support -/// Streaming execution capability -/// Auto-tuning integration points -/// -/// -public class LLIRGraph -{ - #region Properties - - /// - /// Operations in execution order. - /// - public List Operations { get; } = new(); - - /// - /// Mapping from buffer ID to shape. - /// - public Dictionary BufferShapes { get; } = new(); - - /// - /// Mapping from buffer ID to data type. - /// - public Dictionary BufferTypes { get; } = new(); - - /// - /// Input buffer IDs. - /// - public List InputIds { get; } = new(); - - /// - /// Output buffer IDs. - /// - public List OutputIds { get; } = new(); - - /// - /// Memory allocation plan. - /// - public MemoryPlan? MemoryPlan { get; set; } - - /// - /// Graph name. - /// - public string Name { get; set; } = "LLIRGraph"; - - /// - /// Metadata. - /// - public Dictionary Metadata { get; } = new(); - - /// - /// Target device configuration. - /// - public DeviceConfiguration DeviceConfig { get; set; } = new(); - - /// - /// Next buffer ID for allocation. - /// - private int _nextBufferId; - - #endregion - - #region Operations - - /// - /// Adds an operation to the graph. - /// - public void AddOperation(LLIROp op) - { - if (op.OutputId < 0) - { - op.OutputId = _nextBufferId++; - } - else - { - _nextBufferId = Math.Max(_nextBufferId, op.OutputId + 1); - } - - Operations.Add(op); - BufferShapes[op.OutputId] = op.OutputShape; - BufferTypes[op.OutputId] = op.OutputDataType; - } - - /// - /// Creates a new buffer ID. - /// - public int AllocateBufferId() => _nextBufferId++; - - /// - /// Gets an operation by output buffer ID. - /// - public LLIROp? GetOperationByOutputId(int bufferId) => - Operations.FirstOrDefault(op => op.OutputId == bufferId); - - /// - /// Gets all operations that use a buffer as input. - /// - public IEnumerable GetConsumers(int bufferId) => - Operations.Where(op => op.InputIds.Contains(bufferId)); - - #endregion - - #region Validation - - /// - /// Validates the graph structure and scheduling. - /// - public LLIRValidationResult Validate() - { - var errors = new List(); - var warnings = new List(); - - // Check buffer definitions - var definedBuffers = new HashSet(InputIds); - foreach (var op in Operations) - { - // Check inputs are defined - foreach (var inputId in op.InputIds) - { - if (!definedBuffers.Contains(inputId)) - { - errors.Add($"Operation {op.Name} uses undefined buffer b{inputId}"); - } - } - - // Check operation validity - if (!op.Validate()) - { - errors.Add($"Invalid operation: {op.Name}"); - } - - definedBuffers.Add(op.OutputId); - } - - // Check outputs are defined - foreach (var outputId in OutputIds) - { - if (!definedBuffers.Contains(outputId)) - { - errors.Add($"Output buffer b{outputId} not defined"); - } - } - - // Check scheduling - foreach (var op in Operations) - { - if (op.Schedule.VectorWidth > 1 && op.Schedule.VectorAxis < 0) - { - warnings.Add($"Operation {op.Name} has vector width but no vector axis"); - } - - if (op.Device == DeviceType.GPU && - (op.Schedule.ThreadBlockDims.Length == 0 || op.Schedule.GridDims.Length == 0)) - { - warnings.Add($"GPU operation {op.Name} missing thread/grid dimensions"); - } - } - - // Check memory plan - if (MemoryPlan != null) - { - var planResult = MemoryPlan.Validate(); - errors.AddRange(planResult.Errors); - warnings.AddRange(planResult.Warnings); - } - - return new LLIRValidationResult - { - IsValid = errors.Count == 0, - Errors = errors, - Warnings = warnings - }; - } - - #endregion - - #region Analysis - - /// - /// Computes total metrics for the graph. - /// - public LLIRGraphMetrics ComputeMetrics() - { - var metrics = new LLIRGraphMetrics(); - - foreach (var op in Operations) - { - var opMetrics = op.EstimateCost(); - metrics.TotalFLOPs += opMetrics.FLOPs; - metrics.TotalIntOps += opMetrics.IntOps; - metrics.TotalMemoryRead += opMetrics.MemoryRead; - metrics.TotalMemoryWrite += opMetrics.MemoryWrite; - metrics.TotalLatencyNs += opMetrics.LatencyNs; - - if (!metrics.OpCountByType.ContainsKey(op.OpType)) - { - metrics.OpCountByType[op.OpType] = 0; - } - metrics.OpCountByType[op.OpType]++; - - if (!metrics.FLOPsByDevice.ContainsKey(op.Device)) - { - metrics.FLOPsByDevice[op.Device] = 0; - } - metrics.FLOPsByDevice[op.Device] += opMetrics.FLOPs; - } - - metrics.OperationCount = Operations.Count; - metrics.BufferCount = BufferShapes.Count; - metrics.PeakMemoryBytes = MemoryPlan?.PeakMemoryBytes ?? EstimatePeakMemory(); - - return metrics; - } - - private long EstimatePeakMemory() - { - long peak = 0; - long current = 0; - var liveBuffers = new Dictionary(); - - // Add input buffers - foreach (var inputId in InputIds) - { - if (BufferShapes.TryGetValue(inputId, out var shape)) - { - var size = shape.Aggregate(1L, (a, b) => a * b) * GetElementSize(BufferTypes.GetValueOrDefault(inputId, IRDataType.Float32)); - liveBuffers[inputId] = size; - current += size; - } - } - peak = Math.Max(peak, current); - - foreach (var op in Operations) - { - // Add output buffer - var outputSize = op.OutputShape.Aggregate(1L, (a, b) => a * b) * - GetElementSize(op.OutputDataType); - liveBuffers[op.OutputId] = outputSize; - current += outputSize; - peak = Math.Max(peak, current); - - // Remove dead buffers (simplified - actual implementation would use liveness) - // For now, just track peak - } - - return peak; - } - - private static int GetElementSize(IRDataType type) => type switch - { - IRDataType.Float16 or IRDataType.BFloat16 => 2, - IRDataType.Float32 or IRDataType.Int32 or IRDataType.UInt32 => 4, - IRDataType.Float64 or IRDataType.Int64 or IRDataType.UInt64 => 8, - IRDataType.Int8 or IRDataType.UInt8 or IRDataType.QInt8 or IRDataType.QUInt8 => 1, - IRDataType.Int16 or IRDataType.UInt16 => 2, - _ => 4 - }; - - /// - /// Computes the critical path length. - /// - public long ComputeCriticalPath() - { - var latency = new Dictionary(); - - // Initialize inputs with 0 latency - foreach (var inputId in InputIds) - { - latency[inputId] = 0; - } - - foreach (var op in Operations) - { - var inputLatency = op.InputIds.Length > 0 - ? op.InputIds.Max(id => latency.GetValueOrDefault(id, 0)) - : 0; - - var opLatency = op.EstimateCost().LatencyNs; - latency[op.OutputId] = inputLatency + opLatency; - } - - return OutputIds.Count > 0 - ? OutputIds.Max(id => latency.GetValueOrDefault(id, 0)) - : latency.Values.DefaultIfEmpty(0).Max(); - } - - #endregion - - #region Optimization - - /// - /// Applies memory optimization to reuse buffers. - /// - public void OptimizeMemory() - { - // Compute buffer liveness - var liveness = ComputeLiveness(); - - // Assign memory pools - var pools = new List<(int lastUse, long size, int poolId)>(); - var poolAssignment = new Dictionary(); - - foreach (var op in Operations) - { - var bufferId = op.OutputId; - var (firstUse, lastUse) = liveness.GetValueOrDefault(bufferId, (0, Operations.Count)); - var bufferSize = op.OutputShape.Aggregate(1L, (a, b) => a * b) * - GetElementSize(op.OutputDataType); - - // Find reusable pool - int assignedPool = -1; - long offset = 0; - - for (int i = 0; i < pools.Count; i++) - { - var (poolLastUse, poolSize, poolId) = pools[i]; - if (poolLastUse < firstUse && poolSize >= bufferSize) - { - assignedPool = poolId; - pools[i] = (lastUse, poolSize, poolId); - break; - } - } - - if (assignedPool < 0) - { - assignedPool = pools.Count; - pools.Add((lastUse, bufferSize, assignedPool)); - } - - poolAssignment[bufferId] = (assignedPool, offset); - - // Update operation's buffer info - op.BufferAllocation = new BufferInfo - { - SizeBytes = bufferSize, - MemoryPoolId = assignedPool, - PoolOffset = offset, - FirstUseIndex = firstUse, - LastUseIndex = lastUse - }; - } - - // Create memory plan - MemoryPlan = new MemoryPlan - { - PoolCount = pools.Count, - PoolSizes = pools.Select(p => p.size).ToArray(), - BufferAssignments = poolAssignment, - PeakMemoryBytes = pools.Sum(p => p.size) - }; - } - - private Dictionary ComputeLiveness() - { - var liveness = new Dictionary(); - - for (int i = 0; i < Operations.Count; i++) - { - var op = Operations[i]; - - // Output is first used here - liveness[op.OutputId] = (i, i); - - // Update last use of inputs - foreach (var inputId in op.InputIds) - { - if (liveness.TryGetValue(inputId, out var existing)) - { - liveness[inputId] = (existing.Item1, i); - } - else - { - liveness[inputId] = (0, i); - } - } - } - - return liveness; - } - - /// - /// Selects optimal schedules for operations based on device capabilities. - /// - public void AutoSchedule() - { - foreach (var op in Operations) - { - AutoScheduleOp(op); - } - } - - private void AutoScheduleOp(LLIROp op) - { - var schedule = op.Schedule; - - if (op.Device == DeviceType.CPU) - { - // CPU scheduling - var vectorWidth = DeviceConfig.CPUVectorWidth; - - // Find best vectorization axis (innermost with size divisible by vector width) - for (int i = op.OutputShape.Length - 1; i >= 0; i--) - { - if (op.OutputShape[i] >= vectorWidth && op.OutputShape[i] % vectorWidth == 0) - { - schedule.VectorAxis = i; - schedule.VectorWidth = vectorWidth; - break; - } - } - - // Parallelization on outermost axis - if (op.OutputShape.Length > 0 && op.OutputShape[0] >= DeviceConfig.CPUCores) - { - schedule.ParallelAxes = new[] { 0 }; - } - - // Tiling for cache - if (op is MatMulOp matmul) - { - var tileSize = (int)Math.Sqrt(DeviceConfig.L2CacheBytes / 3 / 4); // 3 matrices, 4 bytes each - tileSize = Math.Min(tileSize, 64); - schedule.TileSizes = new[] { tileSize, tileSize, tileSize }; - } - } - else if (op.Device == DeviceType.GPU) - { - // GPU scheduling - var totalElements = op.OutputShape.Aggregate(1L, (a, b) => a * b); - - // Thread block size - var blockSize = Math.Min(256, (int)totalElements); - var numBlocks = (int)Math.Ceiling((double)totalElements / blockSize); - - schedule.ThreadBlockDims = new[] { blockSize, 1, 1 }; - schedule.GridDims = new[] { numBlocks, 1, 1 }; - - // Shared memory for reductions - if (op is ReduceOp) - { - schedule.SharedMemoryBytes = blockSize * 4; - } - } - } - - #endregion - - #region Utilities - - /// - /// Creates a copy of the graph. - /// - public LLIRGraph Clone() - { - var clone = new LLIRGraph { Name = Name + "_clone" }; - - foreach (var op in Operations) - { - // Note: In production, implement proper deep clone for each op type - clone.Operations.Add(op); - } - - foreach (var kvp in BufferShapes) - { - clone.BufferShapes[kvp.Key] = (int[])kvp.Value.Clone(); - } - - foreach (var kvp in BufferTypes) - { - clone.BufferTypes[kvp.Key] = kvp.Value; - } - - clone.InputIds.AddRange(InputIds); - clone.OutputIds.AddRange(OutputIds); - clone.MemoryPlan = MemoryPlan; - clone.DeviceConfig = DeviceConfig; - - foreach (var kvp in Metadata) - { - clone.Metadata[kvp.Key] = kvp.Value; - } - - return clone; - } - - /// - /// Computes structure hash for caching. - /// - public int ComputeStructureHash() - { - int hash = 17; - - foreach (var inputId in InputIds.OrderBy(id => id)) - { - hash = hash * 31 + inputId; - if (BufferShapes.TryGetValue(inputId, out var shape)) - { - foreach (var dim in shape) - { - hash = hash * 31 + dim; - } - } - } - - foreach (var op in Operations) - { - hash = hash * 31 + op.OpType.GetHashCode(); - hash = hash * 31 + op.OutputId; - hash = hash * 31 + op.OutputDataType.GetHashCode(); - foreach (var dim in op.OutputShape) - { - hash = hash * 31 + dim; - } - foreach (var inputId in op.InputIds) - { - hash = hash * 31 + inputId; - } - } - - foreach (var outputId in OutputIds.OrderBy(id => id)) - { - hash = hash * 31 + outputId; - } - - return hash; - } - - public override string ToString() - { - var metrics = ComputeMetrics(); - return $"LLIRGraph '{Name}': {Operations.Count} ops, {BufferShapes.Count} buffers, " + - $"{metrics.TotalFLOPs:N0} FLOPs, {metrics.PeakMemoryBytes / 1024.0 / 1024.0:F2} MB peak"; - } - - #endregion -} - -/// -/// Memory allocation plan. -/// -public class MemoryPlan -{ - public int PoolCount { get; set; } - public long[] PoolSizes { get; set; } = Array.Empty(); - public Dictionary BufferAssignments { get; set; } = new(); - public long PeakMemoryBytes { get; set; } - - public LLIRValidationResult Validate() - { - var errors = new List(); - var warnings = new List(); - - // Check pool assignments - foreach (var (bufferId, (poolId, offset)) in BufferAssignments) - { - if (poolId < 0 || poolId >= PoolCount) - { - errors.Add($"Buffer {bufferId} assigned to invalid pool {poolId}"); - } - - if (offset < 0) - { - errors.Add($"Buffer {bufferId} has negative offset {offset}"); - } - } - - return new LLIRValidationResult - { - IsValid = errors.Count == 0, - Errors = errors, - Warnings = warnings - }; - } -} - -/// -/// Device configuration for scheduling. -/// -public class DeviceConfiguration -{ - // CPU - public int CPUCores { get; set; } = Environment.ProcessorCount; - public int CPUVectorWidth { get; set; } = 8; // AVX-256 default - public long L1CacheBytes { get; set; } = 32 * 1024; - public long L2CacheBytes { get; set; } = 256 * 1024; - public long L3CacheBytes { get; set; } = 8 * 1024 * 1024; - - // GPU - public int GPUSMCount { get; set; } = 0; - public int GPUMaxThreadsPerBlock { get; set; } = 1024; - public long GPUSharedMemoryPerBlock { get; set; } = 48 * 1024; - public long GPUGlobalMemory { get; set; } = 0; - public bool GPUHasTensorCores { get; set; } = false; - - // Memory bandwidth (GB/s) - public double CPUMemoryBandwidth { get; set; } = 50; - public double GPUMemoryBandwidth { get; set; } = 500; - - // Peak compute (GFLOPS) - public double CPUPeakGFLOPS { get; set; } = 500; - public double GPUPeakGFLOPS { get; set; } = 10000; -} - -/// -/// LLIR validation result. -/// -public class LLIRValidationResult -{ - public bool IsValid { get; init; } - public List Errors { get; init; } = new(); - public List Warnings { get; init; } = new(); -} - -/// -/// LLIR graph metrics. -/// -public class LLIRGraphMetrics -{ - public int OperationCount { get; set; } - public int BufferCount { get; set; } - public long TotalFLOPs { get; set; } - public long TotalIntOps { get; set; } - public long TotalMemoryRead { get; set; } - public long TotalMemoryWrite { get; set; } - public long TotalLatencyNs { get; set; } - public long PeakMemoryBytes { get; set; } - public Dictionary OpCountByType { get; } = new(); - public Dictionary FLOPsByDevice { get; } = new(); - - public double ArithmeticIntensity => - (TotalMemoryRead + TotalMemoryWrite) > 0 - ? (double)(TotalFLOPs + TotalIntOps) / (TotalMemoryRead + TotalMemoryWrite) - : 0; -} diff --git a/src/InferenceOptimization/IR/LowLevel/LLIROp.cs b/src/InferenceOptimization/IR/LowLevel/LLIROp.cs deleted file mode 100644 index 14f750a110..0000000000 --- a/src/InferenceOptimization/IR/LowLevel/LLIROp.cs +++ /dev/null @@ -1,644 +0,0 @@ -using AiDotNet.InferenceOptimization.IR.Common; - -namespace AiDotNet.InferenceOptimization.IR.LowLevel; - -/// -/// Low-Level Intermediate Representation Operation. -/// Hardware-oriented operations for efficient execution, similar to TVM's TIR or MLIR's low-level dialects. -/// -/// -/// Design Philosophy: -/// -/// LLIR represents operations at a hardware level, enabling fine-grained optimizations: -/// - Loop nest transformations (tiling, unrolling, fusion) -/// - Memory layout optimization -/// - Vectorization and parallelization -/// - Device-specific code generation -/// -/// -/// Industry Comparison: -/// -/// TVM TIR: Imperative with loop primitives - we add richer scheduling -/// MLIR Linalg: Generic named operations - we add explicit buffer management -/// Halide: Scheduling separated from algorithm - we integrate both -/// Triton: Block-level programming - we support multiple granularities -/// -/// -/// Exceeds Standards By: -/// -/// Unified representation for CPU/GPU/TPU -/// Automatic vectorization width selection -/// Memory hierarchy awareness (L1/L2/L3/DRAM) -/// Built-in profiling and auto-tuning support -/// Cross-platform buffer management -/// -/// -public abstract class LLIROp -{ - #region Identity - - /// - /// Unique identifier for the output buffer. - /// Defaults to -1 (invalid) to prevent silent buffer collisions from missed assignments. - /// - public int OutputId { get; set; } = -1; - - /// - /// Input buffer identifiers. - /// - public int[] InputIds { get; set; } = Array.Empty(); - - /// - /// Operation type name. - /// - public virtual string OpType => GetType().Name.Replace("Op", ""); - - /// - /// Debug name. - /// - public string Name { get; set; } = string.Empty; - - #endregion - - #region Type Information - - /// - /// Output data type. - /// - public IRDataType OutputDataType { get; set; } = IRDataType.Float32; - - /// - /// Output shape. - /// - public int[] OutputShape { get; set; } = Array.Empty(); - - /// - /// Output strides for memory access. - /// - public long[]? OutputStrides { get; set; } - - /// - /// Memory layout for output. - /// - public MemoryLayout OutputLayout { get; set; } = MemoryLayout.RowMajor; - - #endregion - - #region Execution - - /// - /// Target device for execution. - /// - public DeviceType Device { get; set; } = DeviceType.CPU; - - /// - /// Scheduling information for this operation. - /// - public ScheduleInfo Schedule { get; set; } = new(); - - /// - /// Buffer allocation information. - /// - public BufferInfo? BufferAllocation { get; set; } - - #endregion - - #region Provenance - - /// - /// ID of the HLIR node this was lowered from. - /// - public int SourceHLIRNodeId { get; set; } = -1; - - #endregion - - #region Methods - - /// - /// Validates the operation. - /// - public virtual bool Validate() - { - if (OutputId < 0) return false; - if (OutputShape == null || OutputShape.Length == 0) return false; - return true; - } - - /// - /// Estimates the cost of this operation. - /// - public abstract OperationMetrics EstimateCost(); - - public override string ToString() - { - var inputs = string.Join(", ", InputIds.Select(id => $"b{id}")); - var shape = $"[{string.Join(", ", OutputShape)}]"; - return $"b{OutputId} = {OpType}({inputs}) : {OutputDataType}{shape}@{Device}"; - } - - #endregion -} - -/// -/// Scheduling information for loop nest optimization. -/// -public class ScheduleInfo -{ - /// - /// Tile sizes for each loop dimension. - /// - public int[] TileSizes { get; set; } = Array.Empty(); - - /// - /// Loop order (dimension indices). - /// - public int[] LoopOrder { get; set; } = Array.Empty(); - - /// - /// Parallelization axes. - /// - public int[] ParallelAxes { get; set; } = Array.Empty(); - - /// - /// Vectorization axis (-1 for none). - /// - public int VectorAxis { get; set; } = -1; - - /// - /// Vector width (e.g., 4 for SSE, 8 for AVX, 16 for AVX-512). - /// - public int VectorWidth { get; set; } = 1; - - /// - /// Unroll factor for innermost loop. - /// - public int UnrollFactor { get; set; } = 1; - - /// - /// Whether to use software pipelining. - /// - public bool UseSoftwarePipelining { get; set; } - - /// - /// Prefetch distance (in iterations). - /// - public int PrefetchDistance { get; set; } - - /// - /// Thread block dimensions (for GPU). - /// - public int[] ThreadBlockDims { get; set; } = Array.Empty(); - - /// - /// Grid dimensions (for GPU). - /// - public int[] GridDims { get; set; } = Array.Empty(); - - /// - /// Shared memory usage in bytes (for GPU). - /// - public int SharedMemoryBytes { get; set; } - - /// - /// Register pressure estimate. - /// - public int RegisterPressure { get; set; } - - public ScheduleInfo Clone() => new() - { - TileSizes = (int[])TileSizes.Clone(), - LoopOrder = (int[])LoopOrder.Clone(), - ParallelAxes = (int[])ParallelAxes.Clone(), - VectorAxis = VectorAxis, - VectorWidth = VectorWidth, - UnrollFactor = UnrollFactor, - UseSoftwarePipelining = UseSoftwarePipelining, - PrefetchDistance = PrefetchDistance, - ThreadBlockDims = (int[])ThreadBlockDims.Clone(), - GridDims = (int[])GridDims.Clone(), - SharedMemoryBytes = SharedMemoryBytes, - RegisterPressure = RegisterPressure - }; -} - -/// -/// Buffer allocation and memory management information. -/// -public class BufferInfo -{ - /// - /// Buffer size in bytes. - /// - public long SizeBytes { get; set; } - - /// - /// Memory alignment requirement. - /// - public int Alignment { get; set; } = 64; // Cache line aligned by default - - /// - /// Memory pool ID for buffer reuse. - /// - public int MemoryPoolId { get; set; } = -1; - - /// - /// Offset within memory pool. - /// - public long PoolOffset { get; set; } - - /// - /// Memory hierarchy level (L1, L2, L3, DRAM, HBM). - /// - public MemoryLevel MemoryLevel { get; set; } = MemoryLevel.DRAM; - - /// - /// First use index in topological order. - /// - public int FirstUseIndex { get; set; } - - /// - /// Last use index in topological order. - /// - public int LastUseIndex { get; set; } - - /// - /// Whether this buffer can be allocated in-place with an input. - /// - public bool CanInPlace { get; set; } - - /// - /// Input buffer ID for in-place operation. - /// - public int InPlaceInputId { get; set; } = -1; - - /// - /// Whether this buffer is persistent (survives across invocations). - /// - public bool IsPersistent { get; set; } -} - -/// -/// Memory hierarchy level. -/// -public enum MemoryLevel -{ - Register, - L1Cache, - L2Cache, - L3Cache, - DRAM, - HBM, // High Bandwidth Memory (GPU) - SharedMemory, // GPU shared memory - GlobalMemory, // GPU global memory - ConstantMemory, // GPU constant memory - TextureMemory // GPU texture memory -} - -/// -/// Operation performance metrics. -/// -public class OperationMetrics -{ - /// - /// Floating-point operations. - /// - public long FLOPs { get; set; } - - /// - /// Integer operations. - /// - public long IntOps { get; set; } - - /// - /// Memory read in bytes. - /// - public long MemoryRead { get; set; } - - /// - /// Memory write in bytes. - /// - public long MemoryWrite { get; set; } - - /// - /// Estimated cycles on target device. - /// - public long EstimatedCycles { get; set; } - - /// - /// Estimated latency in nanoseconds. - /// - public long LatencyNs { get; set; } - - /// - /// Arithmetic intensity (ops per byte). - /// - public double ArithmeticIntensity => - (MemoryRead + MemoryWrite) > 0 - ? (double)(FLOPs + IntOps) / (MemoryRead + MemoryWrite) - : 0; - - /// - /// Whether operation is memory-bound. - /// - public bool IsMemoryBound => ArithmeticIntensity < 10; - - /// - /// Roofline model bound (theoretical max GFLOPS). - /// - public double RooflineGFLOPS(double peakGFLOPS, double memBandwidthGBps) - { - return Math.Min(peakGFLOPS, ArithmeticIntensity * memBandwidthGBps); - } -} - -#region Concrete Operations - -/// -/// Matrix multiplication operation. -/// -public class MatMulOp : LLIROp -{ - public int M { get; set; } - public int N { get; set; } - public int K { get; set; } - public bool TransposeA { get; set; } - public bool TransposeB { get; set; } - public double Alpha { get; set; } = 1.0; - public double Beta { get; set; } = 0.0; - - public override OperationMetrics EstimateCost() - { - var flops = 2L * M * N * K; - var memRead = (long)(M * K + K * N) * GetElementSize(); - var memWrite = (long)(M * N) * GetElementSize(); - - return new OperationMetrics - { - FLOPs = flops, - MemoryRead = memRead, - MemoryWrite = memWrite, - LatencyNs = flops / 100 // Rough estimate - }; - } - - private int GetElementSize() => OutputDataType switch - { - IRDataType.Float32 => 4, - IRDataType.Float64 => 8, - IRDataType.Float16 or IRDataType.BFloat16 => 2, - _ => 4 - }; -} - -/// -/// Elementwise operation. -/// -public class ElementwiseOp : LLIROp -{ - public ElementwiseOpType ElementwiseType { get; set; } - - public override OperationMetrics EstimateCost() - { - var elements = OutputShape.Aggregate(1L, (a, b) => a * b); - // Use proper element size based on data type (Float16=2, Float32=4, Float64=8, etc.) - var elemSize = OutputDataType.ElementSizeInBytes(); - - return new OperationMetrics - { - FLOPs = elements * (ElementwiseType == ElementwiseOpType.FusedMultiplyAdd ? 2 : 1), - MemoryRead = elements * elemSize * InputIds.Length, - MemoryWrite = elements * elemSize, - // Use ceiling division to ensure non-zero latency for small arrays - LatencyNs = Math.Max(1, (elements + 999) / 1000) - }; - } -} - -public enum ElementwiseOpType -{ - Add, Subtract, Multiply, Divide, - Exp, Log, Sqrt, Rsqrt, - ReLU, Sigmoid, Tanh, GELU, SiLU, Swish, - Max, Min, Abs, Neg, - FusedMultiplyAdd, - Compare, Select, - Softmax, LogSoftmax, - Identity -} - -/// -/// Convolution operation. -/// -public class Conv2DOp : LLIROp -{ - public int BatchSize { get; set; } - public int InputChannels { get; set; } - public int OutputChannels { get; set; } - public int InputHeight { get; set; } - public int InputWidth { get; set; } - public int KernelHeight { get; set; } - public int KernelWidth { get; set; } - public int StrideH { get; set; } = 1; - public int StrideW { get; set; } = 1; - public int PadH { get; set; } - public int PadW { get; set; } - public int DilationH { get; set; } = 1; - public int DilationW { get; set; } = 1; - public int Groups { get; set; } = 1; - public ConvAlgorithm Algorithm { get; set; } = ConvAlgorithm.Auto; - - public override OperationMetrics EstimateCost() - { - var outH = (InputHeight + 2 * PadH - DilationH * (KernelHeight - 1) - 1) / StrideH + 1; - var outW = (InputWidth + 2 * PadW - DilationW * (KernelWidth - 1) - 1) / StrideW + 1; - - // 2 * (multiply + add) per output element per kernel element - var flops = 2L * BatchSize * OutputChannels * outH * outW * - (InputChannels / Groups) * KernelHeight * KernelWidth; - - // Use proper element size based on data type (Float16=2, Float32=4, Float64=8, etc.) - var elemSize = OutputDataType.ElementSizeInBytes(); - var inputSize = (long)BatchSize * InputChannels * InputHeight * InputWidth * elemSize; - var kernelSize = (long)OutputChannels * (InputChannels / Groups) * KernelHeight * KernelWidth * elemSize; - var outputSize = (long)BatchSize * OutputChannels * outH * outW * elemSize; - - return new OperationMetrics - { - FLOPs = flops, - MemoryRead = inputSize + kernelSize, - MemoryWrite = outputSize, - LatencyNs = flops / 100 - }; - } -} - -public enum ConvAlgorithm -{ - Auto, - Direct, // Direct convolution - Im2Col, // Image to column + GEMM - Winograd, // Winograd transform (for small kernels) - FFT, // FFT-based (for large kernels) - Implicit, // Implicit GEMM (for GPU) - TensorCore // Using tensor cores (for GPU) -} - -/// -/// Reduction operation. -/// -public class ReduceOp : LLIROp -{ - public ReduceType ReduceType { get; set; } - public int[] Axes { get; set; } = Array.Empty(); - public bool KeepDims { get; set; } - - /// - /// Shape of the input tensor before reduction. Required for accurate cost estimation. - /// - public int[] InputShape { get; set; } = Array.Empty(); - - public override OperationMetrics EstimateCost() - { - // Calculate input and output element counts - var inputElements = InputShape.Length > 0 - ? InputShape.Aggregate(1L, (a, b) => a * b) - : OutputShape.Aggregate(1L, (a, b) => a * b); // Fallback if InputShape not set - var outputElements = OutputShape.Aggregate(1L, (a, b) => a * b); - - // Use proper element size based on data type - var elemSize = OutputDataType.ElementSizeInBytes(); - - return new OperationMetrics - { - // Each output element requires processing all elements along reduced axes - FLOPs = inputElements, - MemoryRead = inputElements * elemSize, - MemoryWrite = outputElements * elemSize, - LatencyNs = Math.Max(1, inputElements / 100) - }; - } -} - -public enum ReduceType -{ - Sum, Mean, Max, Min, Prod, - L1Norm, L2Norm, LogSumExp, - All, Any -} - -/// -/// Memory operation (copy, reshape, transpose). -/// -public class MemoryOp : LLIROp -{ - public MemoryOpType MemoryOpType { get; set; } - public int[] Permutation { get; set; } = Array.Empty(); // For transpose - public int[] NewShape { get; set; } = Array.Empty(); // For reshape - - public override OperationMetrics EstimateCost() - { - var elements = OutputShape.Aggregate(1L, (a, b) => a * b); - // Use proper element size based on data type (Float16=2, Float32=4, Float64=8, etc.) - var elemSize = OutputDataType.ElementSizeInBytes(); - - return new OperationMetrics - { - FLOPs = 0, - MemoryRead = elements * elemSize, - MemoryWrite = elements * elemSize, - LatencyNs = Math.Max(1, elements / 1000) - }; - } -} - -public enum MemoryOpType -{ - Copy, - Reshape, - Transpose, - Slice, - Concat, - Broadcast, - Pad, - Gather, - Scatter -} - -/// -/// Fused operation combining multiple operations. -/// -public class FusedOp : LLIROp -{ - /// - /// Sequence of fused operations. - /// - public List FusedOps { get; set; } = new(); - - /// - /// Pattern name (e.g., "ConvBNReLU", "MatMulBiasGELU"). - /// - public string FusionPattern { get; set; } = string.Empty; - - /// - /// Additional attributes for the fused operation (e.g., pooling parameters). - /// - /// - /// Used to store operation-specific parameters that don't fit in the standard properties, - /// such as kernel size, stride, and padding for pooling operations. - /// - public Dictionary Attributes { get; set; } = new(); - - public override OperationMetrics EstimateCost() - { - var combined = new OperationMetrics(); - foreach (var op in FusedOps) - { - var opCost = op.EstimateCost(); - combined.FLOPs += opCost.FLOPs; - combined.IntOps += opCost.IntOps; - combined.LatencyNs += opCost.LatencyNs; - } - - // Fusion reduces memory traffic - only read inputs once and write final output - if (FusedOps.Count > 0) - { - var firstOp = FusedOps[0]; - var lastOp = FusedOps[^1]; - combined.MemoryRead = firstOp.EstimateCost().MemoryRead; - combined.MemoryWrite = lastOp.EstimateCost().MemoryWrite; - } - - return combined; - } - - public override string ToString() - { - return $"b{OutputId} = Fused[{FusionPattern}]({string.Join(", ", InputIds.Select(id => $"b{id}"))}) : {OutputDataType}"; - } -} - -/// -/// Constant/parameter loading operation. -/// -public class ConstantOp : LLIROp -{ - public byte[]? Data { get; set; } - public bool IsParameter { get; set; } - public string ParameterName { get; set; } = string.Empty; - - public override OperationMetrics EstimateCost() - { - // Use actual data length if available, otherwise calculate from shape and proper element size - var bytes = Data?.Length ?? OutputShape.Aggregate(1L, (a, b) => a * b) * OutputDataType.ElementSizeInBytes(); - - return new OperationMetrics - { - FLOPs = 0, - MemoryRead = bytes, - MemoryWrite = bytes, - LatencyNs = Math.Max(1, bytes / 1000) - }; - } -} - -#endregion diff --git a/src/InferenceOptimization/IR/Lowering/HLIRToLLIRLowering.cs b/src/InferenceOptimization/IR/Lowering/HLIRToLLIRLowering.cs deleted file mode 100644 index 1321c57187..0000000000 --- a/src/InferenceOptimization/IR/Lowering/HLIRToLLIRLowering.cs +++ /dev/null @@ -1,1754 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.IR.Common; -using AiDotNet.InferenceOptimization.IR.HighLevel; -using AiDotNet.InferenceOptimization.IR.LowLevel; - -namespace AiDotNet.InferenceOptimization.IR.Lowering; - -/// -/// Lowers High-Level IR to Low-Level IR. -/// Transforms semantic operations into hardware-optimized operations. -/// -/// -/// Design Philosophy: -/// -/// The lowering process transforms high-level semantic operations into low-level -/// hardware-optimized operations. This is similar to MLIR's progressive lowering -/// or TVM's Relay to TIR conversion. -/// -/// -/// Lowering Stages: -/// -/// Type conversion: Generic T -> concrete types (float32, etc.) -/// Operation mapping: Semantic ops -> hardware ops -/// Layout selection: Choose optimal memory layouts -/// Scheduling: Generate schedules for loops -/// Memory planning: Allocate and reuse buffers -/// -/// -/// Exceeds Standards By: -/// -/// Preserves fusion information from HLIR -/// Automatic algorithm selection (Winograd, FFT, etc.) -/// Multi-device lowering in single pass -/// Quantization-aware lowering -/// -/// -public class HLIRToLLIRLowering where T : struct -{ - #region Configuration - - /// - /// Target device for lowering. - /// - public DeviceType TargetDevice { get; set; } = DeviceType.CPU; - - /// - /// Device configuration. - /// - public DeviceConfiguration DeviceConfig { get; set; } = new(); - - /// - /// Whether to optimize memory usage. - /// - public bool OptimizeMemory { get; set; } = true; - - /// - /// Whether to auto-schedule operations. - /// - public bool AutoSchedule { get; set; } = true; - - /// - /// Target data type (for type conversion). - /// - public IRDataType TargetDataType { get; set; } = IRDataType.Float32; - - /// - /// Preferred memory layout. - /// - public MemoryLayout PreferredLayout { get; set; } = MemoryLayout.RowMajor; - - #endregion - - #region State - - private readonly Dictionary _hlirToLlirBufferMap = new(); - private LLIRGraph _llirGraph = new(); - - #endregion - - #region Main Entry Point - - /// - /// Lowers an HLIR graph to an LLIR graph. - /// - public LLIRGraph Lower(HLIRGraph hlirGraph) - { - _llirGraph = new LLIRGraph - { - Name = hlirGraph.Name + "_lowered", - DeviceConfig = DeviceConfig - }; - _hlirToLlirBufferMap.Clear(); - - // Process nodes in topological order - var orderedNodes = hlirGraph.GetTopologicalOrder(); - - // First pass: map input nodes - foreach (var inputNode in hlirGraph.InputNodes) - { - MapInputNode(inputNode); - } - - // Second pass: lower each node - foreach (var node in orderedNodes) - { - if (!hlirGraph.InputNodes.Contains(node)) - { - LowerNode(node); - } - } - - // Map output nodes - fail-fast if any output is missing - foreach (var outputNode in hlirGraph.OutputNodes) - { - if (!_hlirToLlirBufferMap.TryGetValue(outputNode.Id, out var llirId)) - { - throw new InvalidOperationException( - $"Output node '{outputNode.Name}' (ID: {outputNode.Id}) was not lowered. " + - $"This indicates a missing lowering implementation for operation type '{outputNode.Operation}'."); - } - _llirGraph.OutputIds.Add(llirId); - } - - // Post-processing - if (OptimizeMemory) - { - _llirGraph.OptimizeMemory(); - } - - if (AutoSchedule) - { - _llirGraph.AutoSchedule(); - } - - return _llirGraph; - } - - #endregion - - #region Node Lowering - - private void MapInputNode(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - _hlirToLlirBufferMap[node.Id] = bufferId; - _llirGraph.InputIds.Add(bufferId); - _llirGraph.BufferShapes[bufferId] = node.OutputType.Shape; - _llirGraph.BufferTypes[bufferId] = ConvertDataType(node.OutputType.DataType); - } - - private void LowerNode(HLIRNode node) - { - // Handle fused nodes - if (node.IsFused && node.FusedFrom != null && node.FusedFrom.Count > 0) - { - LowerFusedNode(node); - return; - } - - // Lower based on operation type - var llirOp = node.Operation switch - { - OperationType.MatMul or OperationType.Gemm => LowerMatMul(node), - OperationType.Add or OperationType.Subtract or OperationType.Multiply or OperationType.Divide => - LowerElementwise(node), - OperationType.Conv2D or OperationType.Convolution2D or OperationType.Convolution => - LowerConv2D(node), - OperationType.ReLU or OperationType.Sigmoid or OperationType.Tanh or OperationType.GELU or - OperationType.Softmax or OperationType.LogSoftmax => - LowerActivation(node), - OperationType.BatchNormalization or OperationType.LayerNormalization => - LowerNormalization(node), - OperationType.MaxPool2D or OperationType.AvgPool2D or OperationType.GlobalAveragePooling => - LowerPooling(node), - OperationType.Reshape or OperationType.Transpose or OperationType.Flatten or - OperationType.Concat or OperationType.Split or OperationType.Slice => - LowerMemoryOp(node), - OperationType.Constant => - LowerConstant(node), - OperationType.Input or OperationType.Output => - LowerInputOutput(node), - OperationType.ReduceSum or OperationType.Mean or OperationType.ReduceMax or OperationType.ReduceMin => - LowerReduction(node), - OperationType.Dense or OperationType.FullyConnected => - LowerDense(node), - OperationType.Embedding => - LowerEmbedding(node), - OperationType.Attention or OperationType.MultiHeadAttention => - LowerAttention(node), - OperationType.Dropout => - LowerDropout(node), - OperationType.FusedConvBatchNormReLU or OperationType.FusedMatMulBias or - OperationType.FusedMatMulBiasReLU or OperationType.FusedMatMulBiasGELU or - OperationType.FusedMultiHeadAttention or OperationType.FusedLayerNormAttention => - LowerFusedOperation(node), - _ => LowerGeneric(node) - }; - - if (llirOp != null) - { - _llirGraph.AddOperation(llirOp); - _hlirToLlirBufferMap[node.Id] = llirOp.OutputId; - } - } - - #endregion - - #region Operation-Specific Lowering - - private LLIROp LowerMatMul(HLIRNode node) - { - var inputIds = GetLLIRInputIds(node); - var (m, n, k) = InferMatMulDims(node); - var bufferId = _llirGraph.AllocateBufferId(); - - var transposeA = GetAttributeBool(node, "transposeA", false); - var transposeB = GetAttributeBool(node, "transposeB", false); - - return new MatMulOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = inputIds, - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - M = m, - N = n, - K = k, - TransposeA = transposeA, - TransposeB = transposeB, - SourceHLIRNodeId = node.Id - }; - } - - private LLIROp LowerElementwise(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - var opType = node.Operation switch - { - OperationType.Add => ElementwiseOpType.Add, - OperationType.Subtract => ElementwiseOpType.Subtract, - OperationType.Multiply => ElementwiseOpType.Multiply, - OperationType.Divide => ElementwiseOpType.Divide, - _ => ElementwiseOpType.Add - }; - - return new ElementwiseOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - ElementwiseType = opType, - SourceHLIRNodeId = node.Id - }; - } - - /// - /// Lowers a Conv2D operation from HLIR to LLIR representation. - /// - /// The HLIR node representing the Conv2D operation. - /// The lowered Conv2DOp for LLIR execution. - /// - /// Thrown when the node has insufficient input type information for proper lowering. - /// - /// - /// - /// Conv2D lowering requires complete shape information for both input tensor and kernel. - /// The input tensor shape must be in NCHW format: [batch, channels, height, width]. - /// The kernel shape must be in OIHW format: [out_channels, in_channels, kernel_h, kernel_w]. - /// - /// - private LLIROp LowerConv2D(HLIRNode node) - { - var inputIds = GetLLIRInputIds(node); - var bufferId = _llirGraph.AllocateBufferId(); - - // Validate input types - Conv2D requires proper shape information - if (node.InputTypes.Count < 2) - { - throw new InvalidOperationException( - $"Conv2D node '{node.Name}' (id={node.Id}) requires at least 2 input types " + - $"(input tensor and kernel), but only has {node.InputTypes.Count}."); - } - - var inputShape = node.InputTypes[0].Shape; - var kernelShape = node.InputTypes[1].Shape; - - if (inputShape == null || inputShape.Length < 4) - { - throw new InvalidOperationException( - $"Conv2D node '{node.Name}' (id={node.Id}) has invalid input tensor shape. " + - $"Expected 4D NCHW tensor, got {(inputShape == null ? "null" : $"{inputShape.Length}D")}."); - } - - if (kernelShape == null || kernelShape.Length < 4) - { - throw new InvalidOperationException( - $"Conv2D node '{node.Name}' (id={node.Id}) has invalid kernel shape. " + - $"Expected 4D OIHW tensor, got {(kernelShape == null ? "null" : $"{kernelShape.Length}D")}."); - } - - // Extract convolution parameters - var strideH = GetAttributeInt(node, "strideH", 1); - var strideW = GetAttributeInt(node, "strideW", 1); - var padH = GetAttributeInt(node, "padH", 0); - var padW = GetAttributeInt(node, "padW", 0); - var groups = GetAttributeInt(node, "groups", 1); - - // Shapes have been validated - safe to access directly - // Input shape: NCHW [batch, channels, height, width] - // Kernel shape: OIHW [out_channels, in_channels, kernel_h, kernel_w] - var op = new Conv2DOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = inputIds, - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - BatchSize = inputShape[0], - InputChannels = inputShape[1], - InputHeight = inputShape[2], - InputWidth = inputShape[3], - OutputChannels = kernelShape[0], - KernelHeight = kernelShape[2], - KernelWidth = kernelShape[3], - StrideH = strideH, - StrideW = strideW, - PadH = padH, - PadW = padW, - Groups = groups, - SourceHLIRNodeId = node.Id - }; - - // Select algorithm - op.Algorithm = SelectConvAlgorithm(op); - - return op; - } - - private LLIROp LowerActivation(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - var opType = node.Operation switch - { - OperationType.ReLU => ElementwiseOpType.ReLU, - OperationType.Sigmoid => ElementwiseOpType.Sigmoid, - OperationType.Tanh => ElementwiseOpType.Tanh, - OperationType.GELU => ElementwiseOpType.GELU, - OperationType.Softmax => ElementwiseOpType.Softmax, - OperationType.LogSoftmax => ElementwiseOpType.LogSoftmax, - _ => ElementwiseOpType.ReLU - }; - - return new ElementwiseOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - ElementwiseType = opType, - SourceHLIRNodeId = node.Id - }; - } - - private LLIROp LowerNormalization(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - - // Extract normalization-specific parameters - var epsilon = GetAttributeDouble(node, "epsilon", 1e-5); - var momentum = GetAttributeDouble(node, "momentum", 0.1); - var axis = GetAttributeInt(node, "axis", -1); - - var fusedOp = new FusedOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - FusionPattern = node.Operation == OperationType.BatchNormalization - ? "BatchNorm" - : "LayerNorm", - SourceHLIRNodeId = node.Id - }; - - // Store normalization parameters in attributes for runtime execution - fusedOp.Attributes["epsilon"] = epsilon; - fusedOp.Attributes["momentum"] = momentum; - fusedOp.Attributes["axis"] = axis; - - return fusedOp; - } - - /// - /// Lowers a pooling operation from HLIR to LLIR representation. - /// - /// The HLIR node representing the pooling operation. - /// The lowered FusedOp containing pooling parameters for LLIR execution. - /// - /// - /// Pooling operations (MaxPool2D, AvgPool2D) have spatial window, stride, and padding - /// parameters that distinguish them from simple reductions. This method uses FusedOp - /// to preserve these windowed operation semantics rather than ReduceOp which would - /// lose the spatial parameters. - /// - /// - /// The pooling parameters (kernel size, stride, padding) are extracted from the node's - /// attributes and stored in the FusedOp's Attributes dictionary for runtime execution. - /// - /// - private LLIROp LowerPooling(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - var pattern = node.Operation switch - { - OperationType.MaxPool2D => "MaxPool2D", - OperationType.AvgPool2D => "AvgPool2D", - OperationType.GlobalAveragePooling => "GlobalAvgPool", - _ => "MaxPool2D" - }; - - // Extract pooling parameters from node attributes - var kernelH = GetAttributeInt(node, "kernelH", 2); - var kernelW = GetAttributeInt(node, "kernelW", 2); - var strideH = GetAttributeInt(node, "strideH", 2); - var strideW = GetAttributeInt(node, "strideW", 2); - var padH = GetAttributeInt(node, "padH", 0); - var padW = GetAttributeInt(node, "padW", 0); - - var fusedOp = new FusedOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - FusionPattern = pattern, - SourceHLIRNodeId = node.Id - }; - - // Store pooling parameters in attributes for runtime execution - fusedOp.Attributes["kernelH"] = kernelH; - fusedOp.Attributes["kernelW"] = kernelW; - fusedOp.Attributes["strideH"] = strideH; - fusedOp.Attributes["strideW"] = strideW; - fusedOp.Attributes["padH"] = padH; - fusedOp.Attributes["padW"] = padW; - - return fusedOp; - } - - private LLIROp LowerMemoryOp(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - var memOpType = node.Operation switch - { - OperationType.Reshape => MemoryOpType.Reshape, - OperationType.Transpose => MemoryOpType.Transpose, - OperationType.Flatten => MemoryOpType.Reshape, - OperationType.Concat => MemoryOpType.Concat, - OperationType.Split => MemoryOpType.Slice, - OperationType.Slice => MemoryOpType.Slice, - _ => MemoryOpType.Copy - }; - - var op = new MemoryOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - MemoryOpType = memOpType, - SourceHLIRNodeId = node.Id - }; - - if (node.Operation == OperationType.Transpose && - node.Attributes.TryGetValue("perm", out var perm)) - { - op.Permutation = GetAttributeIntArray(perm, "perm"); - } - - if (node.Operation == OperationType.Reshape) - { - op.NewShape = node.OutputType.Shape; - } - - return op; - } - - private LLIROp? LowerConstant(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - _hlirToLlirBufferMap[node.Id] = bufferId; - - var op = new ConstantOp - { - OutputId = bufferId, - Name = node.Name, - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - IsParameter = node.Parameters.Count > 0, - ParameterName = node.Name, - SourceHLIRNodeId = node.Id - }; - - _llirGraph.AddOperation(op); - return null; // Already added - } - - private LLIROp? LowerInputOutput(HLIRNode node) - { - // Input/output nodes are handled separately - return null; - } - - private LLIROp LowerReduction(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - var reduceType = node.Operation switch - { - OperationType.ReduceSum => ReduceType.Sum, - OperationType.Mean or OperationType.ReduceMean => ReduceType.Mean, - OperationType.ReduceMax => ReduceType.Max, - OperationType.ReduceMin => ReduceType.Min, - _ => ReduceType.Sum - }; - - var axes = node.Attributes.TryGetValue("axes", out var ax) ? GetAttributeIntArray(ax, "axes") : Array.Empty(); - var keepDims = node.Attributes.TryGetValue("keepDims", out var kd) && (bool)kd; - - // Get input shape for accurate cost estimation - var inputShape = node.InputTypes.Count > 0 && node.InputTypes[0].Shape != null - ? node.InputTypes[0].Shape - : Array.Empty(); - - return new ReduceOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - ReduceType = reduceType, - Axes = axes, - KeepDims = keepDims, - InputShape = inputShape, - SourceHLIRNodeId = node.Id - }; - } - - private LLIROp LowerDense(HLIRNode node) - { - // Dense is lowered to MatMul + optional bias add - return LowerMatMul(node); - } - - private LLIROp LowerEmbedding(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - return new MemoryOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - MemoryOpType = MemoryOpType.Gather, - SourceHLIRNodeId = node.Id - }; - } - - private LLIROp LowerAttention(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - return new FusedOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - FusionPattern = "Attention", - SourceHLIRNodeId = node.Id - }; - } - - /// - /// Lowers a dropout operation from HLIR to LLIR representation. - /// - /// The HLIR node representing the dropout operation. - /// Always returns null since dropout is a no-op in inference mode. - /// - /// - /// During inference, dropout is a no-op (identity operation) - inputs pass through - /// unchanged. This method simply maps the node's output to its input buffer, - /// avoiding unnecessary memory allocation or computation. - /// - /// - private LLIROp? LowerDropout(HLIRNode node) - { - // Dropout in inference mode is a no-op (identity) - // Validate that the node has at least one input before accessing - if (node.Inputs.Count > 0 && _hlirToLlirBufferMap.TryGetValue(node.Inputs[0].Id, out var inputId)) - { - _hlirToLlirBufferMap[node.Id] = inputId; - } - return null; - } - - private LLIROp LowerFusedOperation(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - var pattern = node.Operation switch - { - OperationType.FusedConvBatchNormReLU => "ConvBNReLU", - OperationType.FusedMatMulBias => "MatMulBias", - OperationType.FusedMatMulBiasReLU => "MatMulBiasReLU", - OperationType.FusedMatMulBiasGELU => "MatMulBiasGELU", - OperationType.FusedMultiHeadAttention => "FusedMHA", - OperationType.FusedLayerNormAttention => "LNAttention", - _ => "Fused" - }; - - return new FusedOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - FusionPattern = pattern, - SourceHLIRNodeId = node.Id - }; - } - - private LLIROp LowerGeneric(HLIRNode node) - { - var bufferId = _llirGraph.AllocateBufferId(); - return new ElementwiseOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - ElementwiseType = ElementwiseOpType.Identity, - SourceHLIRNodeId = node.Id - }; - } - - private void LowerFusedNode(HLIRNode node) - { - var fusedFrom = node.FusedFrom; - if (fusedFrom == null || fusedFrom.Count == 0) - { - return; - } - - var bufferId = _llirGraph.AllocateBufferId(); - - // Create a FusedOp that captures the fusion pattern - var pattern = string.Join("_", fusedFrom.Select(n => n.Operation.ToString())); - - var fusedOp = new FusedOp - { - OutputId = bufferId, - Name = node.Name, - InputIds = GetLLIRInputIds(node), - OutputShape = node.OutputType.Shape, - OutputDataType = ConvertDataType(node.OutputType.DataType), - Device = GetDeviceForNode(node), - FusionPattern = pattern, - SourceHLIRNodeId = node.Id - }; - - // Local map for internal fusion buffer wiring (e.g., Conv→BN→ReLU chains) - // This allows sub-ops within the fusion to reference each other's outputs - var fusionBufferMap = new Dictionary(); - - // Lower each original node as part of the fused op with fusion-aware input resolution - foreach (var originalNode in fusedFrom) - { - var llirOp = LowerNodeToOpWithFusionContext(originalNode, fusionBufferMap); - if (llirOp != null) - { - fusedOp.FusedOps.Add(llirOp); - // Store this sub-op's output in the fusion buffer map for subsequent sub-ops - fusionBufferMap[originalNode.Id] = llirOp.OutputId; - } - } - - _llirGraph.AddOperation(fusedOp); - _hlirToLlirBufferMap[node.Id] = bufferId; - } - - /// - /// Lowers a single HLIR node to an LLIR operation with fusion-aware input resolution. - /// - /// The HLIR node to lower. - /// Local buffer map for internal fusion dependencies (e.g., Conv→BN→ReLU chains). - /// The lowered LLIR operation with complete metadata, or null if the operation type is not supported. - /// - /// - /// This method creates LLIR operations with full metadata transfer from the source HLIR node, - /// including shape information, data types, and operation-specific dimensions. This is critical - /// for accurate cost estimation in fused operations. - /// - /// - /// For chained fusion patterns (e.g., Conv→BN→ReLU), sub-ops may depend on each other's outputs. - /// The fusionBufferMap is checked first for internal fusion dependencies, then falls back to - /// the global _hlirToLlirBufferMap for external dependencies. - /// - /// - /// For MatMul operations, the M, N, K dimensions are extracted from input/output shapes. - /// For ElementwiseOp operations, OutputShape and OutputDataType are transferred. - /// - /// - private LLIROp? LowerNodeToOpWithFusionContext(HLIRNode node, Dictionary fusionBufferMap) - { - // Get common properties from the node - var outputShape = node.OutputType?.Shape ?? Array.Empty(); - var outputDataType = node.OutputType != null ? ConvertDataType(node.OutputType.DataType) : IRDataType.Float32; - // Use fusion-aware input resolution for chained fusion patterns - var inputIds = GetLLIRInputIdsWithFusionContext(node, fusionBufferMap); - - return node.Operation switch - { - // Convolution operations - OperationType.Conv2D => CreateConv2DOpForFusion(node, outputShape, outputDataType, inputIds), - - // Matrix operations - OperationType.MatMul or OperationType.Gemm => CreateMatMulOp(node, outputShape, outputDataType, inputIds), - - // Elementwise arithmetic - OperationType.Add => CreateElementwiseOp(node, ElementwiseOpType.Add, outputShape, outputDataType, inputIds), - OperationType.Subtract => CreateElementwiseOp(node, ElementwiseOpType.Subtract, outputShape, outputDataType, inputIds), - OperationType.Multiply => CreateElementwiseOp(node, ElementwiseOpType.Multiply, outputShape, outputDataType, inputIds), - OperationType.Divide => CreateElementwiseOp(node, ElementwiseOpType.Divide, outputShape, outputDataType, inputIds), - - // Activation functions - OperationType.ReLU => CreateElementwiseOp(node, ElementwiseOpType.ReLU, outputShape, outputDataType, inputIds), - OperationType.Sigmoid => CreateElementwiseOp(node, ElementwiseOpType.Sigmoid, outputShape, outputDataType, inputIds), - OperationType.Tanh => CreateElementwiseOp(node, ElementwiseOpType.Tanh, outputShape, outputDataType, inputIds), - OperationType.GELU => CreateElementwiseOp(node, ElementwiseOpType.GELU, outputShape, outputDataType, inputIds), - OperationType.Softmax => CreateElementwiseOp(node, ElementwiseOpType.Softmax, outputShape, outputDataType, inputIds), - OperationType.LogSoftmax => CreateElementwiseOp(node, ElementwiseOpType.LogSoftmax, outputShape, outputDataType, inputIds), - - // Normalization operations (for fused Conv+BN+ReLU patterns) - OperationType.BatchNormalization or OperationType.LayerNormalization => - CreateNormalizationOpForFusion(node, outputShape, outputDataType, inputIds), - - // Pooling operations (for fused Conv+Pool patterns) - OperationType.MaxPool2D or OperationType.AvgPool2D or OperationType.GlobalAveragePooling => - CreatePoolingOpForFusion(node, outputShape, outputDataType, inputIds), - - // Memory/reshape operations (for fused attention + reshape patterns) - OperationType.Reshape or OperationType.Transpose or OperationType.Flatten => - CreateMemoryOpForFusion(node, outputShape, outputDataType, inputIds), - - // Reduction operations (for fused attention + reduction patterns) - OperationType.ReduceSum or OperationType.Mean or OperationType.ReduceMean or - OperationType.ReduceMax or OperationType.ReduceMin => - CreateReductionOpForFusion(node, outputShape, outputDataType, inputIds), - - // Dense/fully-connected operations (commonly fused with activations) - OperationType.Dense or OperationType.FullyConnected => - CreateMatMulOp(node, outputShape, outputDataType, inputIds), - - // Attention operations (for transformer fusions) - OperationType.Attention or OperationType.MultiHeadAttention => - CreateAttentionOpForFusion(node, outputShape, outputDataType, inputIds), - - // Dropout is identity during inference (no-op in fusion context) - OperationType.Dropout => - CreateIdentityOpForFusion(node, outputShape, outputDataType, inputIds), - - // Unsupported operation in fused context - _ => throw new InvalidOperationException( - $"Operation '{node.Operation}' is not supported within fused operations. " + - $"Node: '{node.Name}' (ID: {node.Id})") - }; - } - - /// - /// Creates a MatMulOp with proper dimension information from the HLIR node. - /// - private MatMulOp CreateMatMulOp(HLIRNode node, int[] outputShape, IRDataType outputDataType, int[] inputIds) - { - // Infer M, N, K dimensions from input/output shapes - // MatMul: [M, K] × [K, N] = [M, N] - int m = 1, n = 1, k = 1; - - if (node.InputTypes.Count >= 2) - { - var leftShape = node.InputTypes[0].Shape; - var rightShape = node.InputTypes[1].Shape; - - if (leftShape != null && leftShape.Length >= 2) - { - m = leftShape[^2]; // Second-to-last dimension - k = leftShape[^1]; // Last dimension - } - - if (rightShape != null && rightShape.Length >= 2) - { - n = rightShape[^1]; // Last dimension - } - } - else if (outputShape.Length >= 2) - { - // Fallback: infer from output shape - m = outputShape[^2]; - n = outputShape[^1]; - } - - // Allocate output buffer for this sub-op within the fused operation - var outputId = _llirGraph.AllocateBufferId(); - - // Extract transpose flags using the same safe helper as LowerMatMul - var transposeA = GetAttributeBool(node, "transposeA", false); - var transposeB = GetAttributeBool(node, "transposeB", false); - - return new MatMulOp - { - OutputId = outputId, - Name = node.Name, - InputIds = inputIds, - OutputShape = outputShape, - OutputDataType = outputDataType, - Device = GetDeviceForNode(node), - M = m, - N = n, - K = k, - TransposeA = transposeA, - TransposeB = transposeB, - SourceHLIRNodeId = node.Id - }; - } - - /// - /// Creates an ElementwiseOp with proper metadata from the HLIR node. - /// - private ElementwiseOp CreateElementwiseOp( - HLIRNode node, - ElementwiseOpType opType, - int[] outputShape, - IRDataType outputDataType, - int[] inputIds) - { - // Allocate output buffer for this sub-op within the fused operation - var outputId = _llirGraph.AllocateBufferId(); - - return new ElementwiseOp - { - OutputId = outputId, - Name = node.Name, - ElementwiseType = opType, - InputIds = inputIds, - OutputShape = outputShape, - OutputDataType = outputDataType, - Device = GetDeviceForNode(node), - SourceHLIRNodeId = node.Id - }; - } - - /// - /// Creates a Conv2DOp for use within fused operations. - /// - private Conv2DOp CreateConv2DOpForFusion( - HLIRNode node, - int[] outputShape, - IRDataType outputDataType, - int[] inputIds) - { - // Allocate output buffer for this sub-op within the fused operation - var outputId = _llirGraph.AllocateBufferId(); - - // Extract shape information from node's input types if available - int batchSize = 1, inputChannels = 1, inputHeight = 1, inputWidth = 1; - int outputChannels = 1, kernelHeight = 3, kernelWidth = 3; - - if (node.InputTypes.Count >= 1 && node.InputTypes[0].Shape?.Length >= 4) - { - var inputShape = node.InputTypes[0].Shape; - batchSize = inputShape[0]; - inputChannels = inputShape[1]; - inputHeight = inputShape[2]; - inputWidth = inputShape[3]; - } - - if (node.InputTypes.Count >= 2 && node.InputTypes[1].Shape?.Length >= 4) - { - var kernelShape = node.InputTypes[1].Shape; - outputChannels = kernelShape[0]; - kernelHeight = kernelShape[2]; - kernelWidth = kernelShape[3]; - } - else if (outputShape.Length >= 4) - { - // Fallback: infer output channels from output shape - outputChannels = outputShape[1]; - } - - // Extract convolution parameters from attributes - var strideH = GetAttributeInt(node, "strideH", 1); - var strideW = GetAttributeInt(node, "strideW", 1); - var padH = GetAttributeInt(node, "padH", 0); - var padW = GetAttributeInt(node, "padW", 0); - var groups = GetAttributeInt(node, "groups", 1); - - var conv2DOp = new Conv2DOp - { - OutputId = outputId, - Name = node.Name, - InputIds = inputIds, - OutputShape = outputShape, - OutputDataType = outputDataType, - BatchSize = batchSize, - InputChannels = inputChannels, - InputHeight = inputHeight, - InputWidth = inputWidth, - OutputChannels = outputChannels, - KernelHeight = kernelHeight, - KernelWidth = kernelWidth, - StrideH = strideH, - StrideW = strideW, - PadH = padH, - PadW = padW, - Groups = groups, - Device = GetDeviceForNode(node), - SourceHLIRNodeId = node.Id - }; - - conv2DOp.Algorithm = SelectConvAlgorithm(conv2DOp); - return conv2DOp; - } - - /// - /// Creates a FusedOp representing a normalization operation for use within fused operation contexts. - /// - /// The HLIR node representing the normalization operation. - /// The output shape for this sub-operation. - /// The output data type for this sub-operation. - /// The resolved input buffer IDs within the fusion context. - /// A FusedOp containing normalization-specific parameters for LLIR execution. - /// - /// - /// This method enables normalization operations (BatchNormalization, LayerNormalization) to be - /// included in fused operation patterns such as Conv+BN+ReLU. The normalization parameters - /// (epsilon, momentum, axis) are extracted from the HLIR node attributes and stored in the - /// FusedOp's Attributes dictionary. - /// - /// - /// For BatchNormalization, the typical parameters include: - /// - epsilon: Small constant for numerical stability (default 1e-5) - /// - momentum: Running mean/variance momentum (default 0.1) - /// - axis: Channel axis for normalization (default -1, typically channel dimension) - /// - /// - /// For LayerNormalization, similar parameters are used but normalization occurs - /// across different dimensions (typically the feature dimension). - /// - /// - private FusedOp CreateNormalizationOpForFusion( - HLIRNode node, - int[] outputShape, - IRDataType outputDataType, - int[] inputIds) - { - // Allocate output buffer for this sub-op within the fused operation - var outputId = _llirGraph.AllocateBufferId(); - - // Extract normalization-specific parameters from node attributes - var epsilon = GetAttributeDouble(node, "epsilon", 1e-5); - var momentum = GetAttributeDouble(node, "momentum", 0.1); - var axis = GetAttributeInt(node, "axis", -1); - - // Determine fusion pattern based on operation type - var fusionPattern = node.Operation == OperationType.BatchNormalization - ? "BatchNorm" - : "LayerNorm"; - - var fusedOp = new FusedOp - { - OutputId = outputId, - Name = node.Name, - InputIds = inputIds, - OutputShape = outputShape, - OutputDataType = outputDataType, - Device = GetDeviceForNode(node), - FusionPattern = fusionPattern, - SourceHLIRNodeId = node.Id - }; - - // Store normalization parameters in attributes for runtime execution - fusedOp.Attributes["epsilon"] = epsilon; - fusedOp.Attributes["momentum"] = momentum; - fusedOp.Attributes["axis"] = axis; - - return fusedOp; - } - - /// - /// Creates a FusedOp representing a pooling operation for use within fused operation contexts. - /// - /// The HLIR node representing the pooling operation. - /// The output shape for this sub-operation. - /// The output data type for this sub-operation. - /// The resolved input buffer IDs within the fusion context. - /// A FusedOp containing pooling-specific parameters for LLIR execution. - /// - /// - /// This method enables pooling operations (MaxPool2D, AvgPool2D, GlobalAveragePooling) to be - /// included in fused operation patterns such as Conv+Pool or Conv+BN+ReLU+Pool. The pooling - /// parameters (kernel size, stride, padding) are extracted from the HLIR node attributes. - /// - /// - /// Pooling operations have spatial window, stride, and padding parameters that distinguish - /// them from simple reductions. This method uses FusedOp to preserve these windowed operation - /// semantics for efficient runtime execution. - /// - /// - private FusedOp CreatePoolingOpForFusion( - HLIRNode node, - int[] outputShape, - IRDataType outputDataType, - int[] inputIds) - { - var outputId = _llirGraph.AllocateBufferId(); - - var pattern = node.Operation switch - { - OperationType.MaxPool2D => "MaxPool2D", - OperationType.AvgPool2D => "AvgPool2D", - OperationType.GlobalAveragePooling => "GlobalAvgPool", - _ => "MaxPool2D" - }; - - // Extract pooling parameters from node attributes - var kernelH = GetAttributeInt(node, "kernelH", 2); - var kernelW = GetAttributeInt(node, "kernelW", 2); - var strideH = GetAttributeInt(node, "strideH", 2); - var strideW = GetAttributeInt(node, "strideW", 2); - var padH = GetAttributeInt(node, "padH", 0); - var padW = GetAttributeInt(node, "padW", 0); - - var fusedOp = new FusedOp - { - OutputId = outputId, - Name = node.Name, - InputIds = inputIds, - OutputShape = outputShape, - OutputDataType = outputDataType, - Device = GetDeviceForNode(node), - FusionPattern = pattern, - SourceHLIRNodeId = node.Id - }; - - // Store pooling parameters in attributes for runtime execution - fusedOp.Attributes["kernelH"] = kernelH; - fusedOp.Attributes["kernelW"] = kernelW; - fusedOp.Attributes["strideH"] = strideH; - fusedOp.Attributes["strideW"] = strideW; - fusedOp.Attributes["padH"] = padH; - fusedOp.Attributes["padW"] = padW; - - return fusedOp; - } - - /// - /// Creates a MemoryOp representing a memory/reshape operation for use within fused operation contexts. - /// - /// The HLIR node representing the memory operation. - /// The output shape for this sub-operation. - /// The output data type for this sub-operation. - /// The resolved input buffer IDs within the fusion context. - /// A MemoryOp containing reshape/transpose parameters for LLIR execution. - /// - /// - /// This method enables memory operations (Reshape, Transpose, Flatten) to be included in - /// fused operation patterns such as attention mechanisms that require reshaping between - /// matrix multiplications. These operations are typically zero-copy or view operations. - /// - /// - private MemoryOp CreateMemoryOpForFusion( - HLIRNode node, - int[] outputShape, - IRDataType outputDataType, - int[] inputIds) - { - var outputId = _llirGraph.AllocateBufferId(); - - var memOpType = node.Operation switch - { - OperationType.Reshape => MemoryOpType.Reshape, - OperationType.Transpose => MemoryOpType.Transpose, - OperationType.Flatten => MemoryOpType.Reshape, - _ => MemoryOpType.Copy - }; - - var op = new MemoryOp - { - OutputId = outputId, - Name = node.Name, - InputIds = inputIds, - OutputShape = outputShape, - OutputDataType = outputDataType, - Device = GetDeviceForNode(node), - MemoryOpType = memOpType, - SourceHLIRNodeId = node.Id - }; - - // Handle transpose permutation - if (node.Operation == OperationType.Transpose && - node.Attributes.TryGetValue("perm", out var perm)) - { - op.Permutation = GetAttributeIntArray(perm, "perm"); - } - - // Handle reshape new shape - if (node.Operation == OperationType.Reshape || node.Operation == OperationType.Flatten) - { - op.NewShape = outputShape; - } - - return op; - } - - /// - /// Creates a ReduceOp representing a reduction operation for use within fused operation contexts. - /// - /// The HLIR node representing the reduction operation. - /// The output shape for this sub-operation. - /// The output data type for this sub-operation. - /// The resolved input buffer IDs within the fusion context. - /// A ReduceOp containing reduction parameters for LLIR execution. - /// - /// - /// This method enables reduction operations (ReduceSum, Mean, ReduceMax, ReduceMin) to be - /// included in fused operation patterns such as attention softmax normalization or - /// mean pooling operations. The reduction axes and keepDims parameters are preserved. - /// - /// - private ReduceOp CreateReductionOpForFusion( - HLIRNode node, - int[] outputShape, - IRDataType outputDataType, - int[] inputIds) - { - var outputId = _llirGraph.AllocateBufferId(); - - var reduceType = node.Operation switch - { - OperationType.ReduceSum => ReduceType.Sum, - OperationType.Mean or OperationType.ReduceMean => ReduceType.Mean, - OperationType.ReduceMax => ReduceType.Max, - OperationType.ReduceMin => ReduceType.Min, - _ => ReduceType.Sum - }; - - var axes = node.Attributes.TryGetValue("axes", out var ax) ? GetAttributeIntArray(ax, "axes") : Array.Empty(); - var keepDims = node.Attributes.TryGetValue("keepDims", out var kd) && (bool)kd; - - // Get input shape for accurate cost estimation - var inputShape = node.InputTypes.Count > 0 && node.InputTypes[0].Shape != null - ? node.InputTypes[0].Shape - : Array.Empty(); - - return new ReduceOp - { - OutputId = outputId, - Name = node.Name, - InputIds = inputIds, - OutputShape = outputShape, - OutputDataType = outputDataType, - Device = GetDeviceForNode(node), - ReduceType = reduceType, - Axes = axes, - KeepDims = keepDims, - InputShape = inputShape, - SourceHLIRNodeId = node.Id - }; - } - - /// - /// Creates a FusedOp representing an attention operation for use within fused operation contexts. - /// - /// The HLIR node representing the attention operation. - /// The output shape for this sub-operation. - /// The output data type for this sub-operation. - /// The resolved input buffer IDs within the fusion context. - /// A FusedOp containing attention parameters for LLIR execution. - /// - /// - /// This method enables attention operations (Attention, MultiHeadAttention) to be included - /// in fused transformer patterns such as LayerNorm+Attention or Attention+FFN fusions. - /// These operations benefit significantly from fusion to reduce memory bandwidth. - /// - /// - private FusedOp CreateAttentionOpForFusion( - HLIRNode node, - int[] outputShape, - IRDataType outputDataType, - int[] inputIds) - { - var outputId = _llirGraph.AllocateBufferId(); - - var pattern = node.Operation == OperationType.MultiHeadAttention - ? "MultiHeadAttention" - : "Attention"; - - // Extract attention-specific parameters - var numHeads = GetAttributeInt(node, "numHeads", 8); - var headDim = GetAttributeInt(node, "headDim", 64); - var scale = GetAttributeDouble(node, "scale", 1.0 / Math.Sqrt(headDim)); - var causal = GetAttributeBool(node, "causal", false); - - var fusedOp = new FusedOp - { - OutputId = outputId, - Name = node.Name, - InputIds = inputIds, - OutputShape = outputShape, - OutputDataType = outputDataType, - Device = GetDeviceForNode(node), - FusionPattern = pattern, - SourceHLIRNodeId = node.Id - }; - - // Store attention parameters for runtime execution - fusedOp.Attributes["numHeads"] = numHeads; - fusedOp.Attributes["headDim"] = headDim; - fusedOp.Attributes["scale"] = scale; - fusedOp.Attributes["causal"] = causal; - - return fusedOp; - } - - /// - /// Creates an identity ElementwiseOp for operations that are no-ops during inference. - /// - /// The HLIR node representing the no-op operation. - /// The output shape for this sub-operation. - /// The output data type for this sub-operation. - /// The resolved input buffer IDs within the fusion context. - /// An ElementwiseOp configured as an identity operation. - /// - /// - /// This method handles operations like Dropout that become identity operations during - /// inference. Rather than special-casing these in the fusion executor, we emit an - /// explicit identity op that can be optimized away during execution planning. - /// - /// - private ElementwiseOp CreateIdentityOpForFusion( - HLIRNode node, - int[] outputShape, - IRDataType outputDataType, - int[] inputIds) - { - var outputId = _llirGraph.AllocateBufferId(); - - return new ElementwiseOp - { - OutputId = outputId, - Name = node.Name, - InputIds = inputIds, - OutputShape = outputShape, - OutputDataType = outputDataType, - Device = GetDeviceForNode(node), - ElementwiseType = ElementwiseOpType.Identity, - SourceHLIRNodeId = node.Id - }; - } - - #endregion - - #region Helpers - - private int[] GetLLIRInputIds(HLIRNode node) - { - var ids = new List(); - foreach (var input in node.Inputs) - { - if (!_hlirToLlirBufferMap.TryGetValue(input.Id, out var llirId)) - { - throw new InvalidOperationException( - $"Input node '{input.Name}' (ID: {input.Id}) was not lowered before being used by " + - $"node '{node.Name}' (ID: {node.Id}). This indicates a topological ordering issue or " + - $"missing lowering implementation for operation type '{input.Operation}'."); - } - ids.Add(llirId); - } - return ids.ToArray(); - } - - /// - /// Resolves LLIR input IDs with fusion-aware context for chained fusion patterns. - /// - /// The HLIR node whose inputs need resolution. - /// Local buffer map for internal fusion dependencies. - /// Array of resolved LLIR buffer IDs. - /// - /// For chained fusions (e.g., Conv→BN→ReLU), sub-ops depend on each other's outputs. - /// This method checks the fusionBufferMap first for internal fusion dependencies, - /// then falls back to the global _hlirToLlirBufferMap for external dependencies. - /// - private int[] GetLLIRInputIdsWithFusionContext(HLIRNode node, Dictionary fusionBufferMap) - { - var ids = new List(); - foreach (var input in node.Inputs) - { - // First check fusion buffer map for internal fusion dependencies - if (fusionBufferMap.TryGetValue(input.Id, out var fusionId)) - { - ids.Add(fusionId); - } - // Fall back to global buffer map for external dependencies - else if (_hlirToLlirBufferMap.TryGetValue(input.Id, out var globalId)) - { - ids.Add(globalId); - } - else - { - throw new InvalidOperationException( - $"Cannot resolve input '{input.Name}' (ID: {input.Id}) for node '{node.Name}' (ID: {node.Id}). " + - $"Input not found in fusion buffer map or global buffer map. " + - $"This indicates a topological ordering issue within the fusion or missing lowering implementation."); - } - } - return ids.ToArray(); - } - - private IRDataType ConvertDataType(IRDataType hlirType) - { - // If HLIR has unknown type, use target type - if (hlirType == IRDataType.Unknown) - { - return TargetDataType; - } - return hlirType; - } - - private DeviceType GetDeviceForNode(HLIRNode node) - { - // Use node's preferred device or fall back to target - if (node.Hints.PreferredDevice != DeviceType.Auto) - { - return node.Hints.PreferredDevice; - } - return TargetDevice; - } - - private (int m, int n, int k) InferMatMulDims(HLIRNode node) - { - // Infer from input shapes with null safety - if (node.InputTypes.Count >= 2) - { - var shapeA = node.InputTypes[0].Shape; - var shapeB = node.InputTypes[1].Shape; - - // Guard against null shapes - shapes may not be known at lowering time - if (shapeA != null && shapeB != null && shapeA.Length >= 2 && shapeB.Length >= 2) - { - return (shapeA[^2], shapeB[^1], shapeA[^1]); - } - } - - // Default when shapes are unknown or incompatible - return (1, 1, 1); - } - - private ConvAlgorithm SelectConvAlgorithm(Conv2DOp op) - { - // Select based on kernel size and device - if (op.Device == DeviceType.GPU && DeviceConfig.GPUHasTensorCores) - { - return ConvAlgorithm.TensorCore; - } - - if (op.KernelHeight <= 3 && op.KernelWidth <= 3) - { - return ConvAlgorithm.Winograd; - } - - if (op.KernelHeight >= 7 && op.KernelWidth >= 7) - { - return ConvAlgorithm.FFT; - } - - return op.Device == DeviceType.GPU ? ConvAlgorithm.Implicit : ConvAlgorithm.Im2Col; - } - - /// - /// Safely retrieves an integer attribute from a node's attribute dictionary. - /// - /// The HLIR node containing the attributes. - /// The attribute key to look up. - /// The default value to return if the attribute is missing or invalid. - /// The attribute value as an integer, or the default value if not found or conversion fails. - /// - /// - /// This method handles various attribute types safely: - /// - /// Direct integer types (int, long, short, byte) are converted directly - /// String values are parsed using int.TryParse - /// Other IConvertible types use Convert.ToInt32 with exception handling - /// Any conversion failure returns the default value - /// - /// - /// - private int GetAttributeInt(HLIRNode node, string key, int defaultValue) - { - if (!node.Attributes.TryGetValue(key, out var value) || value == null) - { - return defaultValue; - } - - // Handle common integer types directly - if (value is int intValue) - { - return intValue; - } - if (value is long longValue) - { - return longValue is >= int.MinValue and <= int.MaxValue ? (int)longValue : defaultValue; - } - if (value is short shortValue) - { - return shortValue; - } - if (value is byte byteValue) - { - return byteValue; - } - - // Handle string values - if (value is string strValue) - { - return int.TryParse(strValue, out var parsed) ? parsed : defaultValue; - } - - // Handle other IConvertible types with exception handling - try - { - return Convert.ToInt32(value); - } - catch (FormatException) - { - return defaultValue; - } - catch (OverflowException) - { - return defaultValue; - } - catch (InvalidCastException) - { - return defaultValue; - } - } - - /// - /// Safely converts an attribute value to a boolean with defensive type handling. - /// - /// The HLIR node containing attributes. - /// The attribute key to look up. - /// The value to return if key doesn't exist or conversion fails. - /// The boolean value, or defaultValue if not found or conversion fails. - /// - /// Design Philosophy: - /// - /// Handles various runtime types that may be stored in node attributes: - /// - /// - /// bool - returned directly - /// string "true"/"false" (case-insensitive) - parsed to bool - /// int/long 0/1 - converted to false/true - /// Any conversion failure returns the default value - /// - /// - private bool GetAttributeBool(HLIRNode node, string key, bool defaultValue) - { - if (!node.Attributes.TryGetValue(key, out var value) || value == null) - { - return defaultValue; - } - - // Handle bool directly - if (value is bool boolValue) - { - return boolValue; - } - - // Handle string values (case-insensitive) - if (value is string strValue) - { - if (bool.TryParse(strValue, out var parsed)) - { - return parsed; - } - // Handle "1" and "0" strings - if (strValue == "1") return true; - if (strValue == "0") return false; - return defaultValue; - } - - // Handle integer types (0 = false, non-zero = true) - if (value is int intValue) - { - return intValue != 0; - } - if (value is long longValue) - { - return longValue != 0; - } - - // Handle other IConvertible types with exception handling - try - { - return Convert.ToBoolean(value); - } - catch (FormatException) - { - return defaultValue; - } - catch (InvalidCastException) - { - return defaultValue; - } - } - - /// - /// Safely converts an attribute value to a double with defensive type handling. - /// - /// The HLIR node containing attributes. - /// The attribute key to look up. - /// The value to return if key doesn't exist or conversion fails. - /// The double value, or defaultValue if not found or conversion fails. - private double GetAttributeDouble(HLIRNode node, string key, double defaultValue) - { - if (!node.Attributes.TryGetValue(key, out var value) || value == null) - { - return defaultValue; - } - - // Handle numeric types directly - if (value is double doubleValue) - { - return doubleValue; - } - if (value is float floatValue) - { - return floatValue; - } - if (value is int intValue) - { - return intValue; - } - if (value is long longValue) - { - return longValue; - } - if (value is decimal decimalValue) - { - return (double)decimalValue; - } - - // Handle string values - if (value is string strValue) - { - return double.TryParse(strValue, out var parsed) ? parsed : defaultValue; - } - - // Handle other IConvertible types with exception handling - try - { - return Convert.ToDouble(value); - } - catch (FormatException) - { - return defaultValue; - } - catch (OverflowException) - { - return defaultValue; - } - catch (InvalidCastException) - { - return defaultValue; - } - } - - /// - /// Safely converts an attribute value to an int array with defensive type handling. - /// - /// The attribute value to convert. - /// The name of the attribute (for error messages). - /// An int array, or empty array if conversion is not possible. - /// - /// Design Philosophy: - /// - /// Handles various runtime types that may be stored in node attributes: - /// - /// - /// int[] - returned directly - /// long[] - converted with overflow checking - /// IList<int> - materialized to int[] - /// IEnumerable<long> - converted element-wise - /// object[] - each element converted via Convert.ToInt32 - /// IEnumerable - each element converted via Convert.ToInt32 - /// - /// - private static int[] GetAttributeIntArray(object? value, string attributeName) - { - if (value == null) - { - return Array.Empty(); - } - - // Handle int[] directly - if (value is int[] intArray) - { - return intArray; - } - - // Handle long[] with overflow checking - if (value is long[] longArray) - { - var result = new int[longArray.Length]; - for (int i = 0; i < longArray.Length; i++) - { - long val = longArray[i]; - if (val is < int.MinValue or > int.MaxValue) - { - throw new OverflowException( - $"Value {val} at index {i} in attribute '{attributeName}' exceeds int range."); - } - result[i] = (int)val; - } - return result; - } - - // Handle IList (e.g., List) - if (value is IList intList) - { - var result = new int[intList.Count]; - intList.CopyTo(result, 0); - return result; - } - - // Handle IEnumerable (but not IList, already handled) - if (value is IEnumerable intEnumerable) - { - return intEnumerable.ToArray(); - } - - // Handle IEnumerable - if (value is IEnumerable longEnumerable) - { - var list = new List(); - int index = 0; - foreach (long val in longEnumerable) - { - if (val is < int.MinValue or > int.MaxValue) - { - throw new OverflowException( - $"Value {val} at index {index} in attribute '{attributeName}' exceeds int range."); - } - list.Add((int)val); - index++; - } - return list.ToArray(); - } - - // Handle object[] - convert each element individually - if (value is object[] objArray) - { - var result = new int[objArray.Length]; - for (int i = 0; i < objArray.Length; i++) - { - object? element = objArray[i]; - if (element == null) - { - throw new InvalidCastException( - $"Null element at index {i} in attribute '{attributeName}' cannot be converted to int."); - } - - try - { - result[i] = Convert.ToInt32(element); - } - catch (Exception ex) when (ex is FormatException or OverflowException or InvalidCastException) - { - throw new InvalidCastException( - $"Element at index {i} (type: {element.GetType().Name}, value: {element}) " + - $"in attribute '{attributeName}' cannot be converted to int.", ex); - } - } - return result; - } - - // Handle any other IEnumerable - try to convert each element - if (value is System.Collections.IEnumerable enumerable) - { - var list = new List(); - int index = 0; - foreach (object? element in enumerable) - { - if (element == null) - { - throw new InvalidCastException( - $"Null element at index {index} in attribute '{attributeName}' cannot be converted to int."); - } - - try - { - list.Add(Convert.ToInt32(element)); - } - catch (Exception ex) when (ex is FormatException or OverflowException or InvalidCastException) - { - throw new InvalidCastException( - $"Element at index {index} (type: {element.GetType().Name}, value: {element}) " + - $"in attribute '{attributeName}' cannot be converted to int.", ex); - } - index++; - } - return list.ToArray(); - } - - // If none of the above, throw a descriptive exception - throw new InvalidCastException( - $"Attribute '{attributeName}' has type '{value.GetType().Name}' which cannot be converted to int[]. " + - "Expected int[], long[], List, or another enumerable of numeric values."); - } - - #endregion -} diff --git a/src/InferenceOptimization/Kernels/AttentionKernel.cs b/src/InferenceOptimization/Kernels/AttentionKernel.cs deleted file mode 100644 index 4227e327ea..0000000000 --- a/src/InferenceOptimization/Kernels/AttentionKernel.cs +++ /dev/null @@ -1,328 +0,0 @@ -using System; -using System.Threading.Tasks; -using AiDotNet.LinearAlgebra; -using AiDotNet.Tensors.Engines.Simd; - -namespace AiDotNet.InferenceOptimization.Kernels -{ - /// - /// Fused attention kernel for transformer models - /// Implements optimized scaled dot-product attention: softmax(QK^T/sqrt(d_k))V - /// - public class AttentionKernel : ICustomOperator - { - public string Name => "FusedAttention"; - public string Version => "1.0.0"; - public int Priority => 100; - - public AttentionKernel() { } - - public bool IsSupported() - { - return true; - } - - public double EstimatedSpeedup() - { - // Fused attention reduces memory traffic significantly - return 2.5; - } - - public Tensor Execute(params Tensor[] inputs) - { - if (inputs == null || inputs.Length < 3) - throw new ArgumentException("Attention requires Q, K, V tensors"); - - var q = inputs[0]; // [batch_size, seq_len_q, d_k] - var k = inputs[1]; // [batch_size, seq_len_k, d_k] - var v = inputs[2]; // [batch_size, seq_len_v, d_v] - - bool useMask = inputs.Length > 3; - Tensor? mask = useMask ? inputs[3] : null; - - return ExecuteInternal(q, k, v, mask, maskBatchModulo: q.Shape.Length == 3 ? q.Shape[0] : 0); - } - - private Tensor ExecuteInternal( - Tensor q, - Tensor k, - Tensor v, - Tensor? mask, - int maskBatchModulo) - { - if (q.Shape.Length != 3 || k.Shape.Length != 3 || v.Shape.Length != 3) - throw new ArgumentException("Attention requires 3D tensors [batch, seq_len, features]"); - - int batchSize = q.Shape[0]; - int seqLenQ = q.Shape[1]; - int seqLenK = k.Shape[1]; - int dK = q.Shape[2]; - int dV = v.Shape[2]; - - if (k.Shape[0] != batchSize || v.Shape[0] != batchSize) - throw new ArgumentException("Q, K, and V must have the same batch size"); - - if (k.Shape[2] != dK) - throw new ArgumentException("Q and K must have same feature dimension"); - - if (v.Shape[1] != seqLenK) - throw new ArgumentException("K and V must have same sequence length"); - - if (mask != null) - { - if (mask.Shape.Length != 3) - throw new ArgumentException("Attention mask must be a 3D tensor [batch, seq_len_q, seq_len_k]"); - - if (mask.Shape[1] != seqLenQ || mask.Shape[2] != seqLenK) - throw new ArgumentException("Attention mask must match [batch, seq_len_q, seq_len_k]"); - - if (maskBatchModulo <= 0) - { - if (mask.Shape[0] != batchSize) - throw new ArgumentException("Attention mask must have the same batch size as Q when used in Execute()"); - } - else - { - if (mask.Shape[0] != maskBatchModulo) - throw new ArgumentException("Attention mask batch dimension must match the provided maskBatchModulo"); - } - } - - var result = new Tensor(new[] { batchSize, seqLenQ, dV }); - - // Process each batch in parallel - Parallel.For(0, batchSize, b => - { - ProcessBatch(q, k, v, mask, result, b, seqLenQ, seqLenK, dK, dV, maskBatchModulo); - }); - - return result; - } - - private void ProcessBatch( - Tensor q, Tensor k, Tensor v, - Tensor? mask, Tensor result, - int batchIdx, int seqLenQ, int seqLenK, int dK, int dV, - int maskBatchModulo) - { - float scale = 1.0f / MathF.Sqrt(dK); - - // Extract batch slices - int qOffset = batchIdx * seqLenQ * dK; - int kOffset = batchIdx * seqLenK * dK; - int vOffset = batchIdx * seqLenK * dV; - int outOffset = batchIdx * seqLenQ * dV; - - // Compute attention scores: QK^T - var scores = new float[seqLenQ * seqLenK]; - - for (int i = 0; i < seqLenQ; i++) - { - int qRowOffset = qOffset + i * dK; - var qRow = q.Data.Span.Slice(qRowOffset, dK); - - for (int j = 0; j < seqLenK; j++) - { - int kRowOffset = kOffset + j * dK; - var kRow = k.Data.Span.Slice(kRowOffset, dK); - float score = SimdKernels.DotProduct(qRow, kRow) * scale; - - // Apply mask if provided - if (mask != null) - { - int effectiveMaskBatch = maskBatchModulo > 0 ? (batchIdx % maskBatchModulo) : batchIdx; - int maskIdx = effectiveMaskBatch * seqLenQ * seqLenK + i * seqLenK + j; - // Use epsilon-based comparison for floating point equality - if (MathF.Abs(mask.Data.Span[maskIdx]) < 1e-6f) - { - score = float.NegativeInfinity; - } - } - - scores[i * seqLenK + j] = score; - } - } - - // Apply softmax over each row - ApplySoftmax(scores, seqLenQ, seqLenK); - - // Compute weighted sum: attention_weights * V - for (int i = 0; i < seqLenQ; i++) - { - var outRow = result.Data.Span.Slice(outOffset + i * dV, dV); - outRow.Clear(); - - // Accumulate weighted values - for (int j = 0; j < seqLenK; j++) - { - float weight = scores[i * seqLenK + j]; - if (weight <= 0f) - { - continue; - } - - var vRow = v.Data.Span.Slice(vOffset + j * dV, dV); - SimdKernels.ScalarMultiplyAdd(outRow, vRow, weight, outRow); - } - } - } - - private void ApplySoftmax(float[] data, int rows, int cols) - { - for (int i = 0; i < rows; i++) - { - int rowOffset = i * cols; - - // Find max for numerical stability - float maxVal = float.NegativeInfinity; - for (int j = 0; j < cols; j++) - { - float v = data[rowOffset + j]; - if (v > maxVal) - { - maxVal = v; - } - } - - // Compute exp and sum - float sum = 0.0f; - for (int j = 0; j < cols; j++) - { - int idx = rowOffset + j; - float v = data[idx]; - if (float.IsNegativeInfinity(v)) - { - data[idx] = 0.0f; - continue; - } - - float ev = MathF.Exp(v - maxVal); - data[idx] = ev; - sum += ev; - } - - // Normalize - if (sum > 0.0f) - { - float invSum = 1.0f / sum; - for (int j = 0; j < cols; j++) - { - data[rowOffset + j] *= invSum; - } - } - } - } - - /// - /// Multi-head attention variant - /// - public Tensor MultiHeadAttention( - Tensor q, Tensor k, Tensor v, - int numHeads, Tensor? mask = null) - { - if (q.Shape.Length != 3 || k.Shape.Length != 3 || v.Shape.Length != 3) - throw new ArgumentException("Multi-head attention requires 3D tensors"); - - int batchSize = q.Shape[0]; - int dModel = q.Shape[2]; - - if (k.Shape[0] != batchSize || v.Shape[0] != batchSize) - throw new ArgumentException("Q, K, and V must have the same batch size"); - - if (dModel % numHeads != 0) - throw new ArgumentException("d_model must be divisible by num_heads"); - - int dK = dModel / numHeads; - - if (k.Shape[2] != dModel || v.Shape[2] != dModel) - throw new ArgumentException("Q, K, and V must have the same feature dimension (d_model)"); - - if (v.Shape[1] != k.Shape[1]) - throw new ArgumentException("K and V must have the same sequence length"); - - // Reshape to [batch * num_heads, seq_len, d_k] - var qReshaped = ReshapeForMultiHead(q, numHeads, dK); - var kReshaped = ReshapeForMultiHead(k, numHeads, dK); - var vReshaped = ReshapeForMultiHead(v, numHeads, dK); - - // Apply attention - Tensor attended; - if (mask is null) - { - attended = ExecuteInternal(qReshaped, kReshaped, vReshaped, mask: null, maskBatchModulo: 0); - } - else - { - int expectedPerHeadBatch = batchSize * numHeads; - if (mask.Shape.Length != 3) - throw new ArgumentException("Multi-head attention mask must be a 3D tensor"); - - if (mask.Shape[1] != q.Shape[1] || mask.Shape[2] != k.Shape[1]) - throw new ArgumentException("Multi-head attention mask must match [batch, seq_len_q, seq_len_k]"); - - // Accept either per-batch mask [B, SQ, SK] (broadcast across heads) or per-head mask [B*H, SQ, SK]. - int maskBatchModulo = mask.Shape[0] switch - { - int b when b == expectedPerHeadBatch => 0, - int b when b == batchSize => batchSize, - _ => throw new ArgumentException("Multi-head attention mask must have batch dimension B or B*numHeads"), - }; - - attended = ExecuteInternal(qReshaped, kReshaped, vReshaped, mask, maskBatchModulo); - } - - // Reshape back to [batch, seq_len, d_model] - return ReshapeFromMultiHead(attended, batchSize, q.Shape[1], dModel); - } - - private Tensor ReshapeForMultiHead(Tensor input, int numHeads, int dK) - { - int batchSize = input.Shape[0]; - int seqLen = input.Shape[1]; - var reshaped = new Tensor(new[] { batchSize * numHeads, seqLen, dK }); - - for (int b = 0; b < batchSize; b++) - { - for (int h = 0; h < numHeads; h++) - { - for (int s = 0; s < seqLen; s++) - { - for (int d = 0; d < dK; d++) - { - int srcIdx = b * seqLen * numHeads * dK + s * numHeads * dK + h * dK + d; - int dstIdx = (b * numHeads + h) * seqLen * dK + s * dK + d; - reshaped.Data.Span[dstIdx] = input.Data.Span[srcIdx]; - } - } - } - } - - return reshaped; - } - - private Tensor ReshapeFromMultiHead(Tensor input, int batchSize, int seqLen, int dModel) - { - var reshaped = new Tensor(new[] { batchSize, seqLen, dModel }); - int numHeads = input.Shape[0] / batchSize; - int dK = input.Shape[2]; - - for (int b = 0; b < batchSize; b++) - { - for (int h = 0; h < numHeads; h++) - { - for (int s = 0; s < seqLen; s++) - { - for (int d = 0; d < dK; d++) - { - int srcIdx = (b * numHeads + h) * seqLen * dK + s * dK + d; - int dstIdx = b * seqLen * dModel + s * dModel + h * dK + d; - reshaped.Data.Span[dstIdx] = input.Data.Span[srcIdx]; - } - } - } - } - - return reshaped; - } - } -} diff --git a/src/InferenceOptimization/Kernels/ConvolutionKernel.cs b/src/InferenceOptimization/Kernels/ConvolutionKernel.cs deleted file mode 100644 index b094c52e34..0000000000 --- a/src/InferenceOptimization/Kernels/ConvolutionKernel.cs +++ /dev/null @@ -1,384 +0,0 @@ -using System; -using System.Threading.Tasks; -using AiDotNet.LinearAlgebra; -using AiDotNet.Tensors.Engines; - -namespace AiDotNet.InferenceOptimization.Kernels -{ - /// - /// Optimized convolution kernels including depthwise and group convolutions - /// - public class ConvolutionKernel : ICustomOperator - { - public string Name => "Convolution"; - public string Version => "1.0.0"; - public int Priority => 100; - - public bool IsSupported() - { - return true; - } - - public double EstimatedSpeedup() - { - var caps = PlatformDetector.Capabilities; - if (caps.HasAVX2) return 2.5; - if (caps.HasNeon) return 2.0; - return 1.5; - } - - /// - /// Executes convolution on the provided inputs. - /// Expects 2-3 inputs: input tensor, kernel tensor, and optional config tensor. - /// Config tensor format: [stride, padding] (defaults to stride=1, padding=0) - /// - public Tensor Execute(params Tensor[] inputs) - { - if (inputs == null || inputs.Length < 2) - { - throw new ArgumentException( - "ConvolutionKernel requires at least 2 inputs: input tensor and kernel tensor. " + - "Optional 3rd input for config [stride, padding]."); - } - - var input = inputs[0]; - var kernel = inputs[1]; - - // Extract stride and padding from optional config tensor or use defaults - int stride = 1; - int padding = 0; - - if (inputs.Length >= 3 && inputs[2] != null && inputs[2].Data.Length >= 2) - { - stride = Math.Max(1, (int)inputs[2].Data.Span[0]); - padding = Math.Max(0, (int)inputs[2].Data.Span[1]); - } - - // Determine convolution type based on kernel shape - // Standard: kernel[out_channels, in_channels, kH, kW] - // Depthwise: kernel[channels, 1, kH, kW] - if (kernel.Shape.Length == 4 && kernel.Shape[1] == 1) - { - // Depthwise convolution (kernel has 1 in_channel dimension) - return DepthwiseConv2D(input, kernel, stride, padding); - } - - // Default to standard 2D convolution - return Conv2D(input, kernel, stride, padding); - } - - /// - /// Standard 2D convolution - /// - public Tensor Conv2D( - Tensor input, - Tensor kernel, - int stride = 1, - int padding = 0) - { - // Input: [batch, in_channels, height, width] - // Kernel: [out_channels, in_channels, kernel_h, kernel_w] - - if (input.Shape.Length != 4 || kernel.Shape.Length != 4) - throw new ArgumentException("Conv2D requires 4D tensors"); - - if (stride <= 0) - throw new ArgumentOutOfRangeException(nameof(stride), $"stride must be positive, but got {stride}"); - - if (padding < 0) - throw new ArgumentOutOfRangeException(nameof(padding), $"padding must be non-negative, but got {padding}"); - - int batchSize = input.Shape[0]; - int inChannels = input.Shape[1]; - int inHeight = input.Shape[2]; - int inWidth = input.Shape[3]; - - int outChannels = kernel.Shape[0]; - int kernelH = kernel.Shape[2]; - int kernelW = kernel.Shape[3]; - - if (kernelH <= 0 || kernelW <= 0) - throw new ArgumentException($"Kernel dimensions must be positive, but got {kernelH}x{kernelW}"); - - if (kernel.Shape[1] != inChannels) - throw new ArgumentException($"Conv2D requires kernel.Shape[1] == inChannels ({inChannels}), but got {kernel.Shape[1]}"); - - int outHeight = (inHeight + 2 * padding - kernelH) / stride + 1; - int outWidth = (inWidth + 2 * padding - kernelW) / stride + 1; - - if (outHeight <= 0 || outWidth <= 0) - throw new ArgumentException( - $"Invalid output dimensions ({outHeight}x{outWidth}). " + - $"Check stride ({stride}), padding ({padding}), and kernel size ({kernelH}x{kernelW})."); - var output = new Tensor(new[] { batchSize, outChannels, outHeight, outWidth }); - - // Parallelize over batch and output channels - Parallel.For(0, batchSize * outChannels, idx => - { - int b = idx / outChannels; - int oc = idx % outChannels; - - Conv2DSingleOutput(input, kernel, output, b, oc, - inChannels, inHeight, inWidth, - kernelH, kernelW, stride, padding, - outHeight, outWidth); - }); - - return output; - } - - private void Conv2DSingleOutput( - Tensor input, Tensor kernel, Tensor output, - int batch, int outChannel, - int inChannels, int inHeight, int inWidth, - int kernelH, int kernelW, int stride, int padding, - int outHeight, int outWidth) - { - var inputData = input.Data.Span; - var kernelData = kernel.Data.Span; - var outputData = output.Data.Span; - - for (int oh = 0; oh < outHeight; oh++) - { - for (int ow = 0; ow < outWidth; ow++) - { - float sum = 0.0f; - - for (int ic = 0; ic < inChannels; ic++) - { - for (int kh = 0; kh < kernelH; kh++) - { - for (int kw = 0; kw < kernelW; kw++) - { - int ih = oh * stride - padding + kh; - int iw = ow * stride - padding + kw; - - if (ih >= 0 && ih < inHeight && iw >= 0 && iw < inWidth) - { - int inputIdx = ((batch * inChannels + ic) * inHeight + ih) * inWidth + iw; - int kernelIdx = ((outChannel * inChannels + ic) * kernelH + kh) * kernelW + kw; - sum += inputData[inputIdx] * kernelData[kernelIdx]; - } - } - } - } - - int outputIdx = ((batch * output.Shape[1] + outChannel) * outHeight + oh) * outWidth + ow; - outputData[outputIdx] = sum; - } - } - } - - /// - /// Depthwise separable convolution (more efficient for mobile architectures) - /// - public Tensor DepthwiseConv2D( - Tensor input, - Tensor kernel, - int stride = 1, - int padding = 0) - { - // Input: [batch, channels, height, width] - // Kernel: [channels, 1, kernel_h, kernel_w] - - if (input.Shape.Length != 4 || kernel.Shape.Length != 4) - throw new ArgumentException("DepthwiseConv2D requires 4D tensors"); - - if (stride <= 0) - throw new ArgumentOutOfRangeException(nameof(stride), $"stride must be positive, but got {stride}"); - - if (padding < 0) - throw new ArgumentOutOfRangeException(nameof(padding), $"padding must be non-negative, but got {padding}"); - - int batchSize = input.Shape[0]; - int channels = input.Shape[1]; - int inHeight = input.Shape[2]; - int inWidth = input.Shape[3]; - - int kernelH = kernel.Shape[2]; - int kernelW = kernel.Shape[3]; - - if (kernelH <= 0 || kernelW <= 0) - throw new ArgumentException($"Kernel dimensions must be positive, but got {kernelH}x{kernelW}"); - - int outHeight = (inHeight + 2 * padding - kernelH) / stride + 1; - int outWidth = (inWidth + 2 * padding - kernelW) / stride + 1; - - if (outHeight <= 0 || outWidth <= 0) - throw new ArgumentException( - $"Invalid output dimensions ({outHeight}x{outWidth}). " + - $"Check stride ({stride}), padding ({padding}), and kernel size ({kernelH}x{kernelW})."); - - if (kernel.Shape[1] != 1) - throw new ArgumentException( - $"Depthwise convolution requires kernel.Shape[1] == 1, but got {kernel.Shape[1]}"); - - if (kernel.Shape[0] != channels) - throw new ArgumentException( - $"Depthwise convolution requires kernel.Shape[0] == channels ({channels}), but got {kernel.Shape[0]}"); - - var output = new Tensor(new[] { batchSize, channels, outHeight, outWidth }); - - Parallel.For(0, batchSize * channels, idx => - { - int b = idx / channels; - int c = idx % channels; - - DepthwiseConv2DSingleChannel(input, kernel, output, b, c, - inHeight, inWidth, kernelH, kernelW, - stride, padding, outHeight, outWidth); - }); - - return output; - } - - private void DepthwiseConv2DSingleChannel( - Tensor input, Tensor kernel, Tensor output, - int batch, int channel, - int inHeight, int inWidth, int kernelH, int kernelW, - int stride, int padding, int outHeight, int outWidth) - { - var inputData = input.Data.Span; - var kernelData = kernel.Data.Span; - var outputData = output.Data.Span; - - for (int oh = 0; oh < outHeight; oh++) - { - for (int ow = 0; ow < outWidth; ow++) - { - float sum = 0.0f; - - for (int kh = 0; kh < kernelH; kh++) - { - for (int kw = 0; kw < kernelW; kw++) - { - int ih = oh * stride - padding + kh; - int iw = ow * stride - padding + kw; - - if (ih >= 0 && ih < inHeight && iw >= 0 && iw < inWidth) - { - int inputIdx = ((batch * input.Shape[1] + channel) * inHeight + ih) * inWidth + iw; - int kernelIdx = (channel * kernelH + kh) * kernelW + kw; - sum += inputData[inputIdx] * kernelData[kernelIdx]; - } - } - } - - int outputIdx = ((batch * output.Shape[1] + channel) * outHeight + oh) * outWidth + ow; - outputData[outputIdx] = sum; - } - } - } - - /// - /// Group convolution (reduces parameters and computation) - /// - public Tensor GroupConv2D( - Tensor input, - Tensor kernel, - int groups, - int stride = 1, - int padding = 0) - { - if (input.Shape.Length != 4 || kernel.Shape.Length != 4) - throw new ArgumentException("GroupConv2D requires 4D tensors"); - - int batchSize = input.Shape[0]; - int inChannels = input.Shape[1]; - int inHeight = input.Shape[2]; - int inWidth = input.Shape[3]; - - int outChannels = kernel.Shape[0]; - int kernelH = kernel.Shape[2]; - int kernelW = kernel.Shape[3]; - - if (groups <= 0) - throw new ArgumentOutOfRangeException(nameof(groups), "groups must be positive."); - - if (inChannels % groups != 0 || outChannels % groups != 0) - throw new ArgumentException("Channels must be divisible by groups"); - - int inChannelsPerGroup = inChannels / groups; - int outChannelsPerGroup = outChannels / groups; - - if (kernel.Shape[1] != inChannelsPerGroup) - throw new ArgumentException( - $"Group convolution requires kernel.Shape[1] == inChannelsPerGroup ({inChannelsPerGroup}), " + - $"but got {kernel.Shape[1]}"); - - int outHeight = (inHeight + 2 * padding - kernelH) / stride + 1; - int outWidth = (inWidth + 2 * padding - kernelW) / stride + 1; - - if (outHeight <= 0 || outWidth <= 0) - throw new ArgumentException( - $"Invalid output dimensions ({outHeight}x{outWidth}). " + - $"Check stride ({stride}), padding ({padding}), and kernel size ({kernelH}x{kernelW})."); - - var output = new Tensor(new[] { batchSize, outChannels, outHeight, outWidth }); - - // Process each group independently - Parallel.For(0, groups, g => - { - for (int b = 0; b < batchSize; b++) - { - for (int oc = 0; oc < outChannelsPerGroup; oc++) - { - int globalOutChannel = g * outChannelsPerGroup + oc; - - GroupConv2DSingleOutput(input, kernel, output, b, globalOutChannel, g, - inChannelsPerGroup, inHeight, inWidth, - kernelH, kernelW, stride, padding, - outHeight, outWidth); - } - } - }); - - return output; - } - - private void GroupConv2DSingleOutput( - Tensor input, Tensor kernel, Tensor output, - int batch, int outChannel, int group, - int inChannelsPerGroup, int inHeight, int inWidth, - int kernelH, int kernelW, int stride, int padding, - int outHeight, int outWidth) - { - int inChannelStart = group * inChannelsPerGroup; - var inputData = input.Data.Span; - var kernelData = kernel.Data.Span; - var outputData = output.Data.Span; - - for (int oh = 0; oh < outHeight; oh++) - { - for (int ow = 0; ow < outWidth; ow++) - { - float sum = 0.0f; - - for (int ic = 0; ic < inChannelsPerGroup; ic++) - { - int globalInChannel = inChannelStart + ic; - - for (int kh = 0; kh < kernelH; kh++) - { - for (int kw = 0; kw < kernelW; kw++) - { - int ih = oh * stride - padding + kh; - int iw = ow * stride - padding + kw; - - if (ih >= 0 && ih < inHeight && iw >= 0 && iw < inWidth) - { - int inputIdx = ((batch * input.Shape[1] + globalInChannel) * inHeight + ih) * inWidth + iw; - int kernelIdx = ((outChannel * inChannelsPerGroup + ic) * kernelH + kh) * kernelW + kw; - sum += inputData[inputIdx] * kernelData[kernelIdx]; - } - } - } - } - - int outputIdx = ((batch * output.Shape[1] + outChannel) * outHeight + oh) * outWidth + ow; - outputData[outputIdx] = sum; - } - } - } - } -} diff --git a/src/InferenceOptimization/Kernels/GemmKernel.cs b/src/InferenceOptimization/Kernels/GemmKernel.cs deleted file mode 100644 index 760daa11ac..0000000000 --- a/src/InferenceOptimization/Kernels/GemmKernel.cs +++ /dev/null @@ -1,181 +0,0 @@ -using System; -using System.Runtime.CompilerServices; -using System.Threading.Tasks; -using AiDotNet.LinearAlgebra; -using AiDotNet.Tensors.Engines; -using AiDotNet.Tensors.Engines.Simd; - -namespace AiDotNet.InferenceOptimization.Kernels -{ - /// - /// Optimized General Matrix Multiplication (GEMM) kernel - /// Implements cache-aware blocked matrix multiplication with SIMD - /// - public class GemmKernel : ICustomOperator - { - private const int BlockSize = 64; // Tuned for typical L1 cache - private const int MinParallelSize = 256; // Minimum size for parallel execution - - public string Name => "GEMM"; - public string Version => "1.0.0"; - public int Priority => 100; - - public bool IsSupported() - { - // GEMM is always supported, but performance varies by platform - return true; - } - - public double EstimatedSpeedup() - { - var caps = PlatformDetector.Capabilities; - if (caps.HasAVX2) return 3.0; - if (caps.HasSSE42) return 2.0; - if (caps.HasNeon) return 2.5; - return 1.5; - } - - public Tensor Execute(params Tensor[] inputs) - { - if (inputs == null || inputs.Length < 2) - throw new ArgumentException("GEMM requires at least 2 input tensors"); - - var a = inputs[0]; - var b = inputs[1]; - - if (a.Shape.Length != 2 || b.Shape.Length != 2) - throw new ArgumentException("GEMM requires 2D tensors (matrices)"); - - int m = a.Shape[0]; - int k = a.Shape[1]; - int n = b.Shape[1]; - - if (k != b.Shape[0]) - throw new ArgumentException($"Matrix dimensions incompatible: ({m}x{k}) * ({b.Shape[0]}x{n})"); - - var result = new Tensor(new[] { m, n }); - - // Choose strategy based on matrix size - if (m * n * k < MinParallelSize * MinParallelSize) - { - GemmBlocked(a.Data.Span, b.Data.Span, result.Data.Span, m, n, k); - } - else - { - GemmParallel(a.Data, b.Data, result.Data, m, n, k); - } - - return result; - } - - /// - /// Cache-blocked GEMM implementation - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void GemmBlocked(Span A, Span B, Span C, int M, int N, int K) - { - // Blocked algorithm for cache efficiency - for (int i = 0; i < M; i += BlockSize) - { - int iMax = Math.Min(i + BlockSize, M); - - for (int j = 0; j < N; j += BlockSize) - { - int jMax = Math.Min(j + BlockSize, N); - int spanLen = jMax - j; - - for (int k = 0; k < K; k += BlockSize) - { - int kMax = Math.Min(k + BlockSize, K); - - // Process block - for (int ii = i; ii < iMax; ii++) - { - for (int kk = k; kk < kMax; kk++) - { - float aVal = A[ii * K + kk]; - var bRow = B.Slice(kk * N + j, spanLen); - var cRow = C.Slice(ii * N + j, spanLen); - - // SIMD-optimized inner loop: cRow = cRow + aVal * bRow - SimdKernels.ScalarMultiplyAdd(cRow, bRow, aVal, cRow); - } - } - } - } - } - } - - /// - /// Parallel GEMM implementation for large matrices - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void GemmParallel(Memory A, Memory B, Memory C, int M, int N, int K) - { - // Parallelize over rows of A - Parallel.For(0, (M + BlockSize - 1) / BlockSize, iBlock => - { - int i = iBlock * BlockSize; - int iMax = Math.Min(i + BlockSize, M); - - // Convert Memory to Span inside lambda (ref structs can't be captured) - var aSpan = A.Span; - var bSpan = B.Span; - var cSpan = C.Span; - - for (int j = 0; j < N; j += BlockSize) - { - int jMax = Math.Min(j + BlockSize, N); - int spanLen = jMax - j; - - for (int k = 0; k < K; k += BlockSize) - { - int kMax = Math.Min(k + BlockSize, K); - - for (int ii = i; ii < iMax; ii++) - { - for (int kk = k; kk < kMax; kk++) - { - float aVal = aSpan[ii * K + kk]; - var bRow = bSpan.Slice(kk * N + j, spanLen); - var cRow = cSpan.Slice(ii * N + j, spanLen); - - SimdKernels.ScalarMultiplyAdd(cRow, bRow, aVal, cRow); - } - } - } - } - }); - } - - /// - /// Matrix multiplication with transpose B optimization (C = A * B^T) - /// - public Tensor GemmTransposeB(Tensor a, Tensor b) - { - if (a.Shape.Length != 2 || b.Shape.Length != 2) - throw new ArgumentException("GemmTransposeB requires 2D tensors"); - - int m = a.Shape[0]; - int k = a.Shape[1]; - int n = b.Shape[0]; // Note: B is transposed - - if (k != b.Shape[1]) - throw new ArgumentException("Matrix dimensions incompatible for transpose"); - - var result = new Tensor(new[] { m, n }); - - Parallel.For(0, m, i => - { - var rowA = a.Data.Span.Slice(i * k, k); - for (int j = 0; j < n; j++) - { - var rowB = b.Data.Span.Slice(j * k, k); - result.Data.Span[i * n + j] = SimdKernels.DotProduct(rowA, rowB); - } - }); - - return result; - } - } -} diff --git a/src/InferenceOptimization/OptimizationInitializer.cs b/src/InferenceOptimization/OptimizationInitializer.cs deleted file mode 100644 index 2ae7068aa5..0000000000 --- a/src/InferenceOptimization/OptimizationInitializer.cs +++ /dev/null @@ -1,108 +0,0 @@ -using System; -using AiDotNet.InferenceOptimization.Kernels; -using AiDotNet.Tensors.Engines; -using AiDotNet.Tensors.Engines.Optimization; - -namespace AiDotNet.InferenceOptimization -{ - /// - /// Initializes and registers all optimized kernels and operators - /// - public static class OptimizationInitializer - { - private static bool _initialized = false; - private static readonly object _lock = new object(); - - /// - /// Initializes the inference optimization system - /// - public static void Initialize(bool enableProfiling = false) - { - lock (_lock) - { - if (_initialized) - return; - - // Enable profiling if requested - PerformanceProfiler.Instance.Enabled = enableProfiling; - - // Register optimized kernels - RegisterKernels(); - - // Print platform capabilities - LogPlatformInfo(); - - _initialized = true; - } - } - - private static void RegisterKernels() - { - var registry = CustomOperatorRegistry.Instance; - - // Register GEMM kernel - registry.Register(new GemmKernel()); - - // Register Attention kernel - registry.Register(new AttentionKernel()); - - // Register Convolution kernel - registry.Register(new ConvolutionKernel()); - - // Future: Register GPU kernels when available - // if (PlatformDetector.Capabilities.HasCudaSupport) - // { - // registry.Register(new CudaGemmKernel()); - // registry.Register(new CudaConvolutionKernel()); - // } - } - - private static void LogPlatformInfo() - { - Console.WriteLine("=== AiDotNet Inference Optimization ==="); - Console.WriteLine(PlatformDetector.GetCapabilitiesDescription()); - Console.WriteLine(); - Console.WriteLine("Registered Operators:"); - - var operatorInfo = CustomOperatorRegistry.Instance.GetOperatorInfo(); - foreach (var kvp in operatorInfo) - { - Console.WriteLine($" {kvp.Key}:"); - foreach (var info in kvp.Value) - { - var status = info.IsSupported ? "✓" : "✗"; - Console.WriteLine($" {status} {info.Version} - Priority: {info.Priority}, Speedup: {info.EstimatedSpeedup:F1}x"); - } - } - Console.WriteLine(); - } - - /// - /// Gets a performance summary - /// - public static string GetPerformanceSummary() - { - if (!_initialized) - return "Optimization system not initialized."; - - var report = PerformanceProfiler.Instance.GenerateReport(); - return report; - } - - /// - /// Resets all profiling statistics - /// - public static void ResetStatistics() - { - PerformanceProfiler.Instance.Clear(); - } - - /// - /// Enables or disables profiling at runtime - /// - public static void SetProfilingEnabled(bool enabled) - { - PerformanceProfiler.Instance.Enabled = enabled; - } - } -} diff --git a/src/InferenceOptimization/Passes/AlgebraicSimplificationPass.cs b/src/InferenceOptimization/Passes/AlgebraicSimplificationPass.cs deleted file mode 100644 index 45d5cb263a..0000000000 --- a/src/InferenceOptimization/Passes/AlgebraicSimplificationPass.cs +++ /dev/null @@ -1,242 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Applies algebraic simplification rules to reduce computational complexity. -/// Examples: -/// - x * 1 = x -/// - x + 0 = x -/// - x * 0 = 0 -/// - x / 1 = x -/// - x^1 = x -/// - x^0 = 1 -/// -/// The numeric type (double, float, decimal) -public class AlgebraicSimplificationPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.AlgebraicSimplification; - public override string Name => "Algebraic Simplification"; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - bool changed; - - do - { - changed = false; - - var simplifiableOps = new HashSet - { - OperationType.Multiply, OperationType.Add, OperationType.Subtract, - OperationType.Divide, OperationType.Power - }; - foreach (var node in graph.Nodes.Where(n => simplifiableOps.Contains(n.OperationType)).ToList()) - { - if (TrySimplifyNode(graph, node)) - { - changed = true; - modified = true; - } - } - } while (changed); - - return modified; - } - - private bool TrySimplifyNode(IOptimizationGraph graph, OptimizationNode node) - { - return node.OperationType switch - { - OperationType.Multiply => SimplifyMultiply(graph, node), - OperationType.Add => SimplifyAdd(graph, node), - OperationType.Subtract => SimplifySubtract(graph, node), - OperationType.Divide => SimplifyDivide(graph, node), - OperationType.Power => SimplifyPower(graph, node), - _ => false - }; - } - - private bool SimplifyMultiply(IOptimizationGraph graph, OptimizationNode node) - { - if (node.Inputs.Count != 2) return false; - - var left = node.Inputs[0]; - var right = node.Inputs[1]; - - // x * 0 = 0 - if (IsZeroConstant(left) || IsZeroConstant(right)) - { - ReplaceWithConstant(graph, node, CreateZeroConstant(node)); - return true; - } - - // x * 1 = x - if (IsOneConstant(left)) - { - ReplaceWithNode(graph, node, right); - return true; - } - - if (IsOneConstant(right)) - { - ReplaceWithNode(graph, node, left); - return true; - } - - return false; - } - - private bool SimplifyAdd(IOptimizationGraph graph, OptimizationNode node) - { - if (node.Inputs.Count != 2) return false; - - var left = node.Inputs[0]; - var right = node.Inputs[1]; - - // x + 0 = x - if (IsZeroConstant(left)) - { - ReplaceWithNode(graph, node, right); - return true; - } - - if (IsZeroConstant(right)) - { - ReplaceWithNode(graph, node, left); - return true; - } - - return false; - } - - private bool SimplifySubtract(IOptimizationGraph graph, OptimizationNode node) - { - if (node.Inputs.Count != 2) return false; - - var right = node.Inputs[1]; - - // x - 0 = x - if (IsZeroConstant(right)) - { - ReplaceWithNode(graph, node, node.Inputs[0]); - return true; - } - - // x - x = 0 (if inputs are the same) - if (node.Inputs[0] == node.Inputs[1]) - { - ReplaceWithConstant(graph, node, CreateZeroConstant(node)); - return true; - } - - return false; - } - - private bool SimplifyDivide(IOptimizationGraph graph, OptimizationNode node) - { - if (node.Inputs.Count != 2) return false; - - var right = node.Inputs[1]; - - // x / 1 = x - if (IsOneConstant(right)) - { - ReplaceWithNode(graph, node, node.Inputs[0]); - return true; - } - - return false; - } - - private bool SimplifyPower(IOptimizationGraph graph, OptimizationNode node) - { - if (node.Inputs.Count != 2) return false; - - var right = node.Inputs[1]; - - // x^0 = 1 - if (IsZeroConstant(right)) - { - ReplaceWithConstant(graph, node, CreateOneConstant(node)); - return true; - } - - // x^1 = x - if (IsOneConstant(right)) - { - ReplaceWithNode(graph, node, node.Inputs[0]); - return true; - } - - return false; - } - - private bool IsZeroConstant(OptimizationNode node) - { - return node.OperationType == OperationType.Constant && - node.Metadata.TryGetValue("IsZero", out var isZero) && - (bool)isZero; - } - - private bool IsOneConstant(OptimizationNode node) - { - return node.OperationType == OperationType.Constant && - node.Metadata.TryGetValue("IsOne", out var isOne) && - (bool)isOne; - } - - private OptimizationNode CreateZeroConstant(OptimizationNode templateNode) - { - return new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const_zero", - OutputShape = templateNode.OutputShape, - Metadata = new Dictionary { ["IsZero"] = true } - }; - } - - private OptimizationNode CreateOneConstant(OptimizationNode templateNode) - { - return new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const_one", - OutputShape = templateNode.OutputShape, - Metadata = new Dictionary { ["IsOne"] = true } - }; - } - - private void ReplaceWithNode(IOptimizationGraph graph, OptimizationNode oldNode, OptimizationNode newNode) - { - // Replace all uses of oldNode with newNode - foreach (var output in oldNode.Outputs.ToList()) - { - output.ReplaceInput(oldNode, newNode); - } - - graph.RemoveNode(oldNode); - } - - private void ReplaceWithConstant(IOptimizationGraph graph, OptimizationNode oldNode, OptimizationNode constantNode) - { - // Add constant to graph - graph.AddNode(constantNode); - - // Replace uses - foreach (var output in oldNode.Outputs.ToList()) - { - output.ReplaceInput(oldNode, constantNode); - } - - graph.RemoveNode(oldNode); - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph); - } -} diff --git a/src/InferenceOptimization/Passes/CommonSubexpressionEliminationPass.cs b/src/InferenceOptimization/Passes/CommonSubexpressionEliminationPass.cs deleted file mode 100644 index 73779d593e..0000000000 --- a/src/InferenceOptimization/Passes/CommonSubexpressionEliminationPass.cs +++ /dev/null @@ -1,139 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Eliminates common subexpressions by sharing computation. -/// If the same operation with the same inputs appears multiple times, -/// compute it once and reuse the result. -/// -/// The numeric type (double, float, decimal) -public class CommonSubexpressionEliminationPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.CommonSubexpressionElimination; - public override string Name => "Common Subexpression Elimination"; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - // Build a signature for each node based on its operation and inputs - var signatureToNode = new Dictionary>(); - - foreach (var node in graph.GetTopologicalOrder()) - { - // Skip certain types that shouldn't be CSE'd - if (node.OperationType == OperationType.Input || - node.OperationType == OperationType.Output || - node.OperationType == OperationType.Dropout || // Non-deterministic - !node.CanEliminate) - { - continue; - } - - var signature = ComputeSignature(node); - - if (signatureToNode.TryGetValue(signature, out var existingNode)) - { - // Found a common subexpression! - // Replace all uses of this node with the existing one - foreach (var output in node.Outputs.ToList()) - { - output.ReplaceInput(node, existingNode); - } - - graph.RemoveNode(node); - modified = true; - } - else - { - signatureToNode[signature] = node; - } - } - - return modified; - } - - private string ComputeSignature(OptimizationNode node) - { - // Create a signature based on: - // 1. Operation type - // 2. Input node IDs (sorted only for commutative operations) - // 3. Key parameters - - var parts = new List - { - node.OperationType.ToString() - }; - - // For commutative operations (Add, Multiply), sort input IDs so a+b == b+a - // For non-commutative operations (Subtract, Divide, MatMul, Power), preserve order - var inputIds = IsCommutativeOperation(node.OperationType) - ? node.Inputs.Select(n => n.Id).OrderBy(id => id) - : node.Inputs.Select(n => n.Id); - parts.AddRange(inputIds); - - // Add key parameters (if any) - foreach (var param in node.Parameters.OrderBy(kv => kv.Key)) - { - parts.Add($"{param.Key}={param.Value}"); - } - - // Add key metadata (if any) - // Only include metadata that affects computation - foreach (var meta in node.Metadata.Where(kv => IsComputationalMetadata(kv.Key)).OrderBy(kv => kv.Key)) - { - parts.Add($"{meta.Key}={meta.Value}"); - } - - return string.Join("|", parts); - } - - private bool IsComputationalMetadata(string key) - { - // These metadata keys affect the computation and should be part of the signature - var computationalKeys = new HashSet - { - "stride", - "padding", - "kernel_size", - "dilation", - "groups", - "transpose", - "alpha", // For LeakyReLU, etc. - "beta" - }; - - return computationalKeys.Contains(key.ToLower()); - } - - /// - /// Determines if an operation is commutative (operand order doesn't matter). - /// For commutative operations like Add and Multiply, a+b == b+a, so input IDs can be sorted. - /// For non-commutative operations like Subtract and Divide, a-b != b-a, so order must be preserved. - /// - private static bool IsCommutativeOperation(OperationType opType) - { - return opType switch - { - // Commutative operations (order doesn't matter) - OperationType.Add => true, - OperationType.Multiply => true, - // Non-commutative operations (order matters) - OperationType.Subtract => false, - OperationType.Divide => false, - OperationType.Power => false, - OperationType.MatMul => false, - OperationType.Convolution2D => false, - OperationType.BatchNorm => false, - // Default to non-commutative (safer - preserves operand order) - _ => false - }; - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && graph.Nodes.Count > 1; - } -} diff --git a/src/InferenceOptimization/Passes/ConstantFoldingPass.cs b/src/InferenceOptimization/Passes/ConstantFoldingPass.cs deleted file mode 100644 index 0e2d130422..0000000000 --- a/src/InferenceOptimization/Passes/ConstantFoldingPass.cs +++ /dev/null @@ -1,555 +0,0 @@ -using AiDotNet.Engines; -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; -using AiDotNet.LinearAlgebra; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Folds constant expressions at compile time to reduce runtime computation. -/// For example: If two constants are multiplied, compute the result once during optimization -/// rather than every inference call. -/// -/// The numeric type (double, float, decimal) -public class ConstantFoldingPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.ConstantFolding; - public override string Name => "Constant Folding"; - - private static readonly HashSet FoldableOps = new() - { - OperationType.Add, - OperationType.Subtract, - OperationType.Multiply, - OperationType.Divide, - OperationType.Power, - OperationType.Sqrt, - OperationType.Exp, - OperationType.Log, - OperationType.MatMul - }; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - bool changed; - - // Keep folding until no more changes (iterative constant propagation) - do - { - changed = false; - - foreach (var node in graph.Nodes - .Where(n => FoldableOps.Contains(n.OperationType) && !n.IsFused) - .Where(n => n.Inputs.All(input => input.OperationType == OperationType.Constant)) - .ToList()) - { - if (TryFoldConstant(graph, node)) - { - changed = true; - modified = true; - } - } - } while (changed); - - return modified; - } - - /// - /// Attempts to fold a constant expression node into a single constant value. - /// - /// The optimization graph containing the node. - /// The operation node whose inputs are all constants. - /// - /// True if the operation was successfully folded into a constant node; - /// false if folding could not be performed. - /// - /// - /// - /// This method computes the result of the operation at optimization time and replaces - /// the operation node with a constant node containing the precomputed result. This - /// eliminates runtime computation for operations involving only constants. - /// - /// - /// The folding process: - /// - /// Compute the operation result using vectorized Engine operations - /// Create a new constant node with the computed result - /// Update all output connections to reference the new constant - /// Remove the original operation node from the graph - /// - /// - /// Thread Safety: This method modifies the graph structure and is not thread-safe. - /// - private bool TryFoldConstant(IOptimizationGraph graph, OptimizationNode node) - { - try - { - // Compute the constant result using vectorized Engine operations - Tensor? result = node.OperationType switch - { - OperationType.Add => FoldAdd(node), - OperationType.Subtract => FoldSubtract(node), - OperationType.Multiply => FoldMultiply(node), - OperationType.Divide => FoldDivide(node), - OperationType.Power => FoldPower(node), - OperationType.Sqrt => FoldSqrt(node), - OperationType.Exp => FoldExp(node), - OperationType.Log => FoldLog(node), - OperationType.MatMul => FoldMatMul(node), - _ => null - }; - - if (result == null) - { - return false; - } - - // Create a new constant node with the result - var constantNode = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = $"{node.Name}_folded", - OutputShape = node.OutputShape, - ConstantValue = result, - CanEliminate = false // Constants should not be eliminated - }; - - // Replace the operation node with the constant node - foreach (var output in node.Outputs.ToList()) - { - output.ReplaceInput(node, constantNode); - } - - // Add constant node and remove operation node - graph.AddNode(constantNode); - graph.RemoveNode(node); - - return true; - } - catch (InvalidOperationException) - { - // If folding fails due to invalid graph state, leave the node as is - return false; - } - catch (ArgumentException) - { - // If folding fails due to invalid arguments, leave the node as is - return false; - } - } - - /// - /// Folds an addition operation by computing the elementwise sum of two constant tensors. - /// - /// The optimization node representing the addition operation. - /// - /// A tensor containing the elementwise sum of the two input tensors, - /// or null if the operation cannot be folded. - /// - /// - /// - /// This method uses the Engine's vectorized TensorAdd operation for optimal performance. - /// The operation requires both input tensors to have identical shapes since broadcasting - /// is not supported during constant folding. - /// - /// Performance: Uses hardware-accelerated SIMD operations when available. - /// - private Tensor? FoldAdd(OptimizationNode node) - { - if (node.Inputs.Count != 2) return null; - - var left = node.Inputs[0].ConstantValue; - var right = node.Inputs[1].ConstantValue; - - if (left == null || right == null) return null; - - // Perform vectorized tensor addition using Engine operations - try - { - // Use elementwise addition - tensors must have compatible shapes - if (!left._shape.SequenceEqual(right._shape)) - { - // Shape mismatch - cannot fold without broadcasting support - return null; - } - - // Use Engine's vectorized addition for optimal performance - var engine = AiDotNetEngine.Current; - return engine.TensorAdd(left, right); - } - catch - { - // If tensor arithmetic fails, don't fold - return null; - } - } - - /// - /// Folds a subtraction operation by computing the elementwise difference of two constant tensors. - /// - /// The optimization node representing the subtraction operation. - /// - /// A tensor containing the elementwise difference of the two input tensors (left - right), - /// or null if the operation cannot be folded. - /// - /// - /// - /// This method uses the Engine's vectorized TensorSubtract operation for optimal performance. - /// The operation requires both input tensors to have identical shapes since broadcasting - /// is not supported during constant folding. - /// - /// Performance: Uses hardware-accelerated SIMD operations when available. - /// - private Tensor? FoldSubtract(OptimizationNode node) - { - if (node.Inputs.Count != 2) return null; - - var left = node.Inputs[0].ConstantValue; - var right = node.Inputs[1].ConstantValue; - - if (left == null || right == null) return null; - - // Perform vectorized tensor subtraction using Engine operations - try - { - if (!left._shape.SequenceEqual(right._shape)) - { - // Shape mismatch - cannot fold without broadcasting support - return null; - } - - // Use Engine's vectorized subtraction for optimal performance - var engine = AiDotNetEngine.Current; - return engine.TensorSubtract(left, right); - } - catch - { - // If tensor arithmetic fails, don't fold - return null; - } - } - - /// - /// Folds a multiplication operation by computing the elementwise product of two constant tensors. - /// - /// The optimization node representing the multiplication operation. - /// - /// A tensor containing the elementwise product (Hadamard product) of the two input tensors, - /// or null if the operation cannot be folded. - /// - /// - /// - /// This method uses the Engine's vectorized TensorMultiply operation for optimal performance. - /// The operation requires both input tensors to have identical shapes since broadcasting - /// is not supported during constant folding. - /// - /// Note: This performs elementwise multiplication (Hadamard product), not - /// matrix multiplication. For matrix multiplication, use . - /// Performance: Uses hardware-accelerated SIMD operations when available. - /// - private Tensor? FoldMultiply(OptimizationNode node) - { - if (node.Inputs.Count != 2) return null; - - var left = node.Inputs[0].ConstantValue; - var right = node.Inputs[1].ConstantValue; - - if (left == null || right == null) return null; - - // Perform vectorized tensor multiplication (elementwise) using Engine operations - try - { - if (!left._shape.SequenceEqual(right._shape)) - { - // Shape mismatch - cannot fold without broadcasting support - return null; - } - - // Use Engine's vectorized multiplication for optimal performance - var engine = AiDotNetEngine.Current; - return engine.TensorMultiply(left, right); - } - catch - { - // If tensor arithmetic fails, don't fold - return null; - } - } - - /// - /// Folds a division operation by computing the elementwise quotient of two constant tensors. - /// - /// The optimization node representing the division operation. - /// - /// A tensor containing the elementwise quotient of the two input tensors, - /// or null if the operation cannot be folded. - /// - /// - /// - /// This method uses the Engine's vectorized TensorDivide operation for optimal performance. - /// The operation requires both input tensors to have identical shapes since broadcasting - /// is not supported during constant folding. - /// - /// Performance: Uses hardware-accelerated SIMD operations when available. - /// - private Tensor? FoldDivide(OptimizationNode node) - { - if (node.Inputs.Count != 2) return null; - - var left = node.Inputs[0].ConstantValue; - var right = node.Inputs[1].ConstantValue; - - if (left == null || right == null) return null; - - // Perform vectorized tensor division using Engine operations - try - { - if (!left._shape.SequenceEqual(right._shape)) - { - // Shape mismatch - cannot fold without broadcasting support - return null; - } - - // Use Engine's vectorized division for optimal performance - var engine = AiDotNetEngine.Current; - return engine.TensorDivide(left, right); - } - catch - { - // If tensor arithmetic fails (e.g., division by zero), don't fold - return null; - } - } - - /// - /// Folds a power operation by computing the elementwise power of a constant tensor. - /// - /// The optimization node representing the power operation. - /// - /// A tensor containing the elementwise power result, or null if the operation cannot be folded. - /// - /// - /// - /// This method supports two modes: - /// - /// Single input with scalar exponent from node attributes - /// Two inputs: base tensor and exponent tensor (elementwise power) - /// - /// - /// Performance: Uses vectorized Engine operations for optimal performance. - /// - private Tensor? FoldPower(OptimizationNode node) - { - var engine = AiDotNetEngine.Current; - var numOps = MathHelper.GetNumericOperations(); - - try - { - if (node.Inputs.Count == 1) - { - // Unary power with scalar exponent from attributes - var baseValue = node.Inputs[0].ConstantValue; - if (baseValue == null) return null; - - // Get exponent from node metadata/parameters - if (node.Metadata.TryGetValue("exponent", out var expObj) && expObj is double expDouble) - { - var exponent = numOps.FromDouble(expDouble); - return engine.TensorPow(baseValue, exponent); - } - - // Default to square if no exponent specified - var two = numOps.FromDouble(2.0); - return engine.TensorPow(baseValue, two); - } - else if (node.Inputs.Count == 2) - { - // Binary power: base ^ exponent (elementwise) - var baseValue = node.Inputs[0].ConstantValue; - var exponentValue = node.Inputs[1].ConstantValue; - - if (baseValue == null || exponentValue == null) return null; - - if (!baseValue._shape.SequenceEqual(exponentValue._shape)) - { - return null; - } - - // Elementwise power using scalar operations - var result = new Tensor(baseValue._shape); - for (int i = 0; i < baseValue.Length; i++) - { - var b = numOps.ToDouble(baseValue[i]); - var e = numOps.ToDouble(exponentValue[i]); - result[i] = numOps.FromDouble(Math.Pow(b, e)); - } - return result; - } - - return null; - } - catch - { - return null; - } - } - - /// - /// Folds a square root operation by computing the elementwise square root of a constant tensor. - /// - /// The optimization node representing the square root operation. - /// - /// A tensor containing the elementwise square root, or null if the operation cannot be folded. - /// - /// - /// - /// This method uses the Engine's vectorized TensorSqrt operation for optimal performance. - /// Square root is a unary operation requiring exactly one input tensor. - /// - /// Performance: Uses hardware-accelerated SIMD operations when available. - /// - private Tensor? FoldSqrt(OptimizationNode node) - { - if (node.Inputs.Count != 1) return null; - - var input = node.Inputs[0].ConstantValue; - if (input == null) return null; - - try - { - var engine = AiDotNetEngine.Current; - return engine.TensorSqrt(input); - } - catch - { - // If sqrt fails (e.g., negative values), don't fold - return null; - } - } - - /// - /// Folds an exponential operation by computing the elementwise e^x of a constant tensor. - /// - /// The optimization node representing the exponential operation. - /// - /// A tensor containing the elementwise exponential (e^x), or null if the operation cannot be folded. - /// - /// - /// - /// This method uses the Engine's vectorized TensorExp operation for optimal performance. - /// Exponential is a unary operation requiring exactly one input tensor. - /// - /// Performance: Uses hardware-accelerated SIMD operations when available. - /// - private Tensor? FoldExp(OptimizationNode node) - { - if (node.Inputs.Count != 1) return null; - - var input = node.Inputs[0].ConstantValue; - if (input == null) return null; - - try - { - var engine = AiDotNetEngine.Current; - return engine.TensorExp(input); - } - catch - { - // If exp fails (e.g., overflow), don't fold - return null; - } - } - - /// - /// Folds a natural logarithm operation by computing the elementwise ln(x) of a constant tensor. - /// - /// The optimization node representing the logarithm operation. - /// - /// A tensor containing the elementwise natural logarithm, or null if the operation cannot be folded. - /// - /// - /// - /// This method uses the Engine's vectorized TensorLog operation for optimal performance. - /// Logarithm is a unary operation requiring exactly one input tensor. - /// - /// Performance: Uses hardware-accelerated SIMD operations when available. - /// - private Tensor? FoldLog(OptimizationNode node) - { - if (node.Inputs.Count != 1) return null; - - var input = node.Inputs[0].ConstantValue; - if (input == null) return null; - - try - { - var engine = AiDotNetEngine.Current; - return engine.TensorLog(input); - } - catch - { - // If log fails (e.g., non-positive values), don't fold - return null; - } - } - - /// - /// Folds a matrix multiplication operation by computing the product of two constant tensors. - /// - /// The optimization node representing the matrix multiplication operation. - /// - /// A tensor containing the matrix product of the two input tensors, - /// or null if the operation cannot be folded. - /// - /// - /// - /// This method uses the Engine's vectorized TensorMatMul operation for optimal performance. - /// Matrix multiplication requires 2D tensors with compatible dimensions: - /// left[M,K] × right[K,N] = result[M,N]. - /// - /// Performance: Uses optimized BLAS-like operations when available, - /// with cache-friendly memory access patterns and potential GPU acceleration. - /// - private Tensor? FoldMatMul(OptimizationNode node) - { - if (node.Inputs.Count != 2) return null; - - var left = node.Inputs[0].ConstantValue; - var right = node.Inputs[1].ConstantValue; - - if (left == null || right == null) return null; - - // Matrix multiplication requires 2D tensors with compatible dimensions - try - { - if (left.Shape.Length != 2 || right.Shape.Length != 2) - { - return null; - } - - int k = left.Shape[1]; - - // Check dimension compatibility: left[m,k] @ right[k,n] = result[m,n] - if (k != right.Shape[0]) - { - return null; - } - - // Use Engine's vectorized matrix multiplication for optimal performance - var engine = AiDotNetEngine.Current; - return engine.TensorMatMul(left, right); - } - catch - { - // If matrix multiplication fails, don't fold - return null; - } - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && - graph.Nodes.Any(n => n.OperationType == OperationType.Constant); - } -} diff --git a/src/InferenceOptimization/Passes/ConvBatchNormFusionPass.cs b/src/InferenceOptimization/Passes/ConvBatchNormFusionPass.cs deleted file mode 100644 index 414dfc592d..0000000000 --- a/src/InferenceOptimization/Passes/ConvBatchNormFusionPass.cs +++ /dev/null @@ -1,62 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Fuses Convolution + BatchNormalization into a single operation. -/// This is a critical optimization that eliminates the normalization overhead during inference. -/// -/// The numeric type (double, float, decimal) -public class ConvBatchNormFusionPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.ConvBatchNormFusion; - public override string Name => "Conv + BatchNorm Fusion"; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - // Find Conv -> BatchNorm patterns - var candidates = FindFusionCandidates( - graph, - OperationType.Convolution, - OperationType.BatchNormalization - ); - - // Also check for Convolution2D variant - candidates.AddRange(FindFusionCandidates( - graph, - OperationType.Convolution2D, - OperationType.BatchNormalization - )); - - foreach (var sequence in candidates) - { - FuseConvBatchNorm(graph, sequence); - modified = true; - } - - return modified; - } - - private void FuseConvBatchNorm(IOptimizationGraph graph, List> nodes) - { - // Create fused node - var fusedNode = FuseNodes(graph, nodes, OperationType.FusedConvBatchNorm); - - // Note: In a real implementation, you would fold the BatchNorm parameters - // (mean, variance, gamma, beta) into the convolution weights and bias. - // This requires numerical computation which would be done at optimization time. - - // Mark for special handling during execution - fusedNode.Metadata["RequiresWeightFolding"] = true; - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && - graph.Nodes.Any(n => n.OperationType == OperationType.Convolution || - n.OperationType == OperationType.Convolution2D); - } -} diff --git a/src/InferenceOptimization/Passes/ConvBatchNormReLUFusionPass.cs b/src/InferenceOptimization/Passes/ConvBatchNormReLUFusionPass.cs deleted file mode 100644 index ef663cd05c..0000000000 --- a/src/InferenceOptimization/Passes/ConvBatchNormReLUFusionPass.cs +++ /dev/null @@ -1,76 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Fuses Convolution + BatchNormalization + ReLU into a single operation. -/// This is one of the most common patterns in CNNs (ResNet, VGG, etc.) and provides -/// significant speedup by reducing memory traffic and kernel launches. -/// -/// The numeric type (double, float, decimal) -public class ConvBatchNormReLUFusionPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.ConvBatchNormReLUFusion; - public override string Name => "Conv + BatchNorm + ReLU Fusion"; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - // Find Conv -> BatchNorm -> ReLU patterns - var candidates = FindFusionCandidates( - graph, - OperationType.Convolution, - OperationType.BatchNormalization, - OperationType.ReLU - ); - - // Also check for Convolution2D variant - candidates.AddRange(FindFusionCandidates( - graph, - OperationType.Convolution2D, - OperationType.BatchNormalization, - OperationType.ReLU - )); - - // Also check for LeakyReLU and other ReLU variants - candidates.AddRange(FindFusionCandidates( - graph, - OperationType.Convolution, - OperationType.BatchNormalization, - OperationType.LeakyReLU - )); - - foreach (var sequence in candidates) - { - FuseConvBatchNormReLU(graph, sequence); - modified = true; - } - - return modified; - } - - private void FuseConvBatchNormReLU(IOptimizationGraph graph, List> nodes) - { - // Create fused node - var fusedNode = FuseNodes(graph, nodes, OperationType.FusedConvBatchNormReLU); - - // This fusion provides maximum benefit: - // 1. Fold BatchNorm into Conv weights - // 2. Apply ReLU in-place - // 3. Single kernel launch instead of 3 - fusedNode.Metadata["RequiresWeightFolding"] = true; - fusedNode.CanOperateInPlace = true; // ReLU can be in-place - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && - graph.Nodes.Any(n => n.OperationType == OperationType.Convolution || - n.OperationType == OperationType.Convolution2D) && - graph.Nodes.Any(n => n.OperationType == OperationType.BatchNormalization) && - graph.Nodes.Any(n => n.OperationType == OperationType.ReLU || - n.OperationType == OperationType.LeakyReLU); - } -} diff --git a/src/InferenceOptimization/Passes/DeadCodeEliminationPass.cs b/src/InferenceOptimization/Passes/DeadCodeEliminationPass.cs deleted file mode 100644 index 5cc901d4eb..0000000000 --- a/src/InferenceOptimization/Passes/DeadCodeEliminationPass.cs +++ /dev/null @@ -1,89 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Eliminates nodes that don't contribute to the output (dead code). -/// This includes nodes with no consumers and nodes that are not reachable from outputs. -/// -/// The numeric type (double, float, decimal) -public class DeadCodeEliminationPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.DeadCodeElimination; - public override string Name => "Dead Code Elimination"; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - // Mark all nodes reachable from outputs - var reachable = MarkReachableNodes(graph); - - // Remove unreachable nodes - var nodesToRemove = graph.Nodes - .Where(n => !reachable.Contains(n) && - n.OperationType != OperationType.Input && - n.OperationType != OperationType.Output && - n.CanEliminate) - .ToList(); - - foreach (var node in nodesToRemove) - { - graph.RemoveNode(node); - modified = true; - } - - // Also remove nodes with no consumers (unless they're outputs) - var noConsumerNodes = graph.Nodes - .Where(n => n.Outputs.Count == 0 && - n.OperationType != OperationType.Output && - n.CanEliminate) - .ToList(); - - foreach (var node in noConsumerNodes) - { - graph.RemoveNode(node); - modified = true; - } - - return modified; - } - - private HashSet> MarkReachableNodes(IOptimizationGraph graph) - { - var reachable = new HashSet>(); - var queue = new Queue>(); - - // Start from output nodes and work backwards - foreach (var output in graph.OutputNodes) - { - queue.Enqueue(output); - } - - while (queue.Count > 0) - { - var node = queue.Dequeue(); - - if (reachable.Contains(node)) - { - continue; - } - - reachable.Add(node); - - // Add all inputs to the queue - foreach (var input in node.Inputs.Where(i => !reachable.Contains(i))) - { - queue.Enqueue(input); - } - } - - return reachable; - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && graph.OutputNodes.Count > 0; - } -} diff --git a/src/InferenceOptimization/Passes/ElementwiseFusionPass.cs b/src/InferenceOptimization/Passes/ElementwiseFusionPass.cs deleted file mode 100644 index 451922d56d..0000000000 --- a/src/InferenceOptimization/Passes/ElementwiseFusionPass.cs +++ /dev/null @@ -1,248 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Fuses consecutive elementwise operations into a single operation. -/// For example: (x + y) * z can be computed in a single fused kernel. -/// This reduces memory bandwidth by avoiding intermediate results. -/// -/// The numeric type (double, float, decimal) -/// -/// -/// Elementwise fusion is a key optimization for neural network inference that reduces memory -/// bandwidth requirements. Instead of computing each elementwise operation separately and -/// writing intermediate results to memory, fused operations process data in a single pass. -/// -/// Fusion Strategy: -/// -/// Only linear chains are fused (single consumer at each step) -/// Chains must start from "head" nodes (not outputs of other fuseable ops) -/// Disjoint chains are processed to avoid overlapping fusion attempts -/// -/// Supported Operations: Add, Subtract, Multiply, Divide, Power, Sqrt, Exp, Log, ReLU, Sigmoid, Tanh -/// Performance Impact: Typically 2-3x speedup for memory-bound elementwise sequences. -/// -public class ElementwiseFusionPass : OptimizationPassBase where T : struct -{ - /// - public override OptimizationPassType PassType => OptimizationPassType.ElementwiseFusion; - - /// - public override string Name => "Elementwise Operation Fusion"; - - /// - /// The set of operation types that can be fused in an elementwise chain. - /// - private static readonly HashSet ElementwiseOps = new() - { - OperationType.Add, - OperationType.Subtract, - OperationType.Multiply, - OperationType.Divide, - OperationType.Power, - OperationType.Sqrt, - OperationType.Exp, - OperationType.Log, - OperationType.ReLU, - OperationType.Sigmoid, - OperationType.Tanh - }; - - /// - /// Applies elementwise fusion optimization to the graph. - /// - /// The optimization graph to transform. - /// True if any operations were fused; false otherwise. - /// - /// - /// The algorithm identifies disjoint chains of elementwise operations by: - /// - /// Finding chain "heads" - nodes that are not single-consumer outputs of other elementwise ops - /// Building chains forward from each head, tracking visited nodes to ensure disjointness - /// Fusing chains with 2+ operations into single fused nodes - /// - /// - /// Thread Safety: This method modifies the graph structure and is not thread-safe. - /// - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - // Track visited nodes to ensure disjoint chains - var visited = new HashSet>(); - - // Get all unfused elementwise candidates - var candidates = graph.Nodes - .Where(n => ElementwiseOps.Contains(n.OperationType) && !n.IsFused) - .ToList(); - - foreach (var node in candidates) - { - // Skip nodes already processed in another chain - if (visited.Contains(node)) continue; - - // Only start chains from "head" nodes to avoid overlapping chains - if (!IsChainHead(node)) continue; - - var chain = FindElementwiseChain(node, visited); - if (chain.Count >= 2) - { - FuseElementwiseChain(graph, chain); - modified = true; - } - } - - return modified; - } - - /// - /// Determines whether a node is a valid chain head for elementwise fusion. - /// - /// The node to check. - /// True if the node can be the start of a fusion chain; false otherwise. - /// - /// - /// A node is a chain head if it is NOT the single-consumer output of another unfused - /// elementwise operation. This ensures we start chains at their natural beginning - /// rather than from the middle, preventing overlapping fusion attempts. - /// - /// - private bool IsChainHead(OptimizationNode node) - { - // Head if it's not the single-consumer output of another unfused elementwise op - return !node.Inputs.Any(input => - ElementwiseOps.Contains(input.OperationType) && - !input.IsFused && - input.Outputs.Count == 1 && - input.Outputs[0] == node); - } - - /// - /// Discovers a chain of consecutive elementwise operations starting from the given node. - /// - /// The starting node for chain discovery. - /// Set of already-visited nodes to ensure disjoint chains. - /// A list of nodes forming the elementwise chain, starting with the head. - /// - /// - /// Chain discovery follows these rules: - /// - /// Only single-output nodes can extend the chain - /// The next node must be an unfused elementwise operation - /// The next node must have the current node in its inputs - /// Already-visited nodes cannot be added to the chain - /// - /// - /// - private List> FindElementwiseChain( - OptimizationNode startNode, - HashSet> visited) - { - var chain = new List> { startNode }; - var current = startNode; - visited.Add(current); - - // Follow the chain forward through single-consumer elementwise ops - while (current.Outputs.Count == 1) - { - var next = current.Outputs[0]; - - // Stop if not a simple single-consumer elementwise progression - if (!ElementwiseOps.Contains(next.OperationType) || - next.IsFused || - visited.Contains(next) || - !next.Inputs.Contains(current)) - { - break; - } - - chain.Add(next); - current = next; - visited.Add(current); - } - - return chain; - } - - /// - /// Fuses a chain of elementwise operations into a single fused node. - /// - /// The optimization graph being transformed. - /// The chain of nodes to fuse, ordered from head to tail. - /// - /// - /// The fusion process: - /// - /// Create a new fused node with combined operation sequence - /// Connect all external inputs to the fused node - /// Redirect all outputs of the chain tail to the fused node - /// Remove all original chain nodes from the graph - /// - /// - /// Metadata: The fused node stores the operation sequence in its metadata - /// under the "OperationSequence" key for code generation. - /// - private void FuseElementwiseChain(IOptimizationGraph graph, List> chain) - { - var firstNode = chain[0]; - var lastNode = chain[^1]; - - // Create fused elementwise node - var fusedNode = new OptimizationNode - { - OperationType = OperationType.Custom, - Name = $"{firstNode.Name}_elementwise_fused", - OutputShape = lastNode.OutputShape, - IsFused = true, - CanOperateInPlace = chain.All(n => n.CanOperateInPlace), - FusedFrom = new List>(chain) - }; - - // Store the operation sequence for code generation - fusedNode.Metadata["OperationSequence"] = chain.Select(n => n.OperationType).ToList(); - - // Collect all unique inputs from the chain that are not part of the chain itself - var externalInputs = new HashSet>(chain - .SelectMany(node => node.Inputs) - .Where(input => !chain.Contains(input))); - - // Connect external inputs to the fused node - foreach (var input in externalInputs) - { - fusedNode.AddInput(input); - - // Remove connections from external inputs to chain nodes - foreach (var chainNode in chain) - { - input.Outputs.Remove(chainNode); - } - } - - // Connect outputs - use ToList() to avoid collection modification during iteration - foreach (var output in lastNode.Outputs.ToList()) - { - output.ReplaceInput(lastNode, fusedNode); - } - - // Add fused node to graph - graph.AddNode(fusedNode); - - // Remove original chain nodes from graph - foreach (var node in chain) - { - // Clear internal chain connections before removal to avoid dangling references - node.Inputs.Clear(); - node.Outputs.Clear(); - graph.RemoveNode(node); - } - } - - /// - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && - graph.Nodes.Any(n => ElementwiseOps.Contains(n.OperationType)); - } -} diff --git a/src/InferenceOptimization/Passes/IOptimizationPass.cs b/src/InferenceOptimization/Passes/IOptimizationPass.cs deleted file mode 100644 index fe43ccaf70..0000000000 --- a/src/InferenceOptimization/Passes/IOptimizationPass.cs +++ /dev/null @@ -1,33 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Interface for optimization passes that transform computation graphs to improve inference performance. -/// -/// The numeric type (double, float, decimal) -public interface IOptimizationPass where T : struct -{ - /// - /// The type of optimization pass. - /// - OptimizationPassType PassType { get; } - - /// - /// The name of this optimization pass. - /// - string Name { get; } - - /// - /// Applies the optimization pass to the computation graph. - /// - /// The computation graph to optimize - /// True if the graph was modified, false otherwise - bool Apply(IOptimizationGraph graph); - - /// - /// Checks if this pass can be applied to the graph. - /// - bool CanApply(IOptimizationGraph graph); -} diff --git a/src/InferenceOptimization/Passes/InPlaceOptimizationPass.cs b/src/InferenceOptimization/Passes/InPlaceOptimizationPass.cs deleted file mode 100644 index f601c777c4..0000000000 --- a/src/InferenceOptimization/Passes/InPlaceOptimizationPass.cs +++ /dev/null @@ -1,95 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Marks operations that can be performed in-place to reduce memory allocation. -/// Operations like ReLU, Dropout, and some elementwise operations can modify -/// their input tensors directly instead of allocating new memory. -/// -/// The numeric type (double, float, decimal) -public class InPlaceOptimizationPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.InPlaceOptimization; - public override string Name => "In-Place Operation Optimization"; - - private static readonly HashSet InPlaceCandidates = new() - { - OperationType.ReLU, - OperationType.LeakyReLU, - OperationType.ELU, - OperationType.SELU, - OperationType.Sigmoid, - OperationType.Tanh, - OperationType.Dropout, - OperationType.Add, // Can be in-place if one input won't be used again - OperationType.Multiply, // Same as Add - OperationType.Clip - }; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - foreach (var node in graph.Nodes.Where(n => InPlaceCandidates.Contains(n.OperationType) && CanBeInPlace(n))) - { - node.CanOperateInPlace = true; - node.Metadata["InPlaceOptimized"] = true; - modified = true; - } - - return modified; - } - - private bool CanBeInPlace(OptimizationNode node) - { - // Check if this operation can safely be performed in-place - // An operation can be in-place if: - // 1. It has exactly one input (for unary operations) - // 2. OR for binary operations, one of the inputs is not used elsewhere - - if (IsUnaryOperation(node.OperationType)) - { - // For unary operations, check if input has only one consumer (this node) - if (node.Inputs.Count == 1) - { - var input = node.Inputs[0]; - return input.Outputs.Count == 1; // Only this node uses the input - } - } - else if (IsBinaryOperation(node.OperationType)) - { - // For binary operations, at least one input should have only this node as consumer - return node.Inputs.Any(input => input.Outputs.Count == 1); - } - - return false; - } - - private bool IsUnaryOperation(OperationType opType) - { - return opType == OperationType.ReLU || - opType == OperationType.LeakyReLU || - opType == OperationType.ELU || - opType == OperationType.SELU || - opType == OperationType.Sigmoid || - opType == OperationType.Tanh || - opType == OperationType.Dropout || - opType == OperationType.Clip; - } - - private bool IsBinaryOperation(OperationType opType) - { - return opType == OperationType.Add || - opType == OperationType.Multiply || - opType == OperationType.Subtract || - opType == OperationType.Divide; - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && - graph.Nodes.Any(n => InPlaceCandidates.Contains(n.OperationType)); - } -} diff --git a/src/InferenceOptimization/Passes/LayoutOptimizationPass.cs b/src/InferenceOptimization/Passes/LayoutOptimizationPass.cs deleted file mode 100644 index bf01629f49..0000000000 --- a/src/InferenceOptimization/Passes/LayoutOptimizationPass.cs +++ /dev/null @@ -1,355 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Optimizes tensor layout (NCHW vs NHWC) for better hardware utilization. -/// Different hardware architectures prefer different layouts: -/// - NCHW: Better for GPUs (NVIDIA) -/// - NHWC: Better for some CPUs and TPUs -/// -/// The numeric type (double, float, decimal) -/// -/// -/// Tensor layout optimization is crucial for achieving peak performance on different hardware. -/// The layout determines how tensor dimensions are stored in memory: -/// -/// NCHW: Batch-Channel-Height-Width (preferred by NVIDIA GPUs, cuDNN) -/// NHWC: Batch-Height-Width-Channel (preferred by CPUs, TPUs, some mobile GPUs) -/// -/// -/// How It Works: This pass analyzes the graph to identify nodes with different -/// layout preferences, then inserts transpose operations at layout boundaries to ensure -/// optimal memory access patterns for each operation. -/// Performance Impact: Proper layout selection can yield 20-50% speedup -/// for memory-bound convolution operations. -/// -public class LayoutOptimizationPass : OptimizationPassBase where T : struct -{ - /// - public override OptimizationPassType PassType => OptimizationPassType.LayoutOptimization; - - /// - public override string Name => "Layout Optimization"; - - /// - /// The target memory layout to optimize for. - /// - private readonly string _targetLayout; - - /// - /// Operations that prefer channel-first layout (NCHW). - /// - private static readonly HashSet ChannelFirstOps = new() - { - OperationType.Convolution, - OperationType.Convolution2D, - OperationType.Conv2D, - OperationType.BatchNormalization, - OperationType.MaxPooling, - OperationType.AveragePooling - }; - - /// - /// The set of supported memory layouts for tensor operations. - /// - private static readonly HashSet SupportedLayouts = new() { "NCHW", "NHWC" }; - - /// - /// Initializes a new instance of the class. - /// - /// The target layout to optimize for. Default is "NCHW". - /// - /// Thrown when the specified targetLayout is not supported. - /// - /// - /// Supported layouts: - /// - /// "NCHW" - Batch-Channel-Height-Width (preferred by NVIDIA GPUs) - /// "NHWC" - Batch-Height-Width-Channel (preferred by CPUs/TPUs) - /// - /// - public LayoutOptimizationPass(string targetLayout = "NCHW") - { - if (!SupportedLayouts.Contains(targetLayout)) - { - throw new ArgumentException( - $"Unsupported layout '{targetLayout}'. Supported layouts: {string.Join(", ", SupportedLayouts)}.", - nameof(targetLayout)); - } - _targetLayout = targetLayout; - } - - /// - /// Applies layout optimization to the graph by inserting transpose operations at layout boundaries. - /// - /// The optimization graph to transform. - /// True if any transpose operations were inserted; false otherwise. - /// - /// - /// The optimization process: - /// - /// Analyze each node's preferred layout based on operation type - /// Identify layout mismatches between connected nodes - /// Insert transpose operations to convert between layouts - /// - /// - /// Thread Safety: This method modifies the graph structure and is not thread-safe. - /// - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - // Analyze the graph to determine optimal layout for each node - var layoutInfo = AnalyzeLayouts(graph); - - // Find nodes that require layout conversion due to mismatched inputs - var nodesToConvert = graph.Nodes - .Where(n => RequiresLayoutConversion(n, layoutInfo)) - .ToList(); - - // Insert transpose operations where needed - foreach (var node in nodesToConvert) - { - if (InsertLayoutConversion(graph, node, layoutInfo)) - { - modified = true; - } - } - - return modified; - } - - /// - /// Analyzes the graph to determine the preferred layout for each node. - /// - /// The optimization graph to analyze. - /// A dictionary mapping each node to its preferred layout string. - private Dictionary, string> AnalyzeLayouts(IOptimizationGraph graph) - { - var layouts = new Dictionary, string>(); - - foreach (var node in graph.Nodes) - { - // Determine preferred layout for this operation - var preferredLayout = GetPreferredLayout(node.OperationType); - layouts[node] = preferredLayout; - } - - return layouts; - } - - /// - /// Gets the intrinsic preferred memory layout for a given operation type. - /// - /// The operation type to check. - /// The intrinsic preferred layout string ("NCHW", "NHWC", or "AGNOSTIC"). - /// - /// - /// This method returns the operation's intrinsic layout preference, which is independent - /// of the target hardware layout. Operations like convolution, batch normalization, and - /// pooling inherently prefer NCHW layout due to their memory access patterns. - /// - /// - /// The pass then inserts layout conversions at boundaries between operations with - /// different intrinsic preferences and the target layout. - /// - /// - private string GetPreferredLayout(OperationType opType) - { - // Channel-first operations (Conv, BatchNorm, Pooling) intrinsically prefer NCHW - // regardless of the target hardware layout - if (ChannelFirstOps.Contains(opType)) - { - return "NCHW"; - } - - // Most other operations are layout-agnostic - return "AGNOSTIC"; - } - - /// - /// Determines whether a node requires layout conversion based on mismatched input layouts. - /// - /// The node to check. - /// Dictionary of node layout preferences. - /// True if the node requires layout conversion for any of its inputs. - private bool RequiresLayoutConversion(OptimizationNode node, Dictionary, string> layoutInfo) - { - // Skip if this node has no layout preference - if (!layoutInfo.TryGetValue(node, out var nodeLayout) || nodeLayout == "AGNOSTIC") - { - return false; - } - - // Check if any input has a different non-agnostic layout - return node.Inputs.Any(input => - layoutInfo.TryGetValue(input, out var inputLayout) && - inputLayout != "AGNOSTIC" && - inputLayout != nodeLayout); - } - - /// - /// Inserts transpose operations between mismatched layout boundaries. - /// - /// The optimization graph to modify. - /// The node requiring layout conversion. - /// Dictionary of node layout preferences. - /// True if any transpose operations were inserted; false otherwise. - /// - /// - /// This method inserts a transpose node between each input with a mismatched layout - /// and the target node. The transpose node converts the tensor from the input's - /// layout to the target node's preferred layout. - /// - /// - /// For NCHW to NHWC conversion, the permutation is [0, 2, 3, 1]. - /// For NHWC to NCHW conversion, the permutation is [0, 3, 1, 2]. - /// - /// - private bool InsertLayoutConversion( - IOptimizationGraph graph, - OptimizationNode node, - Dictionary, string> layoutInfo) - { - bool inserted = false; - - if (!layoutInfo.TryGetValue(node, out var nodeLayout)) - { - return false; - } - - // Find all inputs with mismatched layouts - var inputsToConvert = node.Inputs - .Where(input => - layoutInfo.TryGetValue(input, out var inputLayout) && - inputLayout != "AGNOSTIC" && - inputLayout != nodeLayout) - .ToList(); - - foreach (var input in inputsToConvert) - { - var inputLayout = layoutInfo[input]; - - // Create transpose node for layout conversion - var transposeNode = new OptimizationNode - { - OperationType = OperationType.Transpose, - Name = $"{input.Name}_to_{nodeLayout}", - OutputShape = ComputeTransposedShape(input.OutputShape, inputLayout, nodeLayout), - Metadata = new Dictionary - { - ["LayoutConversion"] = true, - ["SourceLayout"] = inputLayout, - ["TargetLayout"] = nodeLayout, - ["Permutation"] = GetLayoutPermutation(inputLayout, nodeLayout) - } - }; - - // Wire up the transpose node: - // 1. Transpose receives input from the original source - transposeNode.AddInput(input); - - // 2. Replace the input on the target node with the transpose output - // Note: ReplaceInput handles the edge removal internally via oldInput.Outputs.Remove(this) - node.ReplaceInput(input, transposeNode); - - // 4. Add the transpose node to the graph - graph.AddNode(transposeNode); - - inserted = true; - } - - return inserted; - } - - /// - /// Computes the output shape after layout transposition. - /// - /// The input tensor shape. - /// The source layout. - /// The target layout. - /// The transposed output shape, or the original shape if not exactly 4D. - /// - /// - /// Layout conversion (NCHW ↔ NHWC) only applies to exactly 4D tensors. - /// Non-4D tensors are returned unchanged since they don't follow the - /// NCHW/NHWC layout conventions. - /// - /// - private int[] ComputeTransposedShape(int[] inputShape, string sourceLayout, string targetLayout) - { - // Layout conversion only applies to exactly 4D tensors (NCHW/NHWC format) - // Non-4D tensors are returned unchanged - if (inputShape == null || inputShape.Length != 4) - { - return inputShape ?? Array.Empty(); - } - - var permutation = GetLayoutPermutation(sourceLayout, targetLayout); - var outputShape = new int[4]; - - for (int i = 0; i < 4; i++) - { - outputShape[i] = inputShape[permutation[i]]; - } - - return outputShape; - } - - /// - /// Gets the permutation array for converting between layouts. - /// - /// The source layout. - /// The target layout. - /// The permutation array for the transpose operation. - /// - /// Thrown when the layout conversion is not supported (different layouts that aren't NCHW/NHWC). - /// - /// - /// Supported conversions: - /// - /// NCHW → NHWC: permutation [0, 2, 3, 1] - /// NHWC → NCHW: permutation [0, 3, 1, 2] - /// Same layout: identity permutation [0, 1, 2, 3] - /// - /// - private int[] GetLayoutPermutation(string sourceLayout, string targetLayout) - { - // NCHW indices: N=0, C=1, H=2, W=3 - // NHWC indices: N=0, H=1, W=2, C=3 - - if (sourceLayout == targetLayout) - { - // Same layout - identity permutation (no-op) - return new[] { 0, 1, 2, 3 }; - } - - if (sourceLayout == "NCHW" && targetLayout == "NHWC") - { - // NCHW -> NHWC: Move C from position 1 to position 3 - return new[] { 0, 2, 3, 1 }; - } - - if (sourceLayout == "NHWC" && targetLayout == "NCHW") - { - // NHWC -> NCHW: Move C from position 3 to position 1 - return new[] { 0, 3, 1, 2 }; - } - - // Unsupported conversion - fail fast to catch configuration errors - throw new NotSupportedException( - $"Layout conversion from '{sourceLayout}' to '{targetLayout}' is not supported. " + - "Supported conversions: NCHW <-> NHWC."); - } - - /// - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && - graph.Nodes.Any(n => n.OperationType == OperationType.Convolution || - n.OperationType == OperationType.Convolution2D || - n.OperationType == OperationType.Conv2D); - } -} diff --git a/src/InferenceOptimization/Passes/MatMulBiasActivationFusionPass.cs b/src/InferenceOptimization/Passes/MatMulBiasActivationFusionPass.cs deleted file mode 100644 index 8ae7f0270e..0000000000 --- a/src/InferenceOptimization/Passes/MatMulBiasActivationFusionPass.cs +++ /dev/null @@ -1,240 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Fuses MatMul + Bias + Activation (ReLU, GELU, etc.) into a single operation. -/// This is the most common pattern in transformer feed-forward networks and MLPs. -/// -/// The numeric type (double, float, decimal) -/// -/// -/// Matrix multiplication followed by bias addition and activation is the fundamental -/// building block of neural networks. Fusing these operations provides significant -/// performance benefits by: -/// -/// Reducing memory bandwidth (no intermediate tensor writes) -/// Enabling hardware-optimized fused kernels (cuBLAS, oneDNN) -/// Reducing kernel launch overhead -/// -/// -/// Fusion Patterns: -/// -/// MatMul + Add (bias) + ReLU → FusedMatMulBiasReLU -/// MatMul + Add (bias) + GELU → FusedMatMulBiasGELU -/// FusedMatMulBias + Activation → FusedMatMulBias{Activation} -/// -/// Performance Impact: Typically 30-50% speedup for transformer feed-forward layers. -/// -public class MatMulBiasActivationFusionPass : OptimizationPassBase where T : struct -{ - /// - public override OptimizationPassType PassType => OptimizationPassType.MatMulBiasActivationFusion; - - /// - public override string Name => "MatMul + Bias + Activation Fusion"; - - private static readonly HashSet SupportedActivations = new() - { - OperationType.ReLU, - OperationType.GELU, - OperationType.LeakyReLU, - OperationType.Tanh, - OperationType.Sigmoid, - OperationType.Swish, - OperationType.Mish - }; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - // Look for already fused MatMulBias nodes followed by activation - foreach (var fusedNode in graph.Nodes.Where(n => - n.OperationType == OperationType.FusedMatMulBias).ToList()) - { - if (fusedNode.Outputs.Count == 1) - { - var activationNode = fusedNode.Outputs[0]; - - if (SupportedActivations.Contains(activationNode.OperationType) && - activationNode.Inputs.Count == 1) - { - FuseMatMulBiasActivation(graph, fusedNode, activationNode); - modified = true; - } - } - } - - // Also look for unfused MatMul -> Add -> Activation - foreach (var matmulNode in graph.Nodes.Where(n => - (n.OperationType == OperationType.MatMul || - n.OperationType == OperationType.Dense || - n.OperationType == OperationType.FullyConnected) && !n.IsFused).ToList()) - { - if (matmulNode.Outputs.Count == 1) - { - var addNode = matmulNode.Outputs[0]; - - if (addNode.OperationType == OperationType.Add && - addNode.Inputs.Count == 2 && - addNode.Outputs.Count == 1) - { - var activationNode = addNode.Outputs[0]; - - if (SupportedActivations.Contains(activationNode.OperationType) && - activationNode.Inputs.Count == 1) - { - // Check if Add has a constant bias - var biasNode = addNode.Inputs.FirstOrDefault(n => n != matmulNode); - - if (biasNode != null && biasNode.OperationType == OperationType.Constant) - { - FuseMatMulBiasActivationFromScratch( - graph, - matmulNode, - addNode, - biasNode, - activationNode); - modified = true; - } - } - } - } - } - - return modified; - } - - private void FuseMatMulBiasActivation( - IOptimizationGraph graph, - OptimizationNode fusedMatMulBias, - OptimizationNode activation) - { - // Determine the fused operation type based on activation - var fusedType = activation.OperationType switch - { - OperationType.ReLU => OperationType.FusedMatMulBiasReLU, - OperationType.GELU => OperationType.FusedMatMulBiasGELU, - _ => OperationType.FusedMatMulBias - }; - - // Create new fused node - var newFusedNode = new OptimizationNode - { - OperationType = fusedType, - Name = $"{fusedMatMulBias.Name}_{activation.OperationType.ToString().ToLower()}", - OutputShape = activation.OutputShape, - IsFused = true, - CanOperateInPlace = true, - FusedFrom = new List> { fusedMatMulBias, activation } - }; - - // Copy all parameters from fused MatMulBias - foreach (var param in fusedMatMulBias.Parameters) - { - newFusedNode.Parameters[param.Key] = param.Value; - } - - // Copy activation parameters if any - foreach (var param in activation.Parameters) - { - newFusedNode.Parameters[$"activation_{param.Key}"] = param.Value; - } - - // Connect inputs - foreach (var input in fusedMatMulBias.Inputs) - { - newFusedNode.AddInput(input); - input.Outputs.Remove(fusedMatMulBias); - } - - // Connect outputs - foreach (var output in activation.Outputs) - { - output.ReplaceInput(activation, newFusedNode); - } - - // Add new fused node and remove old ones - graph.AddNode(newFusedNode); - graph.RemoveNode(fusedMatMulBias); - graph.RemoveNode(activation); - } - - private void FuseMatMulBiasActivationFromScratch( - IOptimizationGraph graph, - OptimizationNode matmul, - OptimizationNode add, - OptimizationNode bias, - OptimizationNode activation) - { - var fusedType = activation.OperationType switch - { - OperationType.ReLU => OperationType.FusedMatMulBiasReLU, - OperationType.GELU => OperationType.FusedMatMulBiasGELU, - _ => OperationType.FusedMatMulBias - }; - - var fusedNode = new OptimizationNode - { - OperationType = fusedType, - Name = $"{matmul.Name}_fused", - OutputShape = activation.OutputShape, - IsFused = true, - CanOperateInPlace = true, - FusedFrom = new List> { matmul, add, activation } - }; - - // Copy matmul parameters - foreach (var param in matmul.Parameters) - { - fusedNode.Parameters[param.Key] = param.Value; - } - - // Add bias - fusedNode.Parameters["bias"] = bias.ConstantValue!; - - // Copy activation parameters - foreach (var param in activation.Parameters) - { - fusedNode.Parameters[$"activation_{param.Key}"] = param.Value; - } - - // Connect inputs - foreach (var input in matmul.Inputs) - { - fusedNode.AddInput(input); - input.Outputs.Remove(matmul); - } - - // Connect outputs - foreach (var output in activation.Outputs) - { - output.ReplaceInput(activation, fusedNode); - } - - // Add fused node and remove originals - graph.AddNode(fusedNode); - graph.RemoveNode(matmul); - graph.RemoveNode(add); - - // Only remove bias if it's not used elsewhere (shared biases should be kept) - // The bias was consumed by the Add node, but if it has other consumers, keep it - if (bias.Outputs.Count <= 1) - { - graph.RemoveNode(bias); - } - - graph.RemoveNode(activation); - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && - graph.Nodes.Any(n => n.OperationType == OperationType.MatMul || - n.OperationType == OperationType.Dense || - n.OperationType == OperationType.FullyConnected || - n.OperationType == OperationType.FusedMatMulBias); - } -} diff --git a/src/InferenceOptimization/Passes/MatMulBiasFusionPass.cs b/src/InferenceOptimization/Passes/MatMulBiasFusionPass.cs deleted file mode 100644 index 3d2d0f9cf1..0000000000 --- a/src/InferenceOptimization/Passes/MatMulBiasFusionPass.cs +++ /dev/null @@ -1,101 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Fuses MatMul + Bias (Add) operations into a single Gemm operation. -/// This is extremely common in fully connected layers and transformers. -/// -/// The numeric type (double, float, decimal) -public class MatMulBiasFusionPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.MatMulBiasFusion; - public override string Name => "MatMul + Bias Fusion"; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - // Find MatMul -> Add patterns where Add is adding a bias - foreach (var matmulNode in graph.Nodes.Where(n => - (n.OperationType == OperationType.MatMul || - n.OperationType == OperationType.Dense || - n.OperationType == OperationType.FullyConnected) && !n.IsFused && n.Outputs.Count == 1).ToList()) - { - var addNode = matmulNode.Outputs[0]; - - // Check if it's an Add with a constant bias - if (addNode.OperationType == OperationType.Add && - addNode.Inputs.Count == 2 && - !addNode.IsFused) - { - // One input should be matmul, other should be constant - var otherInput = addNode.Inputs.FirstOrDefault(n => n != matmulNode); - - if (otherInput != null && otherInput.OperationType == OperationType.Constant) - { - FuseMatMulBias(graph, new List> { matmulNode, addNode, otherInput }); - modified = true; - } - } - } - - return modified; - } - - private void FuseMatMulBias(IOptimizationGraph graph, List> nodes) - { - var matmulNode = nodes[0]; - var addNode = nodes[1]; - var biasNode = nodes[2]; - - // Create fused Gemm node (General Matrix Multiplication with bias) - var fusedNode = new OptimizationNode - { - OperationType = OperationType.FusedMatMulBias, - Name = $"{matmulNode.Name}_gemm", - OutputShape = addNode.OutputShape, - IsFused = true, - FusedFrom = new List> { matmulNode, addNode } - }; - - // Copy parameters - foreach (var param in matmulNode.Parameters) - { - fusedNode.Parameters[param.Key] = param.Value; - } - - // Add bias as a parameter - fusedNode.Parameters["bias"] = biasNode.ConstantValue!; - - // Connect inputs (from matmul, excluding the bias constant) - foreach (var input in matmulNode.Inputs) - { - fusedNode.AddInput(input); - input.Outputs.Remove(matmulNode); - } - - // Connect outputs - foreach (var output in addNode.Outputs) - { - output.ReplaceInput(addNode, fusedNode); - } - - // Add fused node - graph.AddNode(fusedNode); - - // Remove original nodes - graph.RemoveNode(matmulNode); - graph.RemoveNode(addNode); - graph.RemoveNode(biasNode); - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && - graph.Nodes.Any(n => n.OperationType == OperationType.MatMul || - n.OperationType == OperationType.Dense || - n.OperationType == OperationType.FullyConnected); - } -} diff --git a/src/InferenceOptimization/Passes/MemoryReuseOptimizationPass.cs b/src/InferenceOptimization/Passes/MemoryReuseOptimizationPass.cs deleted file mode 100644 index 1f65d5f78f..0000000000 --- a/src/InferenceOptimization/Passes/MemoryReuseOptimizationPass.cs +++ /dev/null @@ -1,141 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Optimizes memory usage by reusing buffers for different operations. -/// Analyzes the lifetime of tensors and assigns memory pools to reduce -/// overall memory footprint during inference. -/// -/// The numeric type (double, float, decimal) -public class MemoryReuseOptimizationPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.MemoryReuseOptimization; - public override string Name => "Memory Reuse Optimization"; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - // Perform liveness analysis - var liveness = PerformLivenessAnalysis(graph); - - // Assign memory pools based on liveness - var memoryPools = AssignMemoryPools(graph.Nodes, liveness); - - // Mark nodes with their memory pool assignments - foreach (var kvp in memoryPools) - { - kvp.Key.Metadata["MemoryPoolId"] = kvp.Value; - modified = true; - } - - return modified; - } - - private Dictionary, (int firstUse, int lastUse)> PerformLivenessAnalysis( - IOptimizationGraph graph) - { - var liveness = new Dictionary, (int, int)>(); - var topologicalOrder = graph.GetTopologicalOrder(); - - for (int i = 0; i < topologicalOrder.Count; i++) - { - var node = topologicalOrder[i]; - - // First use is when the node is computed - var firstUse = i; - - // Last use is when the last consumer reads it - var lastUse = i; - - foreach (var output in node.Outputs) - { - var outputIndex = topologicalOrder.IndexOf(output); - if (outputIndex > lastUse) - { - lastUse = outputIndex; - } - } - - liveness[node] = (firstUse, lastUse); - } - - return liveness; - } - - private Dictionary, int> AssignMemoryPools( - List> nodes, - Dictionary, (int firstUse, int lastUse)> liveness) - { - var poolAssignments = new Dictionary, int>(); - var pools = new List<(int lastUse, long size)>(); - - // Sort nodes by first use - var sortedNodes = nodes - .Where(n => liveness.ContainsKey(n)) - .OrderBy(n => liveness[n].firstUse) - .ToList(); - - foreach (var node in sortedNodes) - { - var (firstUse, lastUse) = liveness[node]; - var tensorSize = EstimateTensorSize(node); - - // Find a pool that's no longer in use - int assignedPool = -1; - - for (int i = 0; i < pools.Count; i++) - { - var (poolLastUse, poolSize) = pools[i]; - - // Can reuse this pool if it's no longer active and size matches - if (poolLastUse < firstUse && poolSize >= tensorSize) - { - assignedPool = i; - pools[i] = (lastUse, poolSize); - break; - } - } - - // If no pool found, create a new one - if (assignedPool == -1) - { - assignedPool = pools.Count; - pools.Add((lastUse, tensorSize)); - } - - poolAssignments[node] = assignedPool; - } - - return poolAssignments; - } - - private long EstimateTensorSize(OptimizationNode node) - { - // Estimate the memory size of the output tensor - if (node.OutputShape.Length == 0) - { - return 0; - } - - long size = 1; - foreach (var dim in node.OutputShape) - { - size *= dim; - } - - // Multiply by size of T (rough estimate) - var typeSize = typeof(T) == typeof(double) ? 8 : - typeof(T) == typeof(float) ? 4 : - typeof(T) == typeof(decimal) ? 16 : 8; - - return size * typeSize; - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && graph.Nodes.Count > 2; - } -} diff --git a/src/InferenceOptimization/Passes/MultiHeadAttentionFusionPass.cs b/src/InferenceOptimization/Passes/MultiHeadAttentionFusionPass.cs deleted file mode 100644 index d0812e4f14..0000000000 --- a/src/InferenceOptimization/Passes/MultiHeadAttentionFusionPass.cs +++ /dev/null @@ -1,90 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Fuses multi-head attention components into a single optimized operation. -/// Multi-head attention consists of multiple Q, K, V projections, attention computation, -/// and output projection. Fusing these provides significant speedup in transformers. -/// -/// The numeric type (double, float, decimal) -public class MultiHeadAttentionFusionPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.AttentionFusion; - public override string Name => "Multi-Head Attention Fusion"; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - // Find multi-head attention layers that can be fused - var attentionNodes = graph.Nodes.Where(n => - n.OperationType == OperationType.MultiHeadAttention && !n.IsFused && CanFuseAttention(n)).ToList(); - - foreach (var attentionNode in attentionNodes) - { - FuseAttention(graph, attentionNode); - modified = true; - } - - return modified; - } - - private bool CanFuseAttention(OptimizationNode attentionNode) - { - // Check if the attention pattern is suitable for fusion - // In a real implementation, we'd check for: - // 1. Q, K, V projection matrices - // 2. Attention computation (softmax, dropout) - // 3. Output projection - return attentionNode.Outputs.Count > 0; - } - - private void FuseAttention(IOptimizationGraph graph, OptimizationNode attentionNode) - { - // Create a fused multi-head attention node - var fusedNode = new OptimizationNode - { - OperationType = OperationType.FusedMultiHeadAttention, - Name = $"{attentionNode.Name}_fused", - OutputShape = attentionNode.OutputShape, - IsFused = true, - FusedFrom = new List> { attentionNode } - }; - - // Copy all parameters - foreach (var param in attentionNode.Parameters) - { - fusedNode.Parameters[param.Key] = param.Value; - } - - // Set metadata for optimized attention computation - fusedNode.Metadata["UseFlashAttention"] = true; // Enable flash attention if available - fusedNode.Metadata["ScaledDotProduct"] = true; - - // Connect inputs - foreach (var input in attentionNode.Inputs) - { - fusedNode.AddInput(input); - input.Outputs.Remove(attentionNode); - } - - // Connect outputs - foreach (var output in attentionNode.Outputs) - { - output.ReplaceInput(attentionNode, fusedNode); - } - - // Replace in graph - graph.AddNode(fusedNode); - graph.RemoveNode(attentionNode); - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && - graph.Nodes.Any(n => n.OperationType == OperationType.MultiHeadAttention || - n.OperationType == OperationType.Attention); - } -} diff --git a/src/InferenceOptimization/Passes/OptimizationPassBase.cs b/src/InferenceOptimization/Passes/OptimizationPassBase.cs deleted file mode 100644 index 2af827437b..0000000000 --- a/src/InferenceOptimization/Passes/OptimizationPassBase.cs +++ /dev/null @@ -1,146 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Base class for optimization passes. -/// -/// The numeric type (double, float, decimal) -public abstract class OptimizationPassBase : IOptimizationPass where T : struct -{ - public abstract OptimizationPassType PassType { get; } - public abstract string Name { get; } - - public abstract bool Apply(IOptimizationGraph graph); - - public virtual bool CanApply(IOptimizationGraph graph) - { - return graph != null && graph.Nodes.Count > 0; - } - - /// - /// Helper method to find fusion candidates in the graph. - /// - protected List>> FindFusionCandidates( - IOptimizationGraph graph, - params OperationType[] pattern) - { - var candidates = new List>>(); - - foreach (var node in graph.Nodes.Where(n => n.OperationType == pattern[0] && !n.IsFused)) - { - var sequence = TryMatchPattern(node, pattern); - if (sequence != null) - { - candidates.Add(sequence); - } - } - - return candidates; - } - - /// - /// Tries to match a pattern starting from a given node. - /// - protected List>? TryMatchPattern( - OptimizationNode startNode, - OperationType[] pattern) - { - var sequence = new List> { startNode }; - var currentNode = startNode; - - for (int i = 1; i < pattern.Length; i++) - { - // Check if current node has exactly one output - if (currentNode.Outputs.Count != 1) - { - return null; - } - - var nextNode = currentNode.Outputs[0]; - - // Check if next node matches the pattern and has exactly one input - if (nextNode.OperationType != pattern[i] || nextNode.Inputs.Count != 1) - { - return null; - } - - // Check if next node is not already fused - if (nextNode.IsFused) - { - return null; - } - - sequence.Add(nextNode); - currentNode = nextNode; - } - - return sequence; - } - - /// - /// Replaces a sequence of nodes with a fused node. - /// - protected OptimizationNode FuseNodes( - IOptimizationGraph graph, - List> nodesToFuse, - OperationType fusedOperationType) - { - if (nodesToFuse.Count == 0) - { - throw new ArgumentException("No nodes to fuse"); - } - - var firstNode = nodesToFuse[0]; - var lastNode = nodesToFuse[nodesToFuse.Count - 1]; - - // Create fused node - var fusedNode = new OptimizationNode - { - OperationType = fusedOperationType, - Name = $"{firstNode.Name}_fused", - OutputShape = lastNode.OutputShape, - IsFused = true, - FusedFrom = new List>(nodesToFuse) - }; - - // Copy parameters from all nodes - foreach (var node in nodesToFuse) - { - foreach (var param in node.Parameters) - { - fusedNode.Parameters[$"{node.Name}_{param.Key}"] = param.Value; - } - - foreach (var meta in node.Metadata) - { - fusedNode.Metadata[$"{node.Name}_{meta.Key}"] = meta.Value; - } - } - - // Connect inputs from first node - foreach (var input in firstNode.Inputs) - { - fusedNode.AddInput(input); - input.Outputs.Remove(firstNode); - } - - // Connect outputs from last node (iterate over copy since ReplaceInput modifies the collection) - foreach (var output in lastNode.Outputs.ToList()) - { - output.ReplaceInput(lastNode, fusedNode); - } - - // Add fused node to graph - graph.AddNode(fusedNode); - - // Remove original nodes - foreach (var node in nodesToFuse) - { - graph.RemoveNode(node); - } - - return fusedNode; - } -} diff --git a/src/InferenceOptimization/Passes/StrengthReductionPass.cs b/src/InferenceOptimization/Passes/StrengthReductionPass.cs deleted file mode 100644 index 1da6a79c52..0000000000 --- a/src/InferenceOptimization/Passes/StrengthReductionPass.cs +++ /dev/null @@ -1,188 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; - -namespace AiDotNet.InferenceOptimization.Passes; - -/// -/// Replaces expensive operations with cheaper equivalent operations. -/// Examples: -/// - x^2 -> x * x -/// - x * 2 -> x + x -/// - x / 2 -> x * 0.5 -/// - sqrt(x^2) -> abs(x) -/// -/// The numeric type (double, float, decimal) -public class StrengthReductionPass : OptimizationPassBase where T : struct -{ - public override OptimizationPassType PassType => OptimizationPassType.StrengthReduction; - public override string Name => "Strength Reduction"; - - public override bool Apply(IOptimizationGraph graph) - { - bool modified = false; - - var reducibleOps = new HashSet - { - OperationType.Power, OperationType.Divide, OperationType.Multiply - }; - foreach (var node in graph.Nodes.Where(n => reducibleOps.Contains(n.OperationType)).ToList()) - { - if (TryReduceStrength(graph, node)) - { - modified = true; - } - } - - return modified; - } - - private bool TryReduceStrength(IOptimizationGraph graph, OptimizationNode node) - { - return node.OperationType switch - { - OperationType.Power => ReducePower(graph, node), - OperationType.Divide => ReduceDivide(graph, node), - OperationType.Multiply => ReduceMultiply(graph, node), - _ => false - }; - } - - private bool ReducePower(IOptimizationGraph graph, OptimizationNode node) - { - if (node.Inputs.Count != 2) return false; - - var exponent = node.Inputs[1]; - - // x^2 -> x * x (multiplication is faster than power) - if (IsConstantWithValue(exponent, 2)) - { - var multiplyNode = new OptimizationNode - { - OperationType = OperationType.Multiply, - Name = $"{node.Name}_strength_reduced", - OutputShape = node.OutputShape - }; - - var input = node.Inputs[0]; - multiplyNode.AddInput(input); - multiplyNode.AddInput(input); - - ReplaceNode(graph, node, multiplyNode); - return true; - } - - return false; - } - - private bool ReduceDivide(IOptimizationGraph graph, OptimizationNode node) - { - if (node.Inputs.Count != 2) return false; - - var divisor = node.Inputs[1]; - - // x / constant -> x * (1/constant) (multiplication is faster than division) - if (divisor.OperationType == OperationType.Constant) - { - var multiplyNode = new OptimizationNode - { - OperationType = OperationType.Multiply, - Name = $"{node.Name}_strength_reduced", - OutputShape = node.OutputShape - }; - - // Create reciprocal constant - var reciprocalNode = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = $"{divisor.Name}_reciprocal", - OutputShape = divisor.OutputShape, - Metadata = new Dictionary - { - ["IsReciprocal"] = true, - ["OriginalConstant"] = divisor - } - }; - - graph.AddNode(reciprocalNode); - - multiplyNode.AddInput(node.Inputs[0]); - multiplyNode.AddInput(reciprocalNode); - - ReplaceNode(graph, node, multiplyNode); - return true; - } - - return false; - } - - private bool ReduceMultiply(IOptimizationGraph graph, OptimizationNode node) - { - if (node.Inputs.Count != 2) return false; - - // x * 2 -> x + x (addition might be faster on some hardware) - var left = node.Inputs[0]; - var right = node.Inputs[1]; - - if (IsConstantWithValue(right, 2)) - { - var addNode = new OptimizationNode - { - OperationType = OperationType.Add, - Name = $"{node.Name}_strength_reduced", - OutputShape = node.OutputShape - }; - - addNode.AddInput(left); - addNode.AddInput(left); - - // Note: Only apply this if beneficial on target hardware - if (IsAdditionFasterThanMultiplication()) - { - ReplaceNode(graph, node, addNode); - return true; - } - } - - return false; - } - - private bool IsConstantWithValue(OptimizationNode node, double value) - { - if (node.OperationType != OperationType.Constant) - { - return false; - } - - // In a real implementation, check the actual constant value - return node.Metadata.TryGetValue("Value", out var val) && - Math.Abs((double)val - value) < 1e-6; - } - - private bool IsAdditionFasterThanMultiplication() - { - // This would depend on the target hardware - // For most modern CPUs, multiplication and addition have similar latency - // So we return false by default - return false; - } - - private void ReplaceNode(IOptimizationGraph graph, OptimizationNode oldNode, OptimizationNode newNode) - { - // Replace all outputs - foreach (var output in oldNode.Outputs.ToList()) - { - output.ReplaceInput(oldNode, newNode); - } - - // Add new node and remove old one - graph.AddNode(newNode); - graph.RemoveNode(oldNode); - } - - public override bool CanApply(IOptimizationGraph graph) - { - return base.CanApply(graph) && - (graph.Nodes.Any(n => n.OperationType == OperationType.Power) || - graph.Nodes.Any(n => n.OperationType == OperationType.Divide)); - } -} diff --git a/src/InferenceOptimization/README.md b/src/InferenceOptimization/README.md deleted file mode 100644 index 2eadda18b0..0000000000 --- a/src/InferenceOptimization/README.md +++ /dev/null @@ -1,253 +0,0 @@ -# AiDotNet Inference Optimization - -This module provides low-level kernel optimization for critical operations, enabling hardware-specific acceleration for efficient AI model inference. - -## Features - -### 1. Custom Operator Registration System -- Thread-safe operator registry with automatic fallback -- Priority-based operator selection -- Support for multiple implementations per operation -- Runtime operator switching based on platform capabilities - -### 2. Platform Detection -- Automatic detection of CPU architecture (x86/x64, ARM) -- SIMD instruction set detection (SSE, AVX, AVX2, AVX-512, NEON) -- Cache size estimation -- GPU capability detection (CUDA, OpenCL) - -### 3. SIMD Vectorization -- AVX2/AVX-512 optimized kernels for x86/x64 -- ARM NEON optimized kernels -- Automatic fallback to scalar implementations -- Optimized operations: - - Vector addition/multiplication - - Dot product with FMA support - - ReLU activation - - Sum reduction - - Scalar multiply-add - -### 4. Optimized Kernels - -#### GEMM (General Matrix Multiplication) -- Cache-blocked algorithm for L1 cache efficiency -- Parallel execution for large matrices -- SIMD-optimized inner loops -- Transpose optimization for better memory access patterns -- Expected speedup: 2-3x on AVX2, 2.5x on NEON - -#### Fused Attention Kernel -- Scaled dot-product attention: `softmax(QK^T/sqrt(d_k))V` -- Multi-head attention support -- Memory-efficient implementation -- Mask support for causal attention -- Expected speedup: 2.5x - -#### Convolution Kernels -- Standard 2D convolution -- Depthwise separable convolution -- Group convolution -- Parallel batch processing -- Expected speedup: 2-2.5x - -### 5. CPU Optimizations - -#### Cache Optimizer -- L1/L2/L3 cache-aware algorithms -- Automatic tiling parameter computation -- Prefetching for reduced latency -- Cache-aware transpose -- Z-order (Morton) indexing for 2D access patterns -- Cache miss estimation - -#### Loop Optimizer -- 2D and 3D loop tiling -- Loop unrolling (4x, 8x) -- Strip mining for cache utilization -- Loop fusion -- Loop interchange optimization -- Parallel tiling with work stealing - -### 6. Performance Profiling -- Thread-safe operation tracking -- Timing and memory usage statistics -- Per-operation metrics (min/avg/max/total) -- Performance report generation -- Runtime enable/disable capability - -### 7. GPU Optimization Infrastructure -- Base classes for GPU kernel implementations -- Memory management abstractions -- CUDA kernel base (ready for DirectCuda/ManagedCuda integration) -- Device capability querying - -## Quick Start - -```csharp -using AiDotNet.InferenceOptimization; -using AiDotNet.InferenceOptimization.Kernels; -using AiDotNet.Tensors.Engines.Simd; // SimdKernels location -using AiDotNet.Tensors.LinearAlgebra; - -// Initialize the optimization system -OptimizationInitializer.Initialize(enableProfiling: true); - -// Use optimized GEMM -var gemmKernel = new GemmKernel(); -var a = new Tensor(new[] { 1000, 500 }); -var b = new Tensor(new[] { 500, 1000 }); -var result = gemmKernel.Execute(a, b); - -// Use fused attention -var attentionKernel = new AttentionKernel(); -var q = new Tensor(new[] { 1, 128, 64 }); // [batch, seq_len, d_k] -var k = new Tensor(new[] { 1, 128, 64 }); -var v = new Tensor(new[] { 1, 128, 64 }); -var attended = attentionKernel.Execute(q, k, v); - -// Get performance report -var report = OptimizationInitializer.GetPerformanceSummary(); -Console.WriteLine(report); -``` - -## Platform Capabilities - -Check what optimizations are available on your platform: - -```csharp -var caps = PlatformDetector.Capabilities; -Console.WriteLine($"Best SIMD: {caps.GetBestSimdSet()}"); -Console.WriteLine($"Has AVX2: {caps.HasAVX2}"); -Console.WriteLine($"Has NEON: {caps.HasNeon}"); -Console.WriteLine($"Processor Count: {caps.ProcessorCount}"); -``` - -## Custom Operators - -Register your own optimized operators: - -```csharp -public class MyCustomKernel : ICustomOperator -{ - public string Name => "MyOperation"; - public string Version => "1.0.0"; - public int Priority => 100; - - public bool IsSupported() - { - return PlatformDetector.Capabilities.HasAVX2; - } - - public double EstimatedSpeedup() - { - return 3.0; // Expected 3x speedup - } - - public Tensor Execute(params Tensor[] inputs) - { - // Your optimized implementation - // ... - } -} - -// Register the operator -CustomOperatorRegistry.Instance.Register(new MyCustomKernel()); - -// Use the operator -var kernel = CustomOperatorRegistry.Instance.GetOperator("MyOperation"); -var result = kernel.Execute(input1, input2); -``` - -## Performance Profiling - -Enable profiling to track performance: - -```csharp -// Enable profiling -OptimizationInitializer.Initialize(enableProfiling: true); - -// Operations are automatically profiled -// ... - -// Get report -var report = OptimizationInitializer.GetPerformanceSummary(); -Console.WriteLine(report); - -// Reset statistics -OptimizationInitializer.ResetStatistics(); -``` - -## CPU Optimization Utilities - -Use cache-aware and loop optimization utilities: - -```csharp -using AiDotNet.Tensors.Engines.Optimization; - -// Determine optimal tile size -int tileSize = LoopOptimizer.DetermineOptimalTileSize(matrixSize); - -// Use tiled loops -LoopOptimizer.Tile2D(rows, cols, tileSize, (iStart, iEnd, jStart, jEnd) => -{ - // Process tile -}); - -// Use parallel tiling -LoopOptimizer.ParallelTile2D(rows, cols, tileSize, (iStart, iEnd, jStart, jEnd) => -{ - // Process tile in parallel -}); - -// Cache-aware transpose -CacheOptimizer.TransposeBlocked(sourceArray, destArray, rows, cols); -``` - -## Benchmarking - -See `AiDotNetBenchmarkTests/InferenceOptimization/` for benchmark examples. - -## Future Enhancements - -- GPU kernel implementations using DirectGpu backends or ManagedCuda -- Quantization support (INT8, FP16) -- Model graph optimization -- Operator fusion -- Dynamic batching optimization -- Memory pooling - -## Integration with Existing Codebase - -The optimization module integrates with existing AiDotNet components: - -- **Tensor Operations**: Optimized kernels work with `AiDotNet.LinearAlgebra.Tensor` -- **Neural Networks**: Can be used to accelerate layer operations in `NeuralNetworkBase` -- **Serving**: Integrates with `RequestBatcher` for optimized inference - -## Requirements - -- .NET 8.0 or later -- x86/x64 or ARM64 processor -- For GPU support: CUDA-capable GPU (future implementation) - -## Performance Targets - -- 2-5x speedup on critical operations (achieved through SIMD and cache optimization) -- Hardware-specific optimizations (AVX2, AVX-512, NEON) -- Graceful fallback behavior (automatic platform detection) -- Benchmarking against MKL and cuBLAS (future work) - -## Contributing - -To add new optimizations: - -1. Implement `ICustomOperator` interface -2. Override `IsSupported()` to check platform compatibility -3. Implement optimized `Execute()` method -4. Register operator with `CustomOperatorRegistry` -5. Add benchmarks in `AiDotNetBenchmarkTests/` - -## License - -Same as parent AiDotNet project. - diff --git a/src/Interfaces/IAiModelBuilder.cs b/src/Interfaces/IAiModelBuilder.cs index 8e49eb032c..358b844587 100644 --- a/src/Interfaces/IAiModelBuilder.cs +++ b/src/Interfaces/IAiModelBuilder.cs @@ -1403,6 +1403,33 @@ IAiModelBuilder ConfigureJitCompilation( /// This builder for fluent chaining. IAiModelBuilder AllowNondeterminism(); + /// + /// Captures SIMD/GPU/native-BLAS acceleration status at build time, logs it, and + /// surfaces a structured snapshot on PredictionModelResult.AccelerationSnapshot. + /// + /// + /// Optional callback receiving the formatted report. Defaults to . + /// + /// This builder for fluent chaining. + IAiModelBuilder ReportAccelerationStatus(Action? logger = null); + + /// + /// Enables disk-backed caching of compiled inference plans in the supplied directory. + /// Plans are saved after first compile and loaded transparently on next process + /// start, skipping cold-start compile cost. + /// + /// Filesystem directory to store plan files. Created if missing. + /// This builder for fluent chaining. + IAiModelBuilder ConfigurePlanCaching(string directory); + + /// + /// Enables low-level per-tensor-op profiling via Tensors' + /// PerformanceProfiler.Instance. After BuildAsync, timings are captured + /// on AiModelResult.TensorsOperationProfile. + /// + /// This builder for fluent chaining. + IAiModelBuilder EnableTensorsOpProfiling(); + /// /// Configures mixed-precision training for faster neural network training with reduced memory usage. /// diff --git a/src/Interfaces/IFullModel.cs b/src/Interfaces/IFullModel.cs index d4c6df67fd..2e8921238a 100644 --- a/src/Interfaces/IFullModel.cs +++ b/src/Interfaces/IFullModel.cs @@ -25,7 +25,6 @@ namespace AiDotNet.Interfaces; /// - IParameterizable: Get/set model parameters (linear models, neural networks) /// - IGradientComputable: Compute and apply gradients (gradient-based optimization) /// - IFeatureAware: Feature selection and tracking -/// - IJitCompilable: Export computation graph for JIT compilation /// /// Not all models support all capabilities. Tree-based and ensemble models /// may not implement IParameterizable or IGradientComputable. diff --git a/src/KnowledgeDistillation/Teachers/OnlineTeacherModel.cs b/src/KnowledgeDistillation/Teachers/OnlineTeacherModel.cs index 848a3af814..835a8b153d 100644 --- a/src/KnowledgeDistillation/Teachers/OnlineTeacherModel.cs +++ b/src/KnowledgeDistillation/Teachers/OnlineTeacherModel.cs @@ -101,10 +101,6 @@ public class OnlineTeacherModel : TeacherModelBase, Vector, T> /// How to update the teacher (default: EMA). /// Update rate for EMA or learning rate (default: 0.999 for EMA). /// How often to update (default: every step). - /// - /// Note: This constructor creates a non-JIT-compilable teacher. - /// For JIT support, use the constructor that accepts an IJitCompilable model. - /// public OnlineTeacherModel( Func, Vector> teacherForward, int inputDimension, diff --git a/src/KnowledgeDistillation/Teachers/PretrainedTeacherModel.cs b/src/KnowledgeDistillation/Teachers/PretrainedTeacherModel.cs index dc9d00f030..68f528d41b 100644 --- a/src/KnowledgeDistillation/Teachers/PretrainedTeacherModel.cs +++ b/src/KnowledgeDistillation/Teachers/PretrainedTeacherModel.cs @@ -12,11 +12,16 @@ namespace AiDotNet.KnowledgeDistillation.Teachers; /// Pretrained teacher model from external source (e.g., ImageNet, BERT). /// /// -/// Architecture Note: This class supports two construction modes: -/// -/// Function delegate mode: Uses a Func<> for forward pass (not JIT-compilable) -/// IJitCompilable mode: Uses a JIT-compilable model for forward pass (JIT-compilable) -/// +/// +/// This wrapper takes a Func<Vector<T>, Vector<T>> forward-pass +/// delegate and invokes it directly on every call. +/// The wrapper itself performs no caching or graph compilation — any +/// optimizations (including Tensors' AutoTracer auto-compile) depend entirely +/// on what happens inside the supplied delegate. A delegate that wraps a +/// standard neural-network model's Predict path will pick up those +/// engine-level optimizations; a delegate that invokes external code +/// (pre-converted ONNX, a REST call, etc.) will not. +/// /// [ModelDomain(ModelDomain.MachineLearning)] [ModelCategory(ModelCategory.NeuralNetwork)] diff --git a/src/KnowledgeDistillation/Teachers/QuantizedTeacherModel.cs b/src/KnowledgeDistillation/Teachers/QuantizedTeacherModel.cs index afdb3c298b..70631ccd67 100644 --- a/src/KnowledgeDistillation/Teachers/QuantizedTeacherModel.cs +++ b/src/KnowledgeDistillation/Teachers/QuantizedTeacherModel.cs @@ -20,10 +20,6 @@ namespace AiDotNet.KnowledgeDistillation.Teachers; /// Faster inference on hardware with integer support /// Reduced memory bandwidth requirements /// -/// JIT Support: When constructed with an IJitCompilable base model, this teacher -/// supports JIT compilation using FakeQuantization with Straight-Through Estimator (STE). -/// This allows the quantized model to be differentiated during training while simulating -/// quantization effects. /// [ModelDomain(ModelDomain.MachineLearning)] [ModelCategory(ModelCategory.NeuralNetwork)] @@ -57,8 +53,9 @@ public class QuantizedTeacherModel : TeacherModelBase, Vector, T /// The base teacher model to quantize. /// Number of bits for quantization (1-32). /// - /// This constructor uses dynamic quantization (per-batch min/max finding) which - /// does not support JIT compilation. Use the constructor with IJitCompilable for JIT support. + /// Uses dynamic quantization (per-batch min/max). The underlying teacher's forward + /// pass goes through Tensors' AutoTracer and is auto-compiled after the input-shape + /// pattern repeats. /// public QuantizedTeacherModel( ITeacherModel, Vector> baseTeacher, diff --git a/src/KnowledgeDistillation/Teachers/SelfTeacherModel.cs b/src/KnowledgeDistillation/Teachers/SelfTeacherModel.cs index d42753f7be..88b576395c 100644 --- a/src/KnowledgeDistillation/Teachers/SelfTeacherModel.cs +++ b/src/KnowledgeDistillation/Teachers/SelfTeacherModel.cs @@ -13,11 +13,8 @@ namespace AiDotNet.KnowledgeDistillation.Teachers; /// /// /// For Beginners: Self-distillation is a technique where a model learns from its own -/// earlier predictions. This teacher can operate in two modes: -/// -/// Cached Mode: Uses pre-computed predictions from earlier epochs (no JIT support) -/// Model Mode: Wraps an IJitCompilable model for dynamic predictions (JIT support available) -/// +/// earlier predictions. This teacher stores pre-computed predictions from earlier epochs and +/// returns them by index via . /// [ModelDomain(ModelDomain.MachineLearning)] [ModelCategory(ModelCategory.NeuralNetwork)] @@ -47,8 +44,7 @@ public class SelfTeacherModel : TeacherModelBase, Vector, T> /// The output dimension of predictions. /// /// Use this constructor when you want to manually cache predictions via - /// and retrieve them via . - /// JIT compilation is not supported in this mode. + /// and retrieve them via . /// public SelfTeacherModel(int outputDimension) { @@ -79,14 +75,19 @@ public void CachePredictions(Vector[] predictions) } /// - /// Gets logits from the underlying model. + /// Not supported for — always throws. /// - /// Input to the model. - /// The logits from the underlying model. - /// Thrown when no underlying model is configured. + /// Ignored. + /// This method does not return; it always throws. + /// + /// Always thrown. serves pre-computed + /// predictions by index via and cannot + /// evaluate a fresh input vector — it has no underlying model to run. + /// /// - /// This method is only available when the SelfTeacherModel was constructed with an - /// IJitCompilable model. For cached prediction mode, use . + /// Callers must use instead, which + /// returns a prediction from the cache populated via + /// . /// public override Vector GetLogits(Vector input) { diff --git a/src/KnowledgeDistillation/Teachers/TransformerTeacherModel.cs b/src/KnowledgeDistillation/Teachers/TransformerTeacherModel.cs index 22850db576..a60a36eaef 100644 --- a/src/KnowledgeDistillation/Teachers/TransformerTeacherModel.cs +++ b/src/KnowledgeDistillation/Teachers/TransformerTeacherModel.cs @@ -13,12 +13,13 @@ namespace AiDotNet.KnowledgeDistillation.Teachers; /// /// The numeric type for calculations (e.g., double, float). /// -/// Architecture Note: This class supports two construction modes: -/// -/// Function delegate mode: Uses a Func<> for forward pass (not JIT-compilable) -/// IJitCompilable mode: Uses a JIT-compilable model for forward pass (JIT-compilable) -/// -/// +/// +/// This wrapper takes a Func<Vector<T>, Vector<T>> forward-pass +/// delegate and invokes it directly on every call. +/// The wrapper performs no caching or graph compilation itself — any +/// optimizations (including Tensors' AutoTracer auto-compile) depend on what +/// the supplied delegate does internally. +/// /// For attention-based distillation strategies that need attention weights, implement /// a custom IDistillationStrategy that can extract attention from the underlying model. /// @@ -53,10 +54,6 @@ public class TransformerTeacherModel : TeacherModelBase, Vector, /// The number of output dimensions. /// Thrown when forwardFunc is null. /// Thrown when dimensions are not positive. - /// - /// Note: This constructor creates a non-JIT-compilable teacher. - /// For JIT support, use the constructor that accepts an IJitCompilable model. - /// public TransformerTeacherModel( Func, Vector> forwardFunc, int inputDimension, diff --git a/src/LinearAlgebra/ExpressionTree.cs b/src/LinearAlgebra/ExpressionTree.cs index b39e1934e8..1746dd4c91 100644 --- a/src/LinearAlgebra/ExpressionTree.cs +++ b/src/LinearAlgebra/ExpressionTree.cs @@ -1631,70 +1631,4 @@ public override void LoadState(Stream stream) } } - #region IJitCompilable Implementation - - /// - /// Recursively builds a computation graph from an expression tree node. - /// - /// The expression tree node to convert. - /// Dictionary mapping variable indices to their computation nodes. - /// The computation node representing this expression tree node. - private ComputationNode BuildComputationGraph( - ExpressionTree node, - Dictionary> variableNodes) - { - switch (node.Type) - { - case ExpressionNodeType.Constant: - // Create a constant tensor (scalar) - var constantTensor = new Tensor(new[] { 1 }); - constantTensor[0] = node.Value; - return new ComputationNode(constantTensor); - - case ExpressionNodeType.Variable: - // Get or create variable node - int varIndex = NumOps.ToInt32(node.Value); - if (!variableNodes.ContainsKey(varIndex)) - { - // Create placeholder for this variable - var varTensor = new Tensor(new[] { 1 }); - varTensor[0] = NumOps.Zero; // Placeholder value - variableNodes[varIndex] = new ComputationNode(varTensor); - } - return variableNodes[varIndex]; - - case ExpressionNodeType.Add: - if (node.Left == null || node.Right == null) - throw new InvalidOperationException("Add operation requires both left and right operands."); - return TensorOperations.Add( - BuildComputationGraph(node.Left, variableNodes), - BuildComputationGraph(node.Right, variableNodes)); - - case ExpressionNodeType.Subtract: - if (node.Left == null || node.Right == null) - throw new InvalidOperationException("Subtract operation requires both left and right operands."); - return TensorOperations.Subtract( - BuildComputationGraph(node.Left, variableNodes), - BuildComputationGraph(node.Right, variableNodes)); - - case ExpressionNodeType.Multiply: - if (node.Left == null || node.Right == null) - throw new InvalidOperationException("Multiply operation requires both left and right operands."); - return TensorOperations.ElementwiseMultiply( - BuildComputationGraph(node.Left, variableNodes), - BuildComputationGraph(node.Right, variableNodes)); - - case ExpressionNodeType.Divide: - if (node.Left == null || node.Right == null) - throw new InvalidOperationException("Divide operation requires both left and right operands."); - return TensorOperations.Divide( - BuildComputationGraph(node.Left, variableNodes), - BuildComputationGraph(node.Right, variableNodes)); - - default: - throw new InvalidOperationException($"Unknown expression node type: {node.Type}"); - } - } - - #endregion } diff --git a/src/Models/Results/AiModelResult.Diagnostics.cs b/src/Models/Results/AiModelResult.Diagnostics.cs new file mode 100644 index 0000000000..4844163dd0 --- /dev/null +++ b/src/Models/Results/AiModelResult.Diagnostics.cs @@ -0,0 +1,41 @@ +using AiDotNet.Diagnostics; + +namespace AiDotNet.Models.Results; + +/// +/// Diagnostics surface on — exposes the +/// acceleration environment snapshot captured at build time when the builder opts in via +/// ReportAccelerationStatus(). +/// +public partial class AiModelResult +{ + /// + /// Snapshot of the SIMD, GPU, and native-BLAS acceleration state captured when this + /// model was built. null if the builder did not call ReportAccelerationStatus. + /// + /// + /// + /// Useful for production observability, CI assertions, and diagnosing why a model is + /// slower than expected on a given host. Fields include CPU SIMD level + /// (), GPU backends detected + /// (, etc.), and native BLAS availability + /// (, etc.). + /// + /// + public AccelerationSnapshot? AccelerationSnapshot { get; internal set; } + + /// + /// Per-tensor-op performance profile captured when the builder opted in via + /// EnableTensorsOpProfiling(). Null otherwise. Complements the + /// higher-level ProfilingReport (from ConfigureProfiling) by + /// surfacing Tensors-package kernel timings, not just AiDotNet workflow timings. + /// + /// + /// + /// PyTorch-parity equivalent: low-level torch.profiler.profile CUDA/CPU op + /// breakdown. Operations are sorted by total time descending; use + /// for a one-line-per-op table. + /// + /// + public TensorsOperationProfile? TensorsOperationProfile { get; internal set; } +} diff --git a/src/Models/Results/AiModelResult.cs b/src/Models/Results/AiModelResult.cs index ee7ec25af9..132837e3fb 100644 --- a/src/Models/Results/AiModelResult.cs +++ b/src/Models/Results/AiModelResult.cs @@ -6657,95 +6657,4 @@ private IChatModel CreateChatModelFromAgentConfig() #endregion - #region IJitCompilable Implementation - - /// - /// Gets whether the underlying model currently supports JIT compilation. - /// - /// Returns true if the wrapped model implements IJitCompilable and supports JIT, false otherwise. - /// - /// - /// This property delegates to the wrapped model's SupportsJitCompilation property if the model - /// implements IJitCompilable. If the model does not implement this interface or does not support - /// JIT compilation, this returns false. - /// - /// For Beginners: Whether you can use JIT compilation depends on the type of model you trained. - /// - /// Models that support JIT compilation (SupportsJitCompilation = true): - /// - Linear regression models - /// - Polynomial regression models - /// - Ridge/Lasso regression models - /// - Models using differentiable operations - /// - /// Models that do NOT support JIT (SupportsJitCompilation = false): - /// - Decision trees - /// - Random forests - /// - Gradient boosted trees - /// - Models using discrete logic - /// - /// If your model supports JIT: - /// - Predictions will be 5-10x faster - /// - The computation graph is compiled to optimized native code - /// - You get this speedup automatically when calling Predict() - /// - /// If your model doesn't support JIT: - /// - Predictions still work normally - /// - No JIT acceleration, but still optimized for the model type - /// - /// - public bool SupportsJitCompilation - { - get - { - if (Model == null) - { - return false; - } - - // JIT compilation has been removed - return false; - } - } - - /// - /// Exports the underlying model's computation graph for JIT compilation. - /// - /// List to populate with input computation nodes. - /// The output computation node representing the model's prediction. - /// Thrown when Model is null. - /// Thrown when the underlying model does not support JIT compilation. - /// - /// - /// This method delegates to the wrapped model's ExportComputationGraph method if the model - /// implements IJitCompilable and supports JIT compilation. If the model does not implement - /// this interface or does not support JIT, this throws NotSupportedException. - /// - /// For Beginners: This method creates a "recipe" of your model's calculations for JIT compilation. - /// - /// If your model supports JIT (SupportsJitCompilation = true): - /// - This method creates a computation graph from your model - /// - The graph represents all the mathematical operations your model performs - /// - The JIT compiler uses this to create fast optimized code - /// - /// If your model doesn't support JIT (SupportsJitCompilation = false): - /// - This method will throw an exception - /// - Check SupportsJitCompilation before calling this - /// - Decision trees, random forests, etc. cannot export computation graphs - /// - /// You typically don't call this method directly. It's used internally by: - /// - AiModelBuilder when building models with JIT enabled - /// - The prediction pipeline to compile models for faster inference - /// - /// Example of what happens inside: - /// - Linear model: Creates graph with MatMul(X, Coefficients) + Intercept - /// - Neural network: Creates graph with all layers and activations - /// - Decision tree: Throws exception - cannot create computation graph - /// - /// - public AiDotNet.Autodiff.ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation has been removed."); - } - - #endregion } diff --git a/src/Models/VectorModel.cs b/src/Models/VectorModel.cs index 09ae37d2e3..d033470708 100644 --- a/src/Models/VectorModel.cs +++ b/src/Models/VectorModel.cs @@ -1731,22 +1731,4 @@ public override void LoadState(Stream stream) } } - #region IJitCompilable Implementation - - /// - /// Converts a Vector to a Tensor for use in computation graphs. - /// - private Tensor VectorToTensor(Vector vector) - { - // Convert Vector to 2D Tensor: (length,) -> (length, 1) - var shape = new int[] { vector.Length, 1 }; - var data = new T[vector.Length]; - for (int i = 0; i < vector.Length; i++) - { - data[i] = vector[i]; - } - return new Tensor(shape, new Vector(data)); - } - - #endregion } diff --git a/src/NeuralNetworks/CompiledModelHost.cs b/src/NeuralNetworks/CompiledModelHost.cs index 87b1908056..d46852c8dd 100644 --- a/src/NeuralNetworks/CompiledModelHost.cs +++ b/src/NeuralNetworks/CompiledModelHost.cs @@ -59,6 +59,43 @@ internal sealed class CompiledModelHost : IDisposable private bool _disposed; + /// + /// Symbolic-shape strategy that governs which input dims the compile cache + /// treats as variable. BatchDynamic (default) lets a single compiled plan + /// serve batch-size 1, 4, 32, 128 etc. without recompiling each time — + /// matching PyTorch's default torch.compile(dynamic=True) posture. + /// + private readonly SymbolicShapeMode _shapeMode; + + /// + /// Stable identity string for this model (typically the full type name plus any + /// architecture hash). Used as the filename prefix for disk-cached compiled plans + /// so multiple model types can share the same directory + /// without collision. Null = disk caching disabled for this host. + /// + private readonly string? _modelIdentity; + + /// + /// Set of shape keys we've already attempted to load from disk for a given + /// structure version. Prevents repeated failed-load IO on every Predict call. + /// Keyed as "v{structureVersion}_s{shapeHash}". + /// + private HashSet? _diskCheckedShapes; + + /// + /// Plans loaded from disk for this host. Preempt so we skip + /// the compile pass entirely on cold-start replay. + /// + private Dictionary>? _preloadedPlans; + + public CompiledModelHost( + SymbolicShapeMode shapeMode = SymbolicShapeMode.BatchDynamic, + string? modelIdentity = null) + { + _shapeMode = shapeMode; + _modelIdentity = modelIdentity; + } + /// /// Synchronizes lifecycle mutations on _cache, _lastCompiledVersion, /// and _disposed. Without it a concurrent Dispose/Invalidate @@ -178,6 +215,17 @@ public Tensor Predict( { try { + var concreteShape = (int[])input._shape.Clone(); + + // Disk-cache preload: first time we see a shape at this structure + // version, try to load a previously-compiled plan from PlanCache. + // Skips the compile cost entirely on cold-start replay. + if (TryUseDiskCachedPlan(concreteShape, structureVersion, out var preloadedResult)) + { + _lastCompiledVersion = structureVersion; + return preloadedResult; + } + // Preserve the Tensor return from eagerForward so the compile // pass can explicitly identify the output tensor. Discarding the // return (previously `() => { eagerForward(); }`) can select a @@ -185,9 +233,13 @@ public Tensor Predict( // tensor ambiguous to the tracer, producing wrong outputs on // replay. Keep the expression-bodied lambda so the value threads // through unchanged. - var plan = cache.GetOrCompileInference( - (int[])input._shape.Clone(), - () => eagerForward()); + var symbolicShape = BuildSymbolicShape(concreteShape, _shapeMode); + var plan = symbolicShape is null + ? cache.GetOrCompileInference(concreteShape, () => eagerForward()) + : cache.GetOrCompileInference(concreteShape, () => eagerForward(), symbolicShape); + + // Fire-and-forget disk save so subsequent cold starts skip the compile. + MaybeSavePlanToDisk(plan, concreteShape, structureVersion); // Safe to write _lastCompiledVersion outside the lock — int writes // are atomic in .NET, and the value is only used as a hint by the @@ -386,4 +438,164 @@ public void Dispose() } } } + + /// + /// Checks for a previously-saved plan matching + /// the current shape + structure version, and executes it if found. Returns + /// true when the disk-cached plan was used (caller should skip compilation). + /// + private bool TryUseDiskCachedPlan( + int[] concreteShape, + int structureVersion, + out Tensor result) + { + result = null!; + var planCache = PlanCache.Current; + if (planCache is null || _modelIdentity is null) + { + return false; + } + + var shapeKey = ComputeShapeKey(concreteShape, structureVersion); + + // Preloaded-in-memory hit: skip disk entirely. + Dictionary>? preloaded; + lock (_sync) + { + preloaded = _preloadedPlans; + } + if (preloaded is not null && preloaded.TryGetValue(shapeKey, out var cachedPlan)) + { + if (cachedPlan.IsValid(concreteShape)) + { + result = cachedPlan.Execute(); + return true; + } + } + + // First-time-miss: check disk once per (shape, version). + bool shouldCheckDisk; + lock (_sync) + { + _diskCheckedShapes ??= new HashSet(); + shouldCheckDisk = _diskCheckedShapes.Add(shapeKey); + } + if (!shouldCheckDisk) + { + return false; + } + + try + { + var loaded = planCache.TryLoadInferenceAsync( + _modelIdentity, structureVersion, concreteShape, AiDotNetEngine.Current) + .GetAwaiter().GetResult(); + + if (loaded is null || !loaded.IsValid(concreteShape)) + { + return false; + } + + lock (_sync) + { + (_preloadedPlans ??= new Dictionary>())[shapeKey] = loaded; + } + + result = loaded.Execute(); + return true; + } + catch (Exception ex) + { + System.Diagnostics.Trace.TraceWarning( + $"CompiledModelHost: disk-cached plan load failed: {ex.GetType().Name}: {ex.Message}"); + return false; + } + } + + /// + /// Fires a fire-and-forget disk save so subsequent process starts skip compile. + /// Errors are swallowed — disk caching is an optimization, not a correctness + /// dependency. + /// + private void MaybeSavePlanToDisk(ICompiledPlan plan, int[] concreteShape, int structureVersion) + { + var planCache = PlanCache.Current; + if (planCache is null || _modelIdentity is null) + { + return; + } + + var shapeKey = ComputeShapeKey(concreteShape, structureVersion); + lock (_sync) + { + // Remember the plan in-memory too; avoids re-saving on next call. + (_preloadedPlans ??= new Dictionary>())[shapeKey] = plan; + } + + var identity = _modelIdentity; + _ = Task.Run(async () => + { + try + { + await planCache.SaveInferenceAsync(plan, identity, structureVersion, concreteShape) + .ConfigureAwait(false); + } + catch (Exception ex) + { + System.Diagnostics.Trace.TraceWarning( + $"CompiledModelHost: background plan save failed: {ex.GetType().Name}: {ex.Message}"); + } + }); + } + + private static string ComputeShapeKey(int[] shape, int structureVersion) + { + var sb = new System.Text.StringBuilder(16 + shape.Length * 4); + sb.Append('v').Append(structureVersion).Append('_'); + for (int i = 0; i < shape.Length; i++) + { + if (i > 0) sb.Append('x'); + sb.Append(shape[i]); + } + return sb.ToString(); + } + + /// + /// Translates + concrete shape into a Tensors + /// (or null to fall back to the purely concrete + /// overload when the rank is too small for the requested mode). + /// + private static SymbolicShape? BuildSymbolicShape(int[] shape, SymbolicShapeMode mode) + { + switch (mode) + { + case SymbolicShapeMode.Static: + return null; + case SymbolicShapeMode.BatchDynamic: + return shape.Length >= 2 ? SymbolicShape.BatchDynamic(shape) : null; + case SymbolicShapeMode.BatchAndSeqDynamic: + return shape.Length >= 3 ? SymbolicShape.BatchAndSeqDynamic(shape) : null; + case SymbolicShapeMode.AllDynamic: + return SymbolicShape.AllDynamic(shape); + default: + return null; + } + } +} + +/// +/// Strategy for how keys the compile cache. Dynamic +/// dims let one compiled plan serve multiple concrete shapes — essential for bursty +/// inference traffic where every request has a different batch size. +/// +public enum SymbolicShapeMode +{ + /// Every dim treated as static. Recompile on any shape change. + Static, + /// Dim 0 (batch) dynamic; all others static. PyTorch-default behavior. + BatchDynamic, + /// Dims 0 (batch) and 1 (seq-len) dynamic. For transformer-style inputs. + BatchAndSeqDynamic, + /// Every dim dynamic. Maximum reuse; slight dispatch overhead on replay. + AllDynamic, } diff --git a/src/NeuralNetworks/ConvolutionalNeuralNetwork.cs b/src/NeuralNetworks/ConvolutionalNeuralNetwork.cs index 7f51850525..357a4c2368 100644 --- a/src/NeuralNetworks/ConvolutionalNeuralNetwork.cs +++ b/src/NeuralNetworks/ConvolutionalNeuralNetwork.cs @@ -234,10 +234,11 @@ public override void UpdateParameters(Vector parameters) /// about what's in the image. /// /// - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through for + /// compiled-plan replay; remains the eager fallback. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// /// Trains the convolutional neural network using the provided input and expected output. diff --git a/src/NeuralNetworks/EfficientNetNetwork.cs b/src/NeuralNetworks/EfficientNetNetwork.cs index 134265eb1c..167014587b 100644 --- a/src/NeuralNetworks/EfficientNetNetwork.cs +++ b/src/NeuralNetworks/EfficientNetNetwork.cs @@ -293,10 +293,11 @@ public Tensor Forward(Tensor input) } /// - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through for + /// compiled-plan replay; remains the eager fallback. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// public override void Train(Tensor input, Tensor expectedOutput) diff --git a/src/NeuralNetworks/FastText.cs b/src/NeuralNetworks/FastText.cs index e8e1b1fd9f..5a3135a35e 100644 --- a/src/NeuralNetworks/FastText.cs +++ b/src/NeuralNetworks/FastText.cs @@ -244,11 +244,11 @@ public override void UpdateParameters(Vector parameters) } } - /// - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through for + /// compiled-plan replay; remains the eager fallback. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// /// Trains the model on a single step of data using standard backpropagation. diff --git a/src/NeuralNetworks/GloVe.cs b/src/NeuralNetworks/GloVe.cs index a39342e9f2..c0cd5cd68d 100644 --- a/src/NeuralNetworks/GloVe.cs +++ b/src/NeuralNetworks/GloVe.cs @@ -293,10 +293,11 @@ public override void UpdateParameters(Vector parameters) /// For Beginners: This is identical to the Forward pass—it takes word IDs and /// returns their addresses (embeddings). /// - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through for + /// compiled-plan replay; remains the eager fallback. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// /// Trains the model on a batch of word pairs and their co-occurrence counts. diff --git a/src/NeuralNetworks/Layers/DeformableConvolutionalLayer.cs b/src/NeuralNetworks/Layers/DeformableConvolutionalLayer.cs index be85179aa1..8bfb218a77 100644 --- a/src/NeuralNetworks/Layers/DeformableConvolutionalLayer.cs +++ b/src/NeuralNetworks/Layers/DeformableConvolutionalLayer.cs @@ -636,58 +636,6 @@ private Tensor InitializeWeights(int outC, int inC, int kH, int kW) return [_outputChannels, outH, outW]; } - /// - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - if (!SupportsJitCompilation) - throw new InvalidOperationException("Layer weights not initialized. Cannot build computation graph."); - - // Create constant nodes for weights - var kernelNode = TensorOperations.Constant(_weights, $"{namePrefix}kernel"); - var biasNode = TensorOperations.Constant(_bias, $"{namePrefix}bias"); - var offsetWeightsNode = TensorOperations.Constant(_offsetWeights, $"{namePrefix}offset_weights"); - var offsetBiasNode = TensorOperations.Constant(_offsetBias, $"{namePrefix}offset_bias"); - - // First compute offsets using standard convolution - var offsetsNode = TensorOperations.Conv2D( - inputNode, - offsetWeightsNode, - offsetBiasNode, - stride: new int[] { _stride, _stride }, - padding: new int[] { _padding, _padding }); - - // Optionally compute modulation mask - ComputationNode? maskNode = null; - if (_useModulation && _maskWeights != null && _maskBias != null) - { - var maskWeightsNode = TensorOperations.Constant(_maskWeights, $"{namePrefix}mask_weights"); - var maskBiasNode = TensorOperations.Constant(_maskBias, $"{namePrefix}mask_bias"); - - var rawMaskNode = TensorOperations.Conv2D( - inputNode, - maskWeightsNode, - maskBiasNode, - stride: new int[] { _stride, _stride }, - padding: new int[] { _padding, _padding }); - - // Apply sigmoid activation to mask - maskNode = TensorOperations.Sigmoid(rawMaskNode); - } - - // Apply deformable convolution with computed offsets and mask - var deformConvNode = TensorOperations.DeformableConv2D( - inputNode, - kernelNode, - offsetsNode, - maskNode, - biasNode, - stride: new int[] { _stride, _stride }, - padding: new int[] { _padding, _padding }, - dilation: new int[] { 1, 1 }); - - return deformConvNode; - } - #endregion #region Parameter Management diff --git a/src/NeuralNetworks/Layers/DenseBlockLayer.cs b/src/NeuralNetworks/Layers/DenseBlockLayer.cs index 115bd12324..8497160733 100644 --- a/src/NeuralNetworks/Layers/DenseBlockLayer.cs +++ b/src/NeuralNetworks/Layers/DenseBlockLayer.cs @@ -227,56 +227,5 @@ public override void ResetState() _conv3x3.ResetState(); } - /// - public ComputationNode BuildComputationGraph( - ComputationNode inputNode, - string namePrefix) - { - // BN1 - var bn1Node = TensorOperations.BatchNorm( - inputNode, - gamma: TensorOperations.Constant(_bn1.GetGamma(), $"{namePrefix}bn1_gamma"), - beta: TensorOperations.Constant(_bn1.GetBeta(), $"{namePrefix}bn1_beta"), - runningMean: _bn1.GetRunningMean(), - runningVar: _bn1.GetRunningVariance(), - training: false, - epsilon: NumOps.ToDouble(_bn1.GetEpsilon())); - - // ReLU1 - var relu1Node = TensorOperations.ReLU(bn1Node); - - // Conv1x1 - var conv1x1Biases = _conv1x1.GetBiases(); - var conv1x1Node = TensorOperations.Conv2D( - relu1Node, - TensorOperations.Constant(_conv1x1.GetFilters(), $"{namePrefix}conv1x1_kernel"), - conv1x1Biases is not null ? TensorOperations.Constant(conv1x1Biases, $"{namePrefix}conv1x1_bias") : null, - stride: new int[] { _conv1x1.Stride, _conv1x1.Stride }, - padding: new int[] { _conv1x1.Padding, _conv1x1.Padding }); - - // BN2 - var bn2Node = TensorOperations.BatchNorm( - conv1x1Node, - gamma: TensorOperations.Constant(_bn2.GetGamma(), $"{namePrefix}bn2_gamma"), - beta: TensorOperations.Constant(_bn2.GetBeta(), $"{namePrefix}bn2_beta"), - runningMean: _bn2.GetRunningMean(), - runningVar: _bn2.GetRunningVariance(), - training: false, - epsilon: NumOps.ToDouble(_bn2.GetEpsilon())); - - // ReLU2 - var relu2Node = TensorOperations.ReLU(bn2Node); - - // Conv3x3 - var conv3x3Biases = _conv3x3.GetBiases(); - var outputNode = TensorOperations.Conv2D( - relu2Node, - TensorOperations.Constant(_conv3x3.GetFilters(), $"{namePrefix}conv3x3_kernel"), - conv3x3Biases is not null ? TensorOperations.Constant(conv3x3Biases, $"{namePrefix}conv3x3_bias") : null, - stride: new int[] { _conv3x3.Stride, _conv3x3.Stride }, - padding: new int[] { _conv3x3.Padding, _conv3x3.Padding }); - - return outputNode; - } } diff --git a/src/NeuralNetworks/Layers/FeedForwardLayer.cs b/src/NeuralNetworks/Layers/FeedForwardLayer.cs index 51e6ee5cf1..6c59a66273 100644 --- a/src/NeuralNetworks/Layers/FeedForwardLayer.cs +++ b/src/NeuralNetworks/Layers/FeedForwardLayer.cs @@ -519,13 +519,23 @@ public override Tensor Forward(Tensor input) EnsureInitialized(); Input = input; - // Use Engine.TensorMatMul for GPU acceleration - var matmul = Engine.TensorMatMul(Input, _weights); + // Fuse MatMul + BiasAdd + Activation into a single FusedLinear call when the + // activation is one the engine's fused kernel supports (ReLU, GELU, Tanh, + // Sigmoid, Swish/SiLU, LeakyReLU). Saves 3 kernel launches per layer. + var fusedActivation = GetFusedActivationType(); - // Add biases (broadcast [1, outputSize] to [batchSize, outputSize]) using engine op - var biasBroadcast = Engine.Reshape(_biases, [1, _weights.Shape[1]]); - var linearOutput = Engine.TensorBroadcastAdd(matmul, biasBroadcast); + if (fusedActivation != FusedActivationType.None && !IsTrainingMode) + { + // Pure-inference fast path — no activation-gradient cache needed. + Output = Engine.FusedLinear(input, _weights, _biases, fusedActivation); + return Output; + } + // Training or non-fusable activation: emit the linear pre-activation in one + // fused call, then apply activation separately so the tape records a single + // FusedLinear entry (calling FusedLinear twice per forward corrupts tape + // accounting via RemoveLastNTapeEntries). + var linearOutput = Engine.FusedLinear(input, _weights, _biases, FusedActivationType.None); PreActivationOutput = linearOutput; Output = ApplyActivation(linearOutput); diff --git a/src/NeuralNetworks/Layers/GraphAttentionLayer.cs b/src/NeuralNetworks/Layers/GraphAttentionLayer.cs index d14b71e983..3e12711637 100644 --- a/src/NeuralNetworks/Layers/GraphAttentionLayer.cs +++ b/src/NeuralNetworks/Layers/GraphAttentionLayer.cs @@ -29,7 +29,6 @@ namespace AiDotNet.NeuralNetworks.Layers; /// Tensor-based weights for all parameters /// Dual backward pass: BackwardManual() for efficiency, BackwardViaAutodiff() for accuracy /// Full gradient computation through attention mechanism -/// JIT compilation support via ExportComputationGraph() /// Complete GetParameters()/SetParameters() for model persistence /// /// diff --git a/src/NeuralNetworks/Layers/GraphIsomorphismLayer.cs b/src/NeuralNetworks/Layers/GraphIsomorphismLayer.cs index 13b72255bf..7559d886af 100644 --- a/src/NeuralNetworks/Layers/GraphIsomorphismLayer.cs +++ b/src/NeuralNetworks/Layers/GraphIsomorphismLayer.cs @@ -29,7 +29,6 @@ namespace AiDotNet.NeuralNetworks.Layers; /// Tensor-based weights for all parameters /// Dual backward pass: BackwardManual() for efficiency, BackwardViaAutodiff() for accuracy /// Full gradient computation through MLP and aggregation paths -/// JIT compilation support via ExportComputationGraph() /// Complete GetParameters()/SetParameters() for model persistence /// /// diff --git a/src/NeuralNetworks/Layers/GraphSAGELayer.cs b/src/NeuralNetworks/Layers/GraphSAGELayer.cs index a545027846..8778427171 100644 --- a/src/NeuralNetworks/Layers/GraphSAGELayer.cs +++ b/src/NeuralNetworks/Layers/GraphSAGELayer.cs @@ -29,7 +29,6 @@ namespace AiDotNet.NeuralNetworks.Layers; /// Tensor-based weights for all parameters /// Dual backward pass: BackwardManual() for efficiency, BackwardViaAutodiff() for accuracy /// Full gradient computation through aggregation paths -/// JIT compilation support via ExportComputationGraph() /// Complete GetParameters()/SetParameters() for model persistence /// /// diff --git a/src/NeuralNetworks/Layers/InvertedResidualBlock.cs b/src/NeuralNetworks/Layers/InvertedResidualBlock.cs index 2764951765..44afd0ec1a 100644 --- a/src/NeuralNetworks/Layers/InvertedResidualBlock.cs +++ b/src/NeuralNetworks/Layers/InvertedResidualBlock.cs @@ -552,106 +552,6 @@ public override void ResetState() } - /// - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - var current = inputNode; - - // Expansion phase (if expansion > 1) - if (_hasExpansion && _expandConv is not null && _expandBn is not null) - { - // Conv1x1 expansion - var expandBiases = _expandConv.GetBiases(); - current = TensorOperations.Conv2D( - current, - TensorOperations.Constant(_expandConv.GetFilters(), $"{namePrefix}expand_kernel"), - expandBiases is not null ? TensorOperations.Constant(expandBiases, $"{namePrefix}expand_bias") : null, - stride: new int[] { _expandConv.Stride, _expandConv.Stride }, - padding: new int[] { _expandConv.Padding, _expandConv.Padding }); - - // BN - current = TensorOperations.BatchNorm( - current, - gamma: TensorOperations.Constant(_expandBn.GetGamma(), $"{namePrefix}expand_bn_gamma"), - beta: TensorOperations.Constant(_expandBn.GetBeta(), $"{namePrefix}expand_bn_beta"), - runningMean: _expandBn.GetRunningMean(), - runningVar: _expandBn.GetRunningVariance(), - training: false, - epsilon: NumOps.ToDouble(_expandBn.GetEpsilon())); - - // Apply activation using proper JIT graph integration - if (ScalarActivation is not null) - current = ScalarActivation.ApplyToGraph(current); - } - - // Depthwise convolution phase - { - var dwBiases = _dwConv.GetBiases(); - current = TensorOperations.Conv2D( - current, - TensorOperations.Constant(_dwConv.GetFilters(), $"{namePrefix}dw_kernel"), - dwBiases is not null ? TensorOperations.Constant(dwBiases, $"{namePrefix}dw_bias") : null, - stride: new int[] { _dwConv.Stride, _dwConv.Stride }, - padding: new int[] { _dwConv.Padding, _dwConv.Padding }); - - // BN - current = TensorOperations.BatchNorm( - current, - gamma: TensorOperations.Constant(_dwBn.GetGamma(), $"{namePrefix}dw_bn_gamma"), - beta: TensorOperations.Constant(_dwBn.GetBeta(), $"{namePrefix}dw_bn_beta"), - runningMean: _dwBn.GetRunningMean(), - runningVar: _dwBn.GetRunningVariance(), - training: false, - epsilon: NumOps.ToDouble(_dwBn.GetEpsilon())); - - // Apply activation using proper JIT graph integration - if (ScalarActivation is not null) - current = ScalarActivation.ApplyToGraph(current); - } - - // Squeeze-and-Excitation phase (if used) - // Note: SE layer expects NHWC format, so we need to transpose - if (_useSE && _se is not null) - { - // Transpose NCHW [B, C, H, W] -> NHWC [B, H, W, C] - current = TensorOperations.Permute(current, new int[] { 0, 2, 3, 1 }); - - // Apply SE layer - current = _se.BuildComputationGraph(current, $"{namePrefix}se_"); - - // Transpose NHWC [B, H, W, C] -> NCHW [B, C, H, W] - current = TensorOperations.Permute(current, new int[] { 0, 3, 1, 2 }); - } - - // Projection phase (LINEAR - no activation) - { - var projectBiases = _projectConv.GetBiases(); - current = TensorOperations.Conv2D( - current, - TensorOperations.Constant(_projectConv.GetFilters(), $"{namePrefix}project_kernel"), - projectBiases is not null ? TensorOperations.Constant(projectBiases, $"{namePrefix}project_bias") : null, - stride: new int[] { _projectConv.Stride, _projectConv.Stride }, - padding: new int[] { _projectConv.Padding, _projectConv.Padding }); - - // BN only, no activation (linear bottleneck) - current = TensorOperations.BatchNorm( - current, - gamma: TensorOperations.Constant(_projectBn.GetGamma(), $"{namePrefix}project_bn_gamma"), - beta: TensorOperations.Constant(_projectBn.GetBeta(), $"{namePrefix}project_bn_beta"), - runningMean: _projectBn.GetRunningMean(), - runningVar: _projectBn.GetRunningVariance(), - training: false, - epsilon: NumOps.ToDouble(_projectBn.GetEpsilon())); - } - - // Residual connection (only if dimensions match) - if (_useResidual) - { - current = TensorOperations.Add(current, inputNode); - } - - return current; - } #region Helper Methods diff --git a/src/NeuralNetworks/Layers/LayerBase.cs b/src/NeuralNetworks/Layers/LayerBase.cs index df3ff6e6a7..efe4d3c8f3 100644 --- a/src/NeuralNetworks/Layers/LayerBase.cs +++ b/src/NeuralNetworks/Layers/LayerBase.cs @@ -854,57 +854,6 @@ public virtual int[] GetInputShape() => /// public virtual Tensor? GetBiases() => null; - /// - /// Exports the layer's computation graph for JIT compilation. - /// - /// List to populate with input computation nodes. - /// The output computation node representing the layer's operation. - /// - /// - /// This method constructs a computation graph representation of the layer's forward pass - /// that can be JIT compiled for faster inference. All layers MUST implement this method - /// to support JIT compilation. - /// - /// For Beginners: JIT (Just-In-Time) compilation converts the layer's operations - /// into optimized native code for 5-10x faster inference. - /// - /// To support JIT compilation, a layer must: - /// 1. Implement this method to export its computation graph - /// 2. Set SupportsJitCompilation to true - /// 3. Use ComputationNode and TensorOperations to build the graph - /// - /// All layers are required to implement this method, even if they set SupportsJitCompilation = false. - /// - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException( - $"{GetType().Name} does not implement ExportComputationGraph. " + - "Use GradientTape-based autodiff instead."); - } - - /// - /// Gets whether this layer supports JIT compilation. - /// - /// True if the layer can be JIT compiled, false otherwise. - /// - /// - /// This property indicates whether the layer has implemented ExportComputationGraph() - /// and can benefit from JIT compilation. All layers MUST implement this property. - /// - /// For Beginners: JIT compilation can make inference 5-10x faster by converting - /// the layer's operations into optimized native code. - /// - /// Layers should return false if they: - /// - Have not yet implemented a working ExportComputationGraph() - /// - Use dynamic operations that change based on input data - /// - Are too simple to benefit from JIT compilation - /// - /// When false, the layer will use the standard Forward() method instead. - /// - /// - public virtual bool SupportsJitCompilation => false; - /// /// Declares the named input ports this layer accepts. /// Default: single port named "input" with the layer's input shape. @@ -2584,93 +2533,6 @@ public virtual Dictionary GetDiagnostics() return diagnostics; } - /// - /// Applies the layer's configured activation function to a computation graph node. - /// - /// The computation node to apply activation to. - /// The computation node with activation applied. - /// Thrown if input is null. - /// Thrown if activation does not support JIT. - /// - /// - /// This helper method delegates to the activation's ApplyToGraph method, - /// following the Open/Closed Principle. Adding new activations does not require - /// modifying layer code. - /// - /// For Beginners: This method adds the activation function to the computation graph. - /// - /// Instead of the layer code checking what type of activation is configured (which would - /// require changing the layer every time a new activation is added), this method simply - /// asks the activation to add itself to the graph. This makes the code more maintainable - /// and extensible. - /// - /// - protected ComputationNode ApplyActivationToGraph(ComputationNode input) - { - if (input == null) - throw new ArgumentNullException(nameof(input)); - - // Check scalar activation first - if (ScalarActivation is not null) - { - if (!ScalarActivation.SupportsJitCompilation) - { - throw new NotSupportedException( - $"Activation {ScalarActivation.GetType().Name} does not support JIT compilation. " + - $"Either the gradient computation is not implemented yet, or the activation " + - $"uses operations not compatible with computation graphs."); - } - - return ScalarActivation.ApplyToGraph(input); - } - - // Check vector activation - if (VectorActivation is not null) - { - if (!VectorActivation.SupportsJitCompilation) - { - throw new NotSupportedException( - $"Activation {VectorActivation.GetType().Name} does not support JIT compilation. " + - $"Either the gradient computation is not implemented yet, or the activation " + - $"uses operations not compatible with computation graphs."); - } - - return VectorActivation.ApplyToGraph(input); - } - - // No activation configured (identity) - return input; - } - - /// - /// Checks if the layer's current activation function supports JIT compilation. - /// - /// True if the activation can be JIT compiled, false otherwise. - /// - /// - /// This method checks whether the layer's configured activation function supports - /// JIT compilation by querying the activation's SupportsJitCompilation property. - /// If no activation is configured, returns true (identity function is always JIT-compatible). - /// - /// For Beginners: This method checks if the activation is ready for JIT compilation. - /// - /// The layer uses this to determine if it can export a computation graph for faster inference. - /// If the activation does not support JIT yet (because gradients are not implemented), the - /// layer will fall back to the standard execution path. - /// - /// - protected bool CanActivationBeJitted() - { - if (ScalarActivation is not null) - return ScalarActivation.SupportsJitCompilation; - - if (VectorActivation is not null) - return VectorActivation.SupportsJitCompilation; - - // No activation (identity) always supports JIT - return true; - } - /// /// Returns layer-specific metadata for serialization purposes. /// diff --git a/src/NeuralNetworks/Layers/PrincipalNeighbourhoodAggregationLayer.cs b/src/NeuralNetworks/Layers/PrincipalNeighbourhoodAggregationLayer.cs index e5a35ece26..9eea18dde3 100644 --- a/src/NeuralNetworks/Layers/PrincipalNeighbourhoodAggregationLayer.cs +++ b/src/NeuralNetworks/Layers/PrincipalNeighbourhoodAggregationLayer.cs @@ -31,7 +31,6 @@ namespace AiDotNet.NeuralNetworks.Layers; /// BatchMatMul for efficient batched graph operations /// Dual backward pass: BackwardManual() for efficiency, BackwardViaAutodiff() for accuracy /// Full gradient computation through all aggregators and scalers -/// JIT compilation support via ExportComputationGraph() /// Complete GetParameters()/SetParameters() for model persistence /// /// diff --git a/src/NeuralNetworks/Layers/RRDBLayer.cs b/src/NeuralNetworks/Layers/RRDBLayer.cs index 7678e16799..9126db680d 100644 --- a/src/NeuralNetworks/Layers/RRDBLayer.cs +++ b/src/NeuralNetworks/Layers/RRDBLayer.cs @@ -357,18 +357,6 @@ public override void ResetState() } } - /// - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - // Pass through 3 Residual Dense Blocks sequentially - var x = _rdbBlocks[0].BuildComputationGraph(inputNode, $"{namePrefix}rdb0_"); - x = _rdbBlocks[1].BuildComputationGraph(x, $"{namePrefix}rdb1_"); - x = _rdbBlocks[2].BuildComputationGraph(x, $"{namePrefix}rdb2_"); - - // Global residual: output = RDB3_output * residualScale + input - var scaledOutput = ScaleNode(x, _residualScale, $"{namePrefix}global_scale"); - return TensorOperations.Add(scaledOutput, inputNode); - } /// /// Scales a computation node by a scalar value using element-wise multiplication. diff --git a/src/NeuralNetworks/Layers/RRDBNetGenerator.cs b/src/NeuralNetworks/Layers/RRDBNetGenerator.cs index 26efd42095..afe86f2120 100644 --- a/src/NeuralNetworks/Layers/RRDBNetGenerator.cs +++ b/src/NeuralNetworks/Layers/RRDBNetGenerator.cs @@ -515,44 +515,6 @@ public override void ResetState() #region JIT Compilation - /// - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - // Initial feature extraction - var x = BuildConvNode(_convFirst, inputNode, $"{namePrefix}conv_first_"); - - // Store for global residual - var conv1Output = x; - - // RRDB blocks - for (int i = 0; i < _rrdbBlocks.Length; i++) - { - x = _rrdbBlocks[i].BuildComputationGraph(x, $"{namePrefix}rrdb{i}_"); - } - - // Trunk conv - x = BuildConvNode(_trunkConv, x, $"{namePrefix}trunk_conv_"); - - // Global residual - x = TensorOperations.Add(x, conv1Output); - - // Upsampling stages - for (int i = 0; i < _upsampleConvs.Length; i++) - { - x = BuildConvNode(_upsampleConvs[i], x, $"{namePrefix}upsample{i}_conv_"); - x = TensorOperations.PixelShuffle(x, 2); - x = TensorOperations.LeakyReLU(x, 0.2); - } - - // HR conv + LeakyReLU - x = BuildConvNode(_hrConv, x, $"{namePrefix}hr_conv_"); - x = TensorOperations.LeakyReLU(x, 0.2); - - // Final conv - x = BuildConvNode(_convLast, x, $"{namePrefix}conv_last_"); - - return x; - } /// /// Builds a Conv2D computation node from a ConvolutionalLayer. diff --git a/src/NeuralNetworks/Layers/ResidualDenseBlock.cs b/src/NeuralNetworks/Layers/ResidualDenseBlock.cs index f58b9a310d..6818b56965 100644 --- a/src/NeuralNetworks/Layers/ResidualDenseBlock.cs +++ b/src/NeuralNetworks/Layers/ResidualDenseBlock.cs @@ -769,39 +769,6 @@ public override void ResetState() } } - /// - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - // x0 = input - var x0 = inputNode; - - // Conv1: x1 = LeakyReLU(Conv(x0)) - var conv1Output = BuildConvNode(_convLayers[0], x0, $"{namePrefix}conv1_"); - var x1 = TensorOperations.LeakyReLU(conv1Output, 0.2); - - // Conv2: x2 = LeakyReLU(Conv(concat(x0, x1))) - var concat1 = TensorOperations.Concat([x0, x1], axis: 1); - var conv2Output = BuildConvNode(_convLayers[1], concat1, $"{namePrefix}conv2_"); - var x2 = TensorOperations.LeakyReLU(conv2Output, 0.2); - - // Conv3: x3 = LeakyReLU(Conv(concat(x0, x1, x2))) - var concat2 = TensorOperations.Concat([concat1, x2], axis: 1); - var conv3Output = BuildConvNode(_convLayers[2], concat2, $"{namePrefix}conv3_"); - var x3 = TensorOperations.LeakyReLU(conv3Output, 0.2); - - // Conv4: x4 = LeakyReLU(Conv(concat(x0, x1, x2, x3))) - var concat3 = TensorOperations.Concat([concat2, x3], axis: 1); - var conv4Output = BuildConvNode(_convLayers[3], concat3, $"{namePrefix}conv4_"); - var x4 = TensorOperations.LeakyReLU(conv4Output, 0.2); - - // Conv5: x5 = Conv(concat(x0, x1, x2, x3, x4)) - NO activation - var concat4 = TensorOperations.Concat([concat3, x4], axis: 1); - var x5 = BuildConvNode(_convLayers[4], concat4, $"{namePrefix}conv5_"); - - // Local residual: output = x5 * residualScale + x0 - var scaledX5 = ScaleNode(x5, _residualScale, $"{namePrefix}residual_scale"); - return TensorOperations.Add(scaledX5, x0); - } /// /// Builds a Conv2D computation node from a ConvolutionalLayer. diff --git a/src/NeuralNetworks/Layers/SparseLinearLayer.cs b/src/NeuralNetworks/Layers/SparseLinearLayer.cs index d3076952ff..31f02d1901 100644 --- a/src/NeuralNetworks/Layers/SparseLinearLayer.cs +++ b/src/NeuralNetworks/Layers/SparseLinearLayer.cs @@ -435,28 +435,6 @@ public override void ResetState() _biasesGradient = null; } - /// - /// Applies the activation function to a computation node. - /// - private ComputationNode ApplyActivationToComputationNode(ComputationNode node) - { - // ScalarActivation is guaranteed non-null here since this method is only called when ScalarActivation is not null - if (ScalarActivation is null) - throw new InvalidOperationException("ScalarActivation cannot be null when applying activation to computation node."); - - // Use ApplyToGraph - the layer's SupportsJitCompilation property ensures this is only - // called when the activation supports JIT compilation - if (ScalarActivation.SupportsJitCompilation) - { - return ScalarActivation.ApplyToGraph(node); - } - - // This should never be reached if SupportsJitCompilation is checked before ExportComputationGraph - throw new InvalidOperationException( - $"Internal error: Activation function '{ScalarActivation.GetType().Name}' does not support JIT compilation. " + - "This indicates the layer's SupportsJitCompilation property was not checked before calling ExportComputationGraph."); - } - /// /// Transposes a matrix using O(1) stride-based view. /// diff --git a/src/NeuralNetworks/Layers/SpyNetLayer.cs b/src/NeuralNetworks/Layers/SpyNetLayer.cs index 21d5e522ac..700da49ee0 100644 --- a/src/NeuralNetworks/Layers/SpyNetLayer.cs +++ b/src/NeuralNetworks/Layers/SpyNetLayer.cs @@ -1230,156 +1230,6 @@ protected override void Dispose(bool disposing) /// public new int[] GetOutputShape() => [2, _inputHeight, _inputWidth]; - /// - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - if (!SupportsJitCompilation) - throw new InvalidOperationException("Layer modules not initialized for JIT compilation."); - - // Input is concatenated frames [batch, 2*channels, height, width] - // Split into two frames - var splitNodes = TensorOperations.Split(inputNode, 2, axis: 1); - var frame1Node = splitNodes[0]; - var frame2Node = splitNodes[1]; - - // Build pyramid nodes for both frames - var pyramid1 = BuildPyramidGraph(frame1Node, $"{namePrefix}pyr1_"); - var pyramid2 = BuildPyramidGraph(frame2Node, $"{namePrefix}pyr2_"); - - // Initialize flow at coarsest level (zeros - constant node) - int coarseH = _inputHeight >> (_numLevels - 1); - int coarseW = _inputWidth >> (_numLevels - 1); - var zeroFlow = new Tensor(new[] { 1, 2, coarseH, coarseW }); - var flowNode = TensorOperations.Constant(zeroFlow, $"{namePrefix}zero_flow"); - - // Coarse-to-fine refinement - for (int level = _numLevels - 1; level >= 0; level--) - { - var img1Node = pyramid1[level]; - var img2Node = pyramid2[level]; - - // Upsample flow if not at coarsest level - if (level < _numLevels - 1) - { - int targetH = _inputHeight >> level; - int targetW = _inputWidth >> level; - flowNode = TensorOperations.Upsample(flowNode, 2); - // Scale flow values by 2 for upsampling - var scaleNode = TensorOperations.Constant( - CreateScaleTensor(flowNode.Value._shape, 2.0), $"{namePrefix}scale_{level}"); - flowNode = TensorOperations.ElementwiseMultiply(flowNode, scaleNode); - } - - // Warp img2 using flow via GridSample - // First create sampling grid from flow - var gridNode = CreateGridFromFlowGraph(flowNode, $"{namePrefix}grid_{level}_"); - var warpedNode = TensorOperations.GridSample(img2Node, gridNode); - - // Concatenate [img1, warped2, flow] for basic module input - var moduleInputNode = TensorOperations.Concat( - new List> { img1Node, warpedNode, flowNode }, axis: 1); - - // Get residual flow from basic module - var residualFlowNode = _basicModules[level].ExportComputationGraph( - new List> { moduleInputNode }); - - // Extract first 2 channels as residual flow (if module outputs more) - // Add residual to current flow - flowNode = TensorOperations.Add(flowNode, residualFlowNode); - } - - return flowNode; - } - - private List> BuildPyramidGraph(ComputationNode imageNode, string namePrefix) - { - var pyramid = new List> { imageNode }; - var currentNode = imageNode; - - for (int i = 1; i < _numLevels; i++) - { - // Downsample by factor of 2 using average pooling - currentNode = TensorOperations.AvgPool2D( - currentNode, - poolSize: new[] { 2, 2 }, - strides: new[] { 2, 2 }); - pyramid.Add(currentNode); - } - - return pyramid; - } - - private ComputationNode CreateGridFromFlowGraph(ComputationNode flowNode, string namePrefix) - { - // Create identity grid and add flow to get sampling positions - // Grid should be [batch, height, width, 2] in normalized coordinates [-1, 1] - var flowShape = flowNode.Value._shape; - int batch = flowShape[0]; - int height = flowShape[2]; - int width = flowShape[3]; - - // Create base identity grid - var identityGrid = CreateIdentityGrid(batch, height, width); - var identityNode = TensorOperations.Constant(identityGrid, $"{namePrefix}identity"); - - // Reshape flow from [B, 2, H, W] to [B, H, W, 2] and normalize - var permutedFlow = TensorOperations.Permute(flowNode, 0, 2, 3, 1); - - // Scale flow to normalized coordinates: flow / (dim - 1) * 2 - T widthScale = NumOps.FromDouble(2.0 / (width - 1)); - T heightScale = NumOps.FromDouble(2.0 / (height - 1)); - var scaleData = new T[batch * height * width * 2]; - for (int b = 0; b < batch; b++) - { - for (int h = 0; h < height; h++) - { - for (int w = 0; w < width; w++) - { - int idx = b * height * width * 2 + h * width * 2 + w * 2; - scaleData[idx] = widthScale; // x scale - scaleData[idx + 1] = heightScale; // y scale - } - } - } - var scaleTensor = new Tensor(new[] { batch, height, width, 2 }, new Vector(scaleData)); - var scaleNode = TensorOperations.Constant(scaleTensor, $"{namePrefix}scale"); - - var scaledFlow = TensorOperations.ElementwiseMultiply(permutedFlow, scaleNode); - var grid = TensorOperations.Add(identityNode, scaledFlow); - - return grid; - } - - private Tensor CreateIdentityGrid(int batch, int height, int width) - { - var data = new T[batch * height * width * 2]; - for (int b = 0; b < batch; b++) - { - for (int h = 0; h < height; h++) - { - for (int w = 0; w < width; w++) - { - int idx = b * height * width * 2 + h * width * 2 + w * 2; - // Normalized coordinates [-1, 1] - data[idx] = NumOps.FromDouble(2.0 * w / (width - 1) - 1.0); // x - data[idx + 1] = NumOps.FromDouble(2.0 * h / (height - 1) - 1.0); // y - } - } - } - return new Tensor(new[] { batch, height, width, 2 }, new Vector(data)); - } - - private Tensor CreateScaleTensor(int[] shape, double scale) - { - int totalSize = 1; - foreach (var dim in shape) totalSize *= dim; - var data = new T[totalSize]; - T scaleVal = NumOps.FromDouble(scale); - for (int i = 0; i < totalSize; i++) - data[i] = scaleVal; - return new Tensor(shape, new Vector(data)); - } - #endregion #region Parameter Management diff --git a/src/NeuralNetworks/Layers/SqueezeAndExcitationLayer.cs b/src/NeuralNetworks/Layers/SqueezeAndExcitationLayer.cs index cb5d9570d2..39c0d3685a 100644 --- a/src/NeuralNetworks/Layers/SqueezeAndExcitationLayer.cs +++ b/src/NeuralNetworks/Layers/SqueezeAndExcitationLayer.cs @@ -1206,25 +1206,6 @@ private Tensor ApplyTensorActivationDerivative(Tensor input, bool isFirstA return result; } - private Autodiff.ComputationNode ApplyActivationToGraphNode(Autodiff.ComputationNode input, bool isFirst) - { - if (isFirst) - { - if (_firstVectorActivation != null && _firstVectorActivation.SupportsJitCompilation) - return _firstVectorActivation.ApplyToGraph(input); - if (_firstActivation != null && _firstActivation.SupportsJitCompilation) - return _firstActivation.ApplyToGraph(input); - } - else - { - if (_secondVectorActivation != null && _secondVectorActivation.SupportsJitCompilation) - return _secondVectorActivation.ApplyToGraph(input); - if (_secondActivation != null && _secondActivation.SupportsJitCompilation) - return _secondActivation.ApplyToGraph(input); - } - return input; - } - /// /// Updates the layer's parameters using the calculated gradients and the specified learning rate. /// @@ -1585,56 +1566,6 @@ public override Dictionary GetDiagnostics() return diagnostics; } - /// - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - if (_weights1 == null || _weights2 == null || _bias1 == null || _bias2 == null) - throw new InvalidOperationException("Layer weights not initialized. Initialize the layer before compiling."); - - // Squeeze: Global Average Pooling across spatial dimensions - var squeezed = TensorOperations.ReduceMean(inputNode, axes: new[] { 1, 2 }, keepDims: false); - - // Excitation: First fully connected layer (weights and biases are already Tensor) - var weights1Node = TensorOperations.Constant(_weights1, $"{namePrefix}se_weights1"); - var bias1Node = TensorOperations.Constant(_bias1, $"{namePrefix}se_bias1"); - - var fc1Output = TensorOperations.MatrixMultiply(squeezed, weights1Node); - fc1Output = TensorOperations.Add(fc1Output, bias1Node); - - // Apply first activation (default: ReLU) - if (_firstActivation != null && _firstActivation.SupportsJitCompilation) - { - fc1Output = _firstActivation.ApplyToGraph(fc1Output); - } - else if (_firstVectorActivation == null) - { - fc1Output = TensorOperations.ReLU(fc1Output); - } - - // Excitation: Second fully connected layer (weights and biases are already Tensor) - var weights2Node = TensorOperations.Constant(_weights2, $"{namePrefix}se_weights2"); - var bias2Node = TensorOperations.Constant(_bias2, $"{namePrefix}se_bias2"); - - var fc2Output = TensorOperations.MatrixMultiply(fc1Output, weights2Node); - fc2Output = TensorOperations.Add(fc2Output, bias2Node); - - // Apply second activation (default: Sigmoid) - if (_secondActivation != null && _secondActivation.SupportsJitCompilation) - { - fc2Output = _secondActivation.ApplyToGraph(fc2Output); - } - else if (_secondVectorActivation == null) - { - fc2Output = TensorOperations.Sigmoid(fc2Output); - } - - // Scale: Multiply input by excitation weights (with broadcasting) - // fc2Output has shape [batch, channels], inputNode has shape [batch, height, width, channels] - // ElementwiseMultiply should handle broadcasting automatically - var scaledOutput = TensorOperations.ElementwiseMultiply(inputNode, fc2Output); - - return scaledOutput; - } public override void ClearGradients() { diff --git a/src/NeuralNetworks/Layers/TransitionLayer.cs b/src/NeuralNetworks/Layers/TransitionLayer.cs index a9bdc5def5..68f3859e6d 100644 --- a/src/NeuralNetworks/Layers/TransitionLayer.cs +++ b/src/NeuralNetworks/Layers/TransitionLayer.cs @@ -440,35 +440,5 @@ public override void ResetState() _pool.ResetState(); } - /// - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - // BN - var bnNode = TensorOperations.BatchNorm( - inputNode, - gamma: TensorOperations.Constant(_bn.GetGamma(), $"{namePrefix}bn_gamma"), - beta: TensorOperations.Constant(_bn.GetBeta(), $"{namePrefix}bn_beta"), - runningMean: _bn.GetRunningMean(), - runningVar: _bn.GetRunningVariance(), - training: false, - epsilon: NumOps.ToDouble(_bn.GetEpsilon())); - - // ReLU - var reluNode = TensorOperations.ReLU(bnNode); - - // Conv 1x1 - var convBiases = _conv.GetBiases(); - var convNode = TensorOperations.Conv2D( - reluNode, - TensorOperations.Constant(_conv.GetFilters(), $"{namePrefix}conv_kernel"), - convBiases is not null ? TensorOperations.Constant(convBiases, $"{namePrefix}conv_bias") : null, - stride: new int[] { _conv.Stride, _conv.Stride }, - padding: new int[] { _conv.Padding, _conv.Padding }); - - // Average Pooling 2x2, stride 2 - var poolNode = TensorOperations.AvgPool2D(convNode, poolSize: new int[] { 2, 2 }, strides: new int[] { 2, 2 }); - - return poolNode; - } } diff --git a/src/NeuralNetworks/Layers/UNetDiscriminator.cs b/src/NeuralNetworks/Layers/UNetDiscriminator.cs index 8bae2375f3..a2a7aea827 100644 --- a/src/NeuralNetworks/Layers/UNetDiscriminator.cs +++ b/src/NeuralNetworks/Layers/UNetDiscriminator.cs @@ -391,48 +391,6 @@ public override void ResetState() #endregion - #region JIT Compilation - - /// - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - // Initial conv + LeakyReLU - var x = BuildConvNode(_convFirst, inputNode, $"{namePrefix}conv_first_"); - x = TensorOperations.LeakyReLU(x, 0.2); - - // Encoder path - store skip connections - var skipConnections = new ComputationNode[_numBlocks]; - for (int i = 0; i < _numBlocks; i++) - { - skipConnections[i] = x; - x = _encoderBlocks[i].BuildComputationGraph(x, $"{namePrefix}enc{i}_"); - } - - // Decoder path - use skip connections - for (int i = 0; i < _numBlocks; i++) - { - int skipIdx = _numBlocks - 1 - i; - x = _decoderBlocks[i].BuildComputationGraph(x, skipConnections[skipIdx], $"{namePrefix}dec{i}_"); - } - - // Final conv - x = BuildConvNode(_convLast, x, $"{namePrefix}conv_last_"); - - return x; - } - - private static ComputationNode BuildConvNode(ConvolutionalLayer conv, ComputationNode input, string namePrefix) - { - var biases = conv.GetBiases(); - return TensorOperations.Conv2D( - input, - TensorOperations.Constant(conv.GetFilters(), $"{namePrefix}kernel"), - biases is not null ? TensorOperations.Constant(biases, $"{namePrefix}bias") : null, - stride: new int[] { conv.Stride, conv.Stride }, - padding: new int[] { conv.Padding, conv.Padding }); - } - - #endregion } @@ -562,28 +520,6 @@ public override void ResetState() _conv2.ResetState(); } - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - var biases1 = _conv1.GetBiases(); - var x = TensorOperations.Conv2D( - inputNode, - TensorOperations.Constant(_conv1.GetFilters(), $"{namePrefix}conv1_kernel"), - biases1 is not null ? TensorOperations.Constant(biases1, $"{namePrefix}conv1_bias") : null, - stride: new int[] { _downsample ? 2 : 1, _downsample ? 2 : 1 }, - padding: new int[] { 1, 1 }); - x = TensorOperations.LeakyReLU(x, 0.2); - - var biases2 = _conv2.GetBiases(); - x = TensorOperations.Conv2D( - x, - TensorOperations.Constant(_conv2.GetFilters(), $"{namePrefix}conv2_kernel"), - biases2 is not null ? TensorOperations.Constant(biases2, $"{namePrefix}conv2_bias") : null, - stride: new int[] { 1, 1 }, - padding: new int[] { 1, 1 }); - x = TensorOperations.LeakyReLU(x, 0.2); - - return x; - } } @@ -830,44 +766,7 @@ public override void ResetState() _conv2.ResetState(); } - public ComputationNode BuildComputationGraph(ComputationNode inputNode, string namePrefix) - { - return BuildComputationGraph(inputNode, null, namePrefix); - } - - public ComputationNode BuildComputationGraph(ComputationNode inputNode, ComputationNode? skipNode, string namePrefix) - { - // Upsample (bilinear) - var x = TensorOperations.Upsample(inputNode, 2); - - // Concatenate with skip - if (skipNode != null) - { - x = TensorOperations.Concat([x, skipNode], axis: 1); - } - - // Conv1 + LeakyReLU - var biases1 = _conv1.GetBiases(); - x = TensorOperations.Conv2D( - x, - TensorOperations.Constant(_conv1.GetFilters(), $"{namePrefix}conv1_kernel"), - biases1 is not null ? TensorOperations.Constant(biases1, $"{namePrefix}conv1_bias") : null, - stride: new int[] { 1, 1 }, - padding: new int[] { 1, 1 }); - x = TensorOperations.LeakyReLU(x, 0.2); - - // Conv2 + LeakyReLU - var biases2 = _conv2.GetBiases(); - x = TensorOperations.Conv2D( - x, - TensorOperations.Constant(_conv2.GetFilters(), $"{namePrefix}conv2_kernel"), - biases2 is not null ? TensorOperations.Constant(biases2, $"{namePrefix}conv2_bias") : null, - stride: new int[] { 1, 1 }, - padding: new int[] { 1, 1 }); - x = TensorOperations.LeakyReLU(x, 0.2); - return x; - } } diff --git a/src/NeuralNetworks/MobileNetV2Network.cs b/src/NeuralNetworks/MobileNetV2Network.cs index 46569a0206..e537b38b43 100644 --- a/src/NeuralNetworks/MobileNetV2Network.cs +++ b/src/NeuralNetworks/MobileNetV2Network.cs @@ -275,7 +275,6 @@ public Tensor Forward(Tensor input) return output; } - /// public override Dictionary> GetNamedLayerActivations(Tensor input) { // Mirror Forward's 3D → 4D expansion so the channel-broadcasted @@ -289,10 +288,11 @@ public override Dictionary> GetNamedLayerActivations(Tensor } /// - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through for + /// compiled-plan replay; remains the eager fallback. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// public override void Train(Tensor input, Tensor expectedOutput) diff --git a/src/NeuralNetworks/NeuralNetworkBase.cs b/src/NeuralNetworks/NeuralNetworkBase.cs index 791c2ecd90..65b1651c7f 100644 --- a/src/NeuralNetworks/NeuralNetworkBase.cs +++ b/src/NeuralNetworks/NeuralNetworkBase.cs @@ -328,6 +328,11 @@ protected NeuralNetworkBase(NeuralNetworkArchitecture architecture, ILossFunc LossFunction = lossFunction; _cachedParameterCount = null; _sensitiveFeatures = new Vector(0); + // Concrete subclass's type name threads through so disk-cached plans in + // PlanCache.Current don't collide between different model classes. + _compileHost = new CompiledModelHost( + shapeMode: SymbolicShapeMode.BatchDynamic, + modelIdentity: GetType().FullName ?? GetType().Name); } /// @@ -2212,7 +2217,7 @@ internal void SetGradientCheckpointingSegmentSize(int segmentSize) /// rather than re-implementing the compile+cache+fallback dance per /// model family. /// - private readonly CompiledModelHost _compileHost = new(); + private readonly CompiledModelHost _compileHost; /// /// Tracks input shapes whose compilation has previously failed on this @@ -4926,90 +4931,6 @@ protected virtual void Dispose(bool disposing) } } - #region IJitCompilable Implementation - - /// - /// - /// - /// Default is true because the base class's now - /// routes through , which auto-compiles when - /// TensorCodecOptions.EnableCompilation is on AND the model's op - /// graph is traceable, falling back to eager otherwise. So every - /// subclass is "JIT-capable" in the - /// effective sense: JIT is attempted, and failures degrade gracefully - /// to eager without the user noticing. - /// - /// - /// Subclasses whose forward path is known to be incompatible with graph - /// capture (non-Engine tensor access, scalar control flow that bakes at - /// trace time, layers whose outputs depend on mutable instance state) - /// should override this to return false — that signals "don't even - /// try" so tooling can short-circuit and users know to expect eager-only - /// performance. - /// - /// For Beginners: JIT (Just-In-Time) compilation optimizes neural networks for faster predictions. - /// - /// Instead of executing each layer one by one at runtime, JIT compilation: - /// - Analyzes the entire network structure - /// - Combines and optimizes operations - /// - Generates specialized native code - /// - Results in 5-10x faster predictions - /// - /// This is especially beneficial for: - /// - Production deployment (real-time predictions) - /// - Batch inference (processing many examples) - /// - Edge devices (mobile, embedded systems) - /// - /// - public virtual bool SupportsJitCompilation => true; - - /// - /// - /// - /// Exports the neural network as a computation graph for JIT compilation. - /// The graph represents the forward pass through all layers in sequence. - /// - /// For Beginners: This method converts the neural network into a computation graph. - /// - /// A computation graph is like a flowchart that describes: - /// 1. How data flows through each layer - /// 2. What operations each layer performs - /// 3. How layer outputs connect to the next layer's inputs - /// - /// The JIT compiler uses this graph to: - /// - Optimize the operations (remove redundancy) - /// - Fuse operations together (combine multiple steps) - /// - Generate fast native code - /// - /// For example, a simple network: - /// Input → Dense Layer → ReLU → Dense Layer → Output - /// - /// Becomes a graph: - /// input_node → matmul_node → add_bias_node → relu_node → matmul_node → add_bias_node - /// - /// The JIT compiler can then optimize this graph (e.g., fuse bias addition with matmul) - /// to create highly efficient code. - /// - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation has been removed."); - } - - protected virtual ComputationNode ConvertLayerToGraph(ILayer layer, ComputationNode input) - { - if (layer is Layers.LayerBase layerBase) - { - var layerInputs = new List> { input }; - return layerBase.ExportComputationGraph(layerInputs); - } - throw new NotSupportedException( - $"Layer {layer.GetType().Name} does not support computation graph export."); - } - - - #endregion - #region ILayeredModel Implementation /// diff --git a/src/NeuralNetworks/PlanCache.cs b/src/NeuralNetworks/PlanCache.cs new file mode 100644 index 0000000000..5ef8ae6bf3 --- /dev/null +++ b/src/NeuralNetworks/PlanCache.cs @@ -0,0 +1,156 @@ +using System.Security.Cryptography; +using System.Text; +using AiDotNet.Tensors.Engines; +using AiDotNet.Tensors.Engines.Compilation; +using AiDotNet.Tensors.Engines.Compilation.Serialization; + +namespace AiDotNet.NeuralNetworks; + +/// +/// Disk-backed store for compiled inference plans. Persists the traced plan after the +/// first compilation so subsequent process starts load the pre-compiled plan instead +/// of re-tracing + re-compiling. Directly wraps Tensors' +/// and . +/// +/// +/// +/// PyTorch-parity equivalent: torch.jit.save(traced_module, path) + +/// torch.jit.load(path). The facade integration is opt-in via +/// AiModelBuilder.ConfigurePlanCaching(directory); once configured, save/load +/// is transparent to the caller. +/// +/// +/// Plans are keyed by (modelTypeName, T, structureVersion, inputShapeHash, +/// hardwareFingerprint). Plans compiled on one host cannot be loaded on a host with +/// a different — Tensors rejects the load and we +/// fall through to a fresh compile. +/// +/// +public sealed class PlanCache +{ + private static PlanCache? _current; + + /// + /// The currently-active plan cache, or null if caching is disabled. Set via + /// . + /// consults this during Predict to decide whether to attempt disk load/save. + /// + public static PlanCache? Current => _current; + + public static void SetCurrent(PlanCache? cache) + { + _current = cache; + } + + public string Directory { get; } + + public PlanCache(string directory) + { + Directory = directory ?? throw new ArgumentNullException(nameof(directory)); + System.IO.Directory.CreateDirectory(directory); + } + + /// + /// Computes a stable filename for a plan identified by model type, element type, + /// structure version, and input shape. Hardware fingerprint is not part of the + /// key — checks that at load time. + /// + public string GetPlanPath(string modelTypeName, Type elementType, int structureVersion, int[] inputShape) + { + var hash = ComputeShapeHash(inputShape); + var safeModelName = SanitizeFilename(modelTypeName); + return Path.Combine( + Directory, + $"{safeModelName}_{elementType.Name}_v{structureVersion}_s{hash}.plan"); + } + + /// + /// Attempts to load a pre-compiled inference plan from disk. Returns null if the + /// file doesn't exist, is incompatible with the current host, or fails to + /// deserialize. A null return cleanly triggers fresh compilation upstream. + /// + public async Task?> TryLoadInferenceAsync( + string modelTypeName, + int structureVersion, + int[] inputShape, + IEngine engine, + CancellationToken cancellationToken = default) + { + var path = GetPlanPath(modelTypeName, typeof(T), structureVersion, inputShape); + if (!File.Exists(path)) + { + return null; + } + + try + { + return await CompiledPlanLoader.LoadInferenceAsync(path, engine, cancellationToken) + .ConfigureAwait(false); + } + catch (Exception ex) + { + // Incompatible plan, corrupted file, or stream error — fall through to + // recompile. Trace the failure so perf regressions don't go silent. + System.Diagnostics.Trace.TraceWarning( + $"PlanCache: load failed for '{path}': {ex.GetType().Name}: {ex.Message}"); + return null; + } + } + + /// + /// Persists a compiled plan to disk via atomic write (tmp-file + rename) so a + /// crash partway through doesn't leave a corrupt cache entry. + /// + public async Task SaveInferenceAsync( + ICompiledPlan plan, + string modelTypeName, + int structureVersion, + int[] inputShape, + CancellationToken cancellationToken = default) + { + var path = GetPlanPath(modelTypeName, typeof(T), structureVersion, inputShape); + var tmpPath = path + ".tmp"; + + try + { + using (var fs = new FileStream(tmpPath, FileMode.Create, FileAccess.Write, FileShare.None)) + { + await plan.SaveAsync(fs, cancellationToken).ConfigureAwait(false); + } + + if (File.Exists(path)) + { + File.Delete(path); + } + File.Move(tmpPath, path); + } + catch (Exception ex) + { + System.Diagnostics.Trace.TraceWarning( + $"PlanCache: save failed for '{path}': {ex.GetType().Name}: {ex.Message}"); + try { if (File.Exists(tmpPath)) File.Delete(tmpPath); } catch { } + } + } + + private static string SanitizeFilename(string name) + { + var invalid = Path.GetInvalidFileNameChars(); + var sb = new StringBuilder(name.Length); + foreach (var c in name) + { + sb.Append(Array.IndexOf(invalid, c) >= 0 ? '_' : c); + } + return sb.ToString(); + } + + private static string ComputeShapeHash(int[] shape) + { + using var sha = SHA256.Create(); + var bytes = new byte[shape.Length * 4]; + Buffer.BlockCopy(shape, 0, bytes, 0, bytes.Length); + var hash = sha.ComputeHash(bytes); + var sb = new StringBuilder(16); + for (int i = 0; i < 8; i++) sb.Append(hash[i].ToString("x2")); + return sb.ToString(); + } +} diff --git a/src/NeuralNetworks/ResNetNetwork.cs b/src/NeuralNetworks/ResNetNetwork.cs index b3e8d84e95..10c4cd2b65 100644 --- a/src/NeuralNetworks/ResNetNetwork.cs +++ b/src/NeuralNetworks/ResNetNetwork.cs @@ -497,10 +497,13 @@ public override void UpdateParameters(Vector parameters) /// /// The input tensor to make a prediction for. /// The predicted output tensor containing class probabilities. - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through so the + /// forward pass gets traced and replayed by CompiledModelHost after warmup — + /// matching PyTorch's torch.compile default. The eager forward is + /// , which retains the GPU-resident optimization path. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// /// Trains the ResNet network using the provided input and expected output. diff --git a/src/NeuralNetworks/SiameseNeuralNetwork.cs b/src/NeuralNetworks/SiameseNeuralNetwork.cs index c4f9ff109f..901ed91dfc 100644 --- a/src/NeuralNetworks/SiameseNeuralNetwork.cs +++ b/src/NeuralNetworks/SiameseNeuralNetwork.cs @@ -301,11 +301,11 @@ public override void UpdateParameters(Vector parameters) } } - /// - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through for + /// compiled-plan replay; remains the eager fallback. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// /// Trains the model on pairs of inputs using a similarity learning objective. diff --git a/src/NeuralNetworks/SuperNet.cs b/src/NeuralNetworks/SuperNet.cs index 58dfff71a1..8229e05964 100644 --- a/src/NeuralNetworks/SuperNet.cs +++ b/src/NeuralNetworks/SuperNet.cs @@ -1540,82 +1540,5 @@ public override void LoadState(Stream stream) } } - #region IJitCompilable Implementation - - /// - /// Exports a single operation as a computation graph. - /// - private ComputationNode ExportOperationGraph(ComputationNode input, int opIdx, string weightKey) - { - // Get or create weight constants - Vector? weight = null; - if (_weights.TryGetValue(weightKey, out var w)) - { - weight = w; - } - - switch (opIdx) - { - case 0: // Identity - return input; - - case 1: // 3x3 Conv (simplified as weighted pass) - if (weight != null) - { - var weightTensor = new Tensor(new[] { weight.Length }); - for (int i = 0; i < weight.Length; i++) - { - weightTensor[i] = NumOps.Add(NumOps.One, weight[i]); - } - var weightNode = TensorOperations.Constant(weightTensor, $"weights_{weightKey}"); - return TensorOperations.ElementwiseMultiply(input, weightNode); - } - return input; - - case 2: // 5x5 Conv (simplified) - if (weight != null) - { - var weightTensor = new Tensor(new[] { weight.Length }); - for (int i = 0; i < weight.Length; i++) - { - weightTensor[i] = NumOps.Add(NumOps.One, NumOps.Multiply(NumOps.FromDouble(1.5), weight[i])); - } - var weightNode = TensorOperations.Constant(weightTensor, $"weights_{weightKey}"); - return TensorOperations.ElementwiseMultiply(input, weightNode); - } - return input; - - case 3: // MaxPool (simplified as scaling) - { - var scaleTensor = new Tensor(new[] { 1 }); - scaleTensor[0] = NumOps.FromDouble(0.9); - var scaleNode = TensorOperations.Constant(scaleTensor, $"maxpool_scale_{weightKey}"); - return TensorOperations.ElementwiseMultiply(input, scaleNode); - } - - case 4: // AvgPool (simplified as scaling) - { - var scaleTensor = new Tensor(new[] { 1 }); - scaleTensor[0] = NumOps.FromDouble(0.8); - var scaleNode = TensorOperations.Constant(scaleTensor, $"avgpool_scale_{weightKey}"); - return TensorOperations.ElementwiseMultiply(input, scaleNode); - } - - default: - return input; - } - } - - /// - /// Performs forward pass through the model (required by IJitCompilable). - /// - /// The input tensor. - /// The output tensor. - public Tensor Forward(Tensor input) - { - return Predict(input); - } - - #endregion } } diff --git a/src/NeuralNetworks/SyntheticData/AutoDiffTabGenerator.cs b/src/NeuralNetworks/SyntheticData/AutoDiffTabGenerator.cs index e4397c1558..b30ea7e559 100644 --- a/src/NeuralNetworks/SyntheticData/AutoDiffTabGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/AutoDiffTabGenerator.cs @@ -981,7 +981,4 @@ private static Vector TensorToVector(Tensor t, int length) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/CTABGANPlusGenerator.cs b/src/NeuralNetworks/SyntheticData/CTABGANPlusGenerator.cs index fe17c4a1cb..d55ea712cd 100644 --- a/src/NeuralNetworks/SyntheticData/CTABGANPlusGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/CTABGANPlusGenerator.cs @@ -1257,7 +1257,4 @@ private Tensor ScaleTensor(Tensor tensor, double scale) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/CTGANGenerator.cs b/src/NeuralNetworks/SyntheticData/CTGANGenerator.cs index 25996052fd..5e7e223c4e 100644 --- a/src/NeuralNetworks/SyntheticData/CTGANGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/CTGANGenerator.cs @@ -1065,7 +1065,4 @@ private static Tensor CloneTensor(Tensor source) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/CausalGANGenerator.cs b/src/NeuralNetworks/SyntheticData/CausalGANGenerator.cs index 5fe2bd754e..4366f4fb08 100644 --- a/src/NeuralNetworks/SyntheticData/CausalGANGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/CausalGANGenerator.cs @@ -1179,7 +1179,4 @@ private static Tensor CloneTensor(Tensor source) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/CopulaGANGenerator.cs b/src/NeuralNetworks/SyntheticData/CopulaGANGenerator.cs index 5bf3d438e0..1171274a89 100644 --- a/src/NeuralNetworks/SyntheticData/CopulaGANGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/CopulaGANGenerator.cs @@ -1335,7 +1335,4 @@ private static Tensor CloneTensor(Tensor source) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/DPCTGANGenerator.cs b/src/NeuralNetworks/SyntheticData/DPCTGANGenerator.cs index 35fca2a590..cb45b4ecc5 100644 --- a/src/NeuralNetworks/SyntheticData/DPCTGANGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/DPCTGANGenerator.cs @@ -1162,7 +1162,4 @@ private static Tensor CloneTensor(Tensor source) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/FinDiffGenerator.cs b/src/NeuralNetworks/SyntheticData/FinDiffGenerator.cs index 8f438c64fd..95faf2ef23 100644 --- a/src/NeuralNetworks/SyntheticData/FinDiffGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/FinDiffGenerator.cs @@ -736,7 +736,4 @@ private static Vector TensorToVector(Tensor t, int length) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/GOGGLEGenerator.cs b/src/NeuralNetworks/SyntheticData/GOGGLEGenerator.cs index 0bb69d47ae..d4304159f7 100644 --- a/src/NeuralNetworks/SyntheticData/GOGGLEGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/GOGGLEGenerator.cs @@ -652,7 +652,4 @@ private static Vector TensorToVector(Tensor t, int length) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/MedSynthGenerator.cs b/src/NeuralNetworks/SyntheticData/MedSynthGenerator.cs index 69443f40a3..cf63dffb1d 100644 --- a/src/NeuralNetworks/SyntheticData/MedSynthGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/MedSynthGenerator.cs @@ -603,7 +603,8 @@ private void UpdateDiscriminator(T lr) #region Discriminator Layer List /// - /// Builds a combined list of discriminator layers for TapeLayerBridge. + /// Builds a combined list of discriminator layers (dense + dropout + output) + /// for gradient-penalty and related analyses. /// private IReadOnlyList> BuildDiscLayerList() { @@ -867,7 +868,4 @@ private static Vector TensorToVector(Tensor t, int length) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/MisGANGenerator.cs b/src/NeuralNetworks/SyntheticData/MisGANGenerator.cs index b5df0b525a..7efba091e0 100644 --- a/src/NeuralNetworks/SyntheticData/MisGANGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/MisGANGenerator.cs @@ -604,7 +604,8 @@ private void ApplyMaskGradientPenalty(Vector real, Vector fake, T scaledLr } /// - /// Builds a combined list of data discriminator layers for TapeLayerBridge. + /// Builds a combined list of data discriminator layers for gradient-penalty + /// and related analyses. /// private IReadOnlyList> BuildDataDiscLayerList() { @@ -619,7 +620,8 @@ private IReadOnlyList> BuildDataDiscLayerList() } /// - /// Builds a combined list of mask discriminator layers for TapeLayerBridge. + /// Builds a combined list of mask discriminator layers for gradient-penalty + /// and related analyses. /// private IReadOnlyList> BuildMaskDiscLayerList() { @@ -986,7 +988,4 @@ private static Tensor CloneTensor(Tensor source) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/OCTGANGenerator.cs b/src/NeuralNetworks/SyntheticData/OCTGANGenerator.cs index 58663cc87f..525c1619b9 100644 --- a/src/NeuralNetworks/SyntheticData/OCTGANGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/OCTGANGenerator.cs @@ -984,7 +984,4 @@ protected override IFullModel, Tensor> CreateNewInstance() #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/PATEGANGenerator.cs b/src/NeuralNetworks/SyntheticData/PATEGANGenerator.cs index e456e17be1..c65406d845 100644 --- a/src/NeuralNetworks/SyntheticData/PATEGANGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/PATEGANGenerator.cs @@ -1020,7 +1020,4 @@ protected override IFullModel, Tensor> CreateNewInstance() #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/REaLTabFormerGenerator.cs b/src/NeuralNetworks/SyntheticData/REaLTabFormerGenerator.cs index aa8ab70b83..78151a07e0 100644 --- a/src/NeuralNetworks/SyntheticData/REaLTabFormerGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/REaLTabFormerGenerator.cs @@ -828,7 +828,4 @@ private static Vector TensorToVector(Tensor t, int length) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/SyntheticTabularGeneratorBase.cs b/src/NeuralNetworks/SyntheticData/SyntheticTabularGeneratorBase.cs index 9c0fffd21b..d4f4c9f8f0 100644 --- a/src/NeuralNetworks/SyntheticData/SyntheticTabularGeneratorBase.cs +++ b/src/NeuralNetworks/SyntheticData/SyntheticTabularGeneratorBase.cs @@ -466,70 +466,4 @@ protected Vector CreateStandardNormalVector(int length) return v; } - #region IJitCompilable Implementation - - /// - /// Gets whether this generator supports JIT compilation for accelerated generation. - /// - /// - /// - /// Generators with simple MLP-based architectures (GAN generators) can export their - /// generator network as a computation graph for JIT compilation, which accelerates - /// the neural network forward pass during synthetic data generation. - /// - /// - /// For Beginners: JIT compilation converts the generator's neural network into - /// optimized native code. This makes the Generate() method faster, especially when - /// generating large numbers of synthetic rows. Generators based on diffusion models, - /// autoregressive transformers, or statistical methods typically cannot be JIT compiled - /// because they use iterative or dynamic computation patterns. - /// - /// - public virtual bool SupportsJitCompilation => false; - - /// - /// Exports the generator network's computation graph for JIT compilation. - /// - /// List to populate with input computation nodes (the noise vector). - /// The output computation node representing the generator's raw output. - /// - /// - /// Subclasses that support JIT compilation should override this method to export their - /// generator network as a computation graph using TensorOperations. The exported graph - /// represents the forward pass from noise input to raw generated output (before inverse - /// data transformation). - /// - /// - /// For Beginners: This creates a "recipe" of the generator's calculations that - /// the JIT compiler can optimize into fast native code. Only the neural network portion - /// is exported — the data transformation and post-processing steps remain interpreted. - /// - /// - /// Thrown when the generator does not support JIT compilation. - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException( - $"{GetType().Name} does not support JIT compilation. Check SupportsJitCompilation before calling."); - } - - /// - /// Exports a sequential MLP generator as a computation graph, with optional residual concatenation. - /// Delegates to for the implementation. - /// - protected static ComputationNode ExportMLPGeneratorGraph( - List> inputNodes, - int inputSize, - IReadOnlyList> hiddenLayers, - IReadOnlyList>? bnLayers, - ILayer outputLayer, - TapeLayerBridge.HiddenActivation hiddenAct = TapeLayerBridge.HiddenActivation.ReLU, - TapeLayerBridge.HiddenActivation outputAct = TapeLayerBridge.HiddenActivation.None, - bool useResidualConcat = false) - { - return TapeLayerBridge.ExportMLPGeneratorGraph( - inputNodes, inputSize, hiddenLayers, bnLayers, outputLayer, - hiddenAct, outputAct, useResidualConcat); - } - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/TVAEGenerator.cs b/src/NeuralNetworks/SyntheticData/TVAEGenerator.cs index 8ee6836044..136319dab4 100644 --- a/src/NeuralNetworks/SyntheticData/TVAEGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/TVAEGenerator.cs @@ -974,9 +974,6 @@ private static Tensor VectorToTensor(Vector v) #endregion - #region IJitCompilable Override - - #endregion #region Tensor Shape Helpers diff --git a/src/NeuralNetworks/SyntheticData/TabDDPMGenerator.cs b/src/NeuralNetworks/SyntheticData/TabDDPMGenerator.cs index 933253144a..1c5c4f157e 100644 --- a/src/NeuralNetworks/SyntheticData/TabDDPMGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/TabDDPMGenerator.cs @@ -1070,7 +1070,4 @@ private static Vector TensorToVector(Tensor t, int length) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/TabFlowGenerator.cs b/src/NeuralNetworks/SyntheticData/TabFlowGenerator.cs index 5c50ee2ce1..ddc52321db 100644 --- a/src/NeuralNetworks/SyntheticData/TabFlowGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/TabFlowGenerator.cs @@ -890,7 +890,4 @@ private static Vector TensorToVector(Tensor t, int length) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/TabLLMGenGenerator.cs b/src/NeuralNetworks/SyntheticData/TabLLMGenGenerator.cs index 84b9f7809a..411725be7d 100644 --- a/src/NeuralNetworks/SyntheticData/TabLLMGenGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/TabLLMGenGenerator.cs @@ -836,7 +836,4 @@ private Vector GetParameterVector() #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/TabSynGenerator.cs b/src/NeuralNetworks/SyntheticData/TabSynGenerator.cs index 5e5c9d3308..ef26609af2 100644 --- a/src/NeuralNetworks/SyntheticData/TabSynGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/TabSynGenerator.cs @@ -1203,9 +1203,6 @@ private static Vector TensorToVector(Tensor t, int length) #endregion - #region IJitCompilable Override - - #endregion #region Tensor Shape Helpers diff --git a/src/NeuralNetworks/SyntheticData/TabTransformerGenGenerator.cs b/src/NeuralNetworks/SyntheticData/TabTransformerGenGenerator.cs index ca52cc81ad..0b6ebdca42 100644 --- a/src/NeuralNetworks/SyntheticData/TabTransformerGenGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/TabTransformerGenGenerator.cs @@ -732,7 +732,4 @@ private static Vector TensorToVector(Tensor t, int length) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/TableGANGenerator.cs b/src/NeuralNetworks/SyntheticData/TableGANGenerator.cs index 039d93b540..495b97e295 100644 --- a/src/NeuralNetworks/SyntheticData/TableGANGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/TableGANGenerator.cs @@ -758,7 +758,4 @@ private static Tensor CloneTensor(Tensor source) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/SyntheticData/TapeLayerBridge.cs b/src/NeuralNetworks/SyntheticData/TapeLayerBridge.cs deleted file mode 100644 index e4aaa67f6f..0000000000 --- a/src/NeuralNetworks/SyntheticData/TapeLayerBridge.cs +++ /dev/null @@ -1,183 +0,0 @@ -using AiDotNet.Interfaces; -using AiDotNet.NeuralNetworks.Layers; -using AiDotNet.Tensors.Engines.Autodiff; -using AiDotNet.Tensors.LinearAlgebra; - -namespace AiDotNet.NeuralNetworks.SyntheticData; - -/// -/// Bridges the existing layer system with GradientTape for proper automatic differentiation. -/// -/// The numeric type. -/// -/// -/// TapeLayerBridge eliminates the need for manual gradient computation -/// (ManualLinearBackward) in GAN generators. It uses the GradientTape autodiff system -/// to automatically compute gradients of a network output with respect to its input, -/// which is essential for WGAN-GP gradient penalty computation. -/// -/// -/// For Beginners: In WGAN-GP training, we need to know how the discriminator's output -/// changes when we slightly change its input. Previously, this was computed manually by -/// extracting weights and multiplying backwards through each layer. This utility does -/// the same thing automatically using the GradientTape system, which: -/// - Is less error-prone (no manual weight index calculations) -/// - Handles all activation functions automatically -/// - Produces correct gradients via the chain rule -/// - Supports any layer configuration without custom backward code -/// -/// -/// Supported Layer Types: -/// -/// FullyConnectedLayer — linear transform via TensorOperations.MatrixMultiply -/// DropoutLayer — skipped (gradient penalty uses eval mode) -/// BatchNormalizationLayer — affine transform using running statistics -/// -/// -/// -/// Usage Example (WGAN-GP gradient penalty): -/// -/// var inputGrad = TapeLayerBridge<double>.ComputeInputGradient( -/// interpolatedInput, -/// discriminatorLayers, -/// HiddenActivation.LeakyReLU, -/// applyActivationOnLast: false); -/// double gradNorm = ComputeL2Norm(inputGrad); -/// double penalty = (gradNorm - 1.0) * (gradNorm - 1.0); -/// -/// -/// -public static class TapeLayerBridge -{ - /// - /// Specifies the activation function applied between hidden layers. - /// - /// - /// - /// GAN discriminators typically use IdentityActivation on their FullyConnectedLayers - /// and apply activations manually. This enum tells the bridge which activation to apply - /// between layers during the TensorOperations-based forward pass. - /// - /// - public enum HiddenActivation - { - /// No activation (identity). - None, - /// LeakyReLU with alpha=0.2 (standard for GAN discriminators). - LeakyReLU, - /// Standard ReLU activation. - ReLU, - /// Sigmoid activation. - Sigmoid, - /// Tanh activation. - Tanh, - /// SiLU/Swish activation. - SiLU, - /// GELU activation. - GELU - } - - /// - /// Exports an MLP-based generator network as a JIT-compilable computation graph. - /// - /// - /// - /// This helper constructs a computation graph by chaining fully connected layers, - /// optional batch normalization layers, and activation functions. It supports the - /// CTGAN-style residual architecture where the original input is concatenated back - /// at each hidden layer. - /// - /// - /// For Beginners: Most GAN generators use the same basic MLP structure for - /// their generator network. This method converts that structure into a computation - /// graph that the JIT compiler can optimize for 2-10x faster generation. - /// The graph excludes column-specific output activations (Tanh per continuous column, - /// Softmax per categorical group), which are applied separately after the MLP forward. - /// - /// - /// List to populate with the input (noise) variable node. - /// The size of the noise input vector. - /// The hidden fully connected layers (excluding output). - /// Batch normalization layers (one per hidden layer), or null. - /// The final output fully connected layer. - /// Activation to apply between hidden layers. - /// Activation to apply on the output layer (use None for identity). - /// If true, concatenates the original input at each hidden layer. - /// The output computation node. - public static ComputationNode ExportMLPGeneratorGraph( - List> inputNodes, - int inputSize, - IReadOnlyList> hiddenLayers, - IReadOnlyList>? bnLayers, - ILayer outputLayer, - HiddenActivation hiddenAct = HiddenActivation.ReLU, - HiddenActivation outputAct = HiddenActivation.None, - bool useResidualConcat = false) - { - var inputNode = TensorOperations.Variable( - new Tensor([1, inputSize]), "generator_input", requiresGradient: false); - inputNodes.Add(inputNode); - - var current = inputNode; - - for (int i = 0; i < hiddenLayers.Count; i++) - { - // Residual concatenation (CTGAN-style skip connections) - if (useResidualConcat && i > 0) - { - current = TensorOperations.Concat( - new List> { current, inputNode }, axis: 1); - } - - // Forward through FC layer - var fcInputs = new List> { current }; - current = ((Layers.LayerBase)hiddenLayers[i]).ExportComputationGraph(fcInputs); - - // Forward through BN layer if present - if (bnLayers is not null && i < bnLayers.Count) - { - var bnInputs = new List> { current }; - current = ((Layers.LayerBase)bnLayers[i]).ExportComputationGraph(bnInputs); - } - - // Apply hidden activation - current = ApplyComputationNodeActivation(current, hiddenAct, 0.2); - } - - // Output layer with optional residual concat - if (useResidualConcat) - { - current = TensorOperations.Concat( - new List> { current, inputNode }, axis: 1); - } - - var outInputs = new List> { current }; - current = ((Layers.LayerBase)outputLayer).ExportComputationGraph(outInputs); - - // Apply output activation - current = ApplyComputationNodeActivation(current, outputAct, 0.2); - - return current; - } - - /// - /// Applies activation using ComputationNode API (legacy, for ExportMLPGeneratorGraph). - /// - private static ComputationNode ApplyComputationNodeActivation( - ComputationNode node, - HiddenActivation activation, - double leakyAlpha) - { - return activation switch - { - HiddenActivation.LeakyReLU => TensorOperations.LeakyReLU(node, leakyAlpha), - HiddenActivation.ReLU => TensorOperations.ReLU(node), - HiddenActivation.Sigmoid => TensorOperations.Sigmoid(node), - HiddenActivation.Tanh => TensorOperations.Tanh(node), - HiddenActivation.SiLU => TensorOperations.Swish(node), - HiddenActivation.GELU => TensorOperations.GELU(node), - HiddenActivation.None => node, - _ => node, - }; - } -} diff --git a/src/NeuralNetworks/SyntheticData/TimeGANGenerator.cs b/src/NeuralNetworks/SyntheticData/TimeGANGenerator.cs index b2f68ddb15..d82e61ffad 100644 --- a/src/NeuralNetworks/SyntheticData/TimeGANGenerator.cs +++ b/src/NeuralNetworks/SyntheticData/TimeGANGenerator.cs @@ -608,7 +608,8 @@ private void TrainDiscriminatorStep(List> sequences, int startIdx, int #region Discriminator Layer List /// - /// Builds a combined list of discriminator layers for TapeLayerBridge. + /// Builds a combined list of discriminator layers (dense + dropout + output) + /// for gradient-penalty and related analyses. /// private IReadOnlyList> BuildDiscLayerList() { @@ -887,7 +888,4 @@ private static Vector TensorToVector(Tensor t, int length) #endregion - #region IJitCompilable Override - - #endregion } diff --git a/src/NeuralNetworks/UNet3D.cs b/src/NeuralNetworks/UNet3D.cs index 98c3bb76be..1b605499be 100644 --- a/src/NeuralNetworks/UNet3D.cs +++ b/src/NeuralNetworks/UNet3D.cs @@ -236,10 +236,11 @@ public Tensor Forward(Tensor input) /// The input voxel grid tensor. /// The predicted segmentation map. /// - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through for + /// compiled-plan replay; remains the eager fallback. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// /// Trains the network on a single batch of input-output pairs. diff --git a/src/NeuralNetworks/VGGNetwork.cs b/src/NeuralNetworks/VGGNetwork.cs index f24eba4bb9..5796d77aad 100644 --- a/src/NeuralNetworks/VGGNetwork.cs +++ b/src/NeuralNetworks/VGGNetwork.cs @@ -339,10 +339,11 @@ public override void UpdateParameters(Vector parameters) /// probability is the network's prediction. /// /// - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through for + /// compiled-plan replay; remains the eager fallback. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// /// Trains the VGG network using the provided input and expected output. diff --git a/src/NeuralNetworks/VoxelCNN.cs b/src/NeuralNetworks/VoxelCNN.cs index 7c430c7050..7b001dd7f5 100644 --- a/src/NeuralNetworks/VoxelCNN.cs +++ b/src/NeuralNetworks/VoxelCNN.cs @@ -229,10 +229,11 @@ public Tensor Forward(Tensor input) /// The input voxel grid tensor. /// The predicted class probabilities or scores. /// - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through for + /// compiled-plan replay; remains the eager fallback. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// /// Trains the network on a single batch of input-output pairs. diff --git a/src/NeuralNetworks/Word2Vec.cs b/src/NeuralNetworks/Word2Vec.cs index d192c6c9ad..c48805469a 100644 --- a/src/NeuralNetworks/Word2Vec.cs +++ b/src/NeuralNetworks/Word2Vec.cs @@ -317,10 +317,11 @@ public override void UpdateParameters(Vector parameters) /// For Beginners: This is identical to the Forward pass—it takes word IDs and /// returns their embeddings. /// - public override Tensor Predict(Tensor input) - { - return Forward(input); - } + /// + /// Routes inference through for + /// compiled-plan replay; remains the eager fallback. + /// + protected override Tensor PredictEager(Tensor input) => Forward(input); /// /// Trains the Word2Vec model on a single batch of target and context pairs. diff --git a/src/OnlineLearning/OnlineLearningModelBase.cs b/src/OnlineLearning/OnlineLearningModelBase.cs index d35f3d4fe2..5116a369a3 100644 --- a/src/OnlineLearning/OnlineLearningModelBase.cs +++ b/src/OnlineLearning/OnlineLearningModelBase.cs @@ -92,11 +92,6 @@ public abstract class OnlineLearningModelBase : IOnlineLearningModel, IMod /// public virtual bool SupportsParameterInitialization => ParameterCount > 0; - /// - /// Gets whether JIT compilation is supported. - /// - public virtual bool SupportsJitCompilation => false; - /// /// Initializes a new instance of the OnlineLearningModelBase class. /// @@ -459,14 +454,6 @@ public virtual void LoadState(Stream stream) Deserialize(serializedData); } - /// - /// Exports the computation graph for JIT compilation. - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation is not supported for this online learning model."); - } - #endregion #region Helper Methods diff --git a/src/Regression/AdaBoostR2Regression.cs b/src/Regression/AdaBoostR2Regression.cs index 095eb98082..ca0ebb6eff 100644 --- a/src/Regression/AdaBoostR2Regression.cs +++ b/src/Regression/AdaBoostR2Regression.cs @@ -672,7 +672,4 @@ public override IFullModel, Vector> Clone() return clone; } - #region IJitCompilable Implementation Override - - #endregion } diff --git a/src/Regression/DecisionTreeAsyncRegressionBase.cs b/src/Regression/DecisionTreeAsyncRegressionBase.cs index 20d09c9897..088ec9e8fe 100644 --- a/src/Regression/DecisionTreeAsyncRegressionBase.cs +++ b/src/Regression/DecisionTreeAsyncRegressionBase.cs @@ -1129,140 +1129,4 @@ public T SoftTreeTemperature #endregion - #region IJitCompilable Implementation - - /// - /// Gets whether this model currently supports JIT compilation. - /// - /// - /// true when is enabled and the tree has been trained; - /// false otherwise. - /// - /// - /// - /// When is enabled, the decision tree can be exported as a - /// differentiable computation graph using soft (sigmoid-based) gating. This enables - /// JIT compilation for optimized inference. - /// - /// - /// When is disabled, JIT compilation is not supported because - /// traditional hard decision trees use branching logic that cannot be represented as - /// a static computation graph. - /// - /// For Beginners: JIT compilation is available when soft tree mode is enabled. - /// - /// In soft tree mode, the discrete if-then decisions are replaced with smooth sigmoid - /// functions that can be compiled into an optimized computation graph. This gives you - /// the interpretability of decision trees with the speed of JIT-compiled models. - /// - /// - public virtual bool SupportsJitCompilation => UseSoftTree && Root != null; - - /// - /// Exports the model's computation graph for JIT compilation. - /// - /// List to populate with input computation nodes. - /// The root node of the exported computation graph. - /// - /// Thrown when is false. - /// - /// - /// Thrown when the tree has not been trained (Root is null). - /// - /// - /// - /// When soft tree mode is enabled, this exports the tree as a differentiable computation - /// graph using operations. Each internal - /// node becomes a soft split operation that computes sigmoid-weighted combinations of - /// left and right subtree outputs. - /// - /// For Beginners: This method converts the decision tree into a computation graph. - /// - /// In soft tree mode, each decision node becomes a smooth blend: - /// - Instead of "go left OR right", it computes "X% left + Y% right" - /// - The percentages are determined by the sigmoid function - /// - This creates a smooth, differentiable function that can be JIT compiled - /// - /// - public virtual AiDotNet.Autodiff.ComputationNode ExportComputationGraph(List> inputNodes) - { - if (!UseSoftTree) - { - throw new NotSupportedException( - "Async decision trees do not support JIT compilation in hard tree mode because they use " + - "discrete branching logic (if-then-else rules).\n\n" + - "To enable JIT compilation, set UseSoftTree = true to use soft (differentiable) decision trees " + - "with sigmoid-based gating."); - } - - if (Root == null) - { - throw new InvalidOperationException( - "Cannot export computation graph: the decision tree has not been trained. " + - "Call Train() or TrainAsync() first to build the tree structure."); - } - - // Get the number of features from the tree structure - int numFeatures = GetMaxFeatureIndexFromTree(Root) + 1; - - // Create input variable node - var inputTensor = new Tensor(new[] { numFeatures }); - var input = Autodiff.TensorOperations.Variable(inputTensor, "input"); - inputNodes.Add(input); - - // Recursively export the tree as soft split operations - return ExportNodeAsComputationGraph(Root, input); - } - - /// - /// Gets the maximum feature index used in the tree. - /// - /// The root node of the tree to scan. - /// The maximum feature index found. - private int GetMaxFeatureIndexFromTree(DecisionTreeNode? node) - { - if (node == null || node.IsLeaf) - return -1; - - int maxIndex = node.FeatureIndex; - int leftMax = GetMaxFeatureIndexFromTree(node.Left); - int rightMax = GetMaxFeatureIndexFromTree(node.Right); - - return Math.Max(maxIndex, Math.Max(leftMax, rightMax)); - } - - /// - /// Recursively exports a tree node as a computation graph. - /// - /// The node to export. - /// The input computation node. - /// A computation node representing this subtree. - private Autodiff.ComputationNode ExportNodeAsComputationGraph( - DecisionTreeNode node, - Autodiff.ComputationNode input) - { - if (node.IsLeaf) - { - // Leaf node: return constant prediction value - var leafTensor = new Tensor(new[] { 1 }); - leafTensor[0] = node.Prediction; - return Autodiff.TensorOperations.Constant(leafTensor, $"leaf_{node.GetHashCode()}"); - } - - // Internal node: export as SoftSplit operation - // Recursively export left and right subtrees - var leftOutput = ExportNodeAsComputationGraph(node.Left!, input); - var rightOutput = ExportNodeAsComputationGraph(node.Right!, input); - - // Use SoftSplit operation: output = sigmoid((threshold - x[feature]) / temp) * left + (1 - sigmoid) * right - return Autodiff.TensorOperations.SoftSplit( - input, - leftOutput, - rightOutput, - node.FeatureIndex, - node.SplitValue, - SoftTreeTemperature); - } - - #endregion } diff --git a/src/Regression/DecisionTreeRegressionBase.cs b/src/Regression/DecisionTreeRegressionBase.cs index 2aa4ba1988..72f1b21714 100644 --- a/src/Regression/DecisionTreeRegressionBase.cs +++ b/src/Regression/DecisionTreeRegressionBase.cs @@ -1197,116 +1197,4 @@ public virtual void LoadState(Stream stream) /// public T SoftTreeTemperature { get; set; } - /// - /// Gets a value indicating whether this model supports JIT (Just-In-Time) compilation. - /// - /// - /// - /// When is enabled, the decision tree can be exported as a - /// differentiable computation graph using soft (sigmoid-based) gating. This enables - /// JIT compilation for optimized inference. - /// - /// - /// When is disabled, JIT compilation is not supported because - /// traditional hard decision trees use branching logic that cannot be represented as - /// a static computation graph. - /// - /// - public virtual bool SupportsJitCompilation => UseSoftTree && Root != null; - - /// - /// Exports the model's computation as a graph of operations. - /// - /// The input nodes for the computation graph. - /// The root node of the exported computation graph. - /// - /// Thrown when is false. - /// - /// - /// Thrown when the tree has not been trained (Root is null). - /// - /// - /// - /// When soft tree mode is enabled, this exports the tree as a differentiable computation - /// graph using operations. Each internal - /// node becomes a soft split operation that computes sigmoid-weighted combinations of - /// left and right subtree outputs. - /// - /// - public virtual AiDotNet.Autodiff.ComputationNode ExportComputationGraph(List> inputNodes) - { - if (!UseSoftTree) - { - throw new NotSupportedException( - "Decision tree regression models do not support JIT compilation in hard tree mode because they use:\n" + - "- Tree-based branching logic with dynamic conditions\n" + - "- Recursive tree traversal that depends on input values\n" + - "- Conditional splits that cannot be represented as static tensor operations\n\n" + - "To enable JIT compilation, set UseSoftTree = true to use soft (differentiable) decision trees " + - "with sigmoid-based gating."); - } - - if (Root == null) - { - throw new InvalidOperationException( - "Cannot export computation graph: the decision tree has not been trained. " + - "Call Train() first to build the tree structure."); - } - - // Get the number of features from the tree structure - int numFeatures = GetMaxFeatureIndexFromTree(Root) + 1; - - // Create input variable node - var inputTensor = new Tensor(new[] { numFeatures }); - var input = Autodiff.TensorOperations.Variable(inputTensor, "input"); - inputNodes.Add(input); - - // Recursively export the tree as soft split operations - return ExportNodeAsComputationGraph(Root, input); - } - - /// - /// Recursively exports a tree node as a computation graph. - /// - private Autodiff.ComputationNode ExportNodeAsComputationGraph( - DecisionTreeNode node, - Autodiff.ComputationNode input) - { - if (node.IsLeaf) - { - // Leaf node: return constant prediction value - var leafTensor = new Tensor(new[] { 1 }); - leafTensor[0] = node.Prediction; - return Autodiff.TensorOperations.Constant(leafTensor, $"leaf_{node.GetHashCode()}"); - } - - // Internal node: export as SoftSplit operation - // Recursively export left and right subtrees - var leftOutput = ExportNodeAsComputationGraph(node.Left!, input); - var rightOutput = ExportNodeAsComputationGraph(node.Right!, input); - - // Use SoftSplit operation: output = sigmoid((threshold - x[feature]) / temp) * left + (1 - sigmoid) * right - return Autodiff.TensorOperations.SoftSplit( - input, - leftOutput, - rightOutput, - node.FeatureIndex, - node.SplitValue, - SoftTreeTemperature); - } - - /// - /// Gets the maximum feature index used in the tree. - /// - private int GetMaxFeatureIndexFromTree(DecisionTreeNode? node) - { - if (node == null || node.IsLeaf) - return -1; - - int maxIndex = node.FeatureIndex; - int leftMax = GetMaxFeatureIndexFromTree(node.Left); - int rightMax = GetMaxFeatureIndexFromTree(node.Right); - - return Math.Max(maxIndex, Math.Max(leftMax, rightMax)); - } } diff --git a/src/Regression/ExtremelyRandomizedTreesRegression.cs b/src/Regression/ExtremelyRandomizedTreesRegression.cs index 73652e93ee..f6b062d039 100644 --- a/src/Regression/ExtremelyRandomizedTreesRegression.cs +++ b/src/Regression/ExtremelyRandomizedTreesRegression.cs @@ -575,7 +575,4 @@ public override IEnumerable GetActiveFeatureIndices() return base.GetActiveFeatureIndices(); } - #region IJitCompilable Implementation Override - - #endregion } diff --git a/src/Regression/GradientBoostingRegression.cs b/src/Regression/GradientBoostingRegression.cs index 992b44d2bf..12de812943 100644 --- a/src/Regression/GradientBoostingRegression.cs +++ b/src/Regression/GradientBoostingRegression.cs @@ -583,7 +583,4 @@ public override IFullModel, Vector> Clone() return clone; } - #region IJitCompilable Implementation Override - - #endregion } diff --git a/src/Regression/NonLinearRegressionBase.cs b/src/Regression/NonLinearRegressionBase.cs index 49a5c2d215..cc90bfa8e1 100644 --- a/src/Regression/NonLinearRegressionBase.cs +++ b/src/Regression/NonLinearRegressionBase.cs @@ -1236,263 +1236,4 @@ public virtual void LoadState(Stream stream) Deserialize(data); } - #region IJitCompilable Implementation - - /// - /// - /// - /// Non-linear regression models support JIT compilation for all kernel types: - /// - Linear kernel: Fully supported (dot product) - /// - RBF kernel: Fully supported (Gaussian similarity) - /// - Sigmoid kernel: Fully supported (tanh-based similarity) - /// - Polynomial kernel: Fully supported (power operation) - /// - Laplacian kernel: Fully supported (L1 norm using sqrt(x^2) approximation) - /// - /// For Beginners: JIT (Just-In-Time) compilation can speed up kernel-based models. - /// - /// Non-linear models use kernel functions to capture complex patterns. JIT compilation - /// optimizes these computations for faster predictions. All kernel types are supported: - /// - Linear kernels (simple dot products) - /// - RBF kernels (Gaussian similarity based on distance) - /// - Sigmoid kernels (tanh-based similarity) - /// - Polynomial kernels (captures polynomial relationships) - /// - Laplacian kernels (L1 distance-based similarity) - /// - /// For large models with many support vectors, JIT can provide 3-5x speedup. - /// - /// - public virtual bool SupportsJitCompilation - { - get - { - // Check if we have a trained model - if (SupportVectors == null || SupportVectors.Rows == 0 || Alphas == null || Alphas.Length == 0) - return false; - - // Check if kernel type is supported - return Options.KernelType == KernelType.Linear || - Options.KernelType == KernelType.RBF || - Options.KernelType == KernelType.Sigmoid || - Options.KernelType == KernelType.Polynomial || - Options.KernelType == KernelType.Laplacian; - } - } - - /// - /// - /// - /// Exports the non-linear regression model as a computation graph. - /// The graph represents: output = B + sum(alpha[i] * kernel(input, supportVector[i])) - /// - /// For Beginners: This converts the kernel-based model to a computation graph. - /// - /// The computation graph represents: - /// 1. For each support vector: - /// - Compute kernel similarity between input and support vector - /// - Multiply by alpha coefficient (weight) - /// 2. Sum all weighted kernel values - /// 3. Add bias term (B) - /// - /// Kernel functions measure similarity: - /// - Linear: Simple dot product (like correlation) - /// - RBF: Gaussian distance (close points are similar) - /// - Sigmoid: Tanh-based similarity - /// - /// The JIT compiler optimizes this complex computation into fast native code. - /// - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - // Validation - if (SupportVectors == null || SupportVectors.Rows == 0) - { - throw new InvalidOperationException("Cannot export computation graph: Model has not been trained yet."); - } - - if (!SupportsJitCompilation) - { - throw new NotSupportedException($"JIT compilation is not supported for kernel type: {Options.KernelType}"); - } - - // Create input node (placeholder for input features) - // Shape: [1, feature_count] (single example) - var featureCount = SupportVectors.Columns; - var inputShape = new int[] { 1, featureCount }; - var inputTensor = new Tensor(inputShape); - var inputNode = new ComputationNode(inputTensor); - inputNodes.Add(inputNode); - - // Accumulator for summing all kernel results - ComputationNode? sumNode = null; - - // Process each support vector - for (int i = 0; i < SupportVectors.Rows; i++) - { - // Create support vector node - var svShape = new int[] { 1, featureCount }; - var svData = new T[featureCount]; - for (int j = 0; j < featureCount; j++) - { - svData[j] = SupportVectors[i, j]; - } - var svTensor = new Tensor(svShape, new Vector(svData)); - var svNode = new ComputationNode(svTensor); - - // Compute kernel value based on kernel type - ComputationNode kernelNode = Options.KernelType switch - { - KernelType.Linear => ComputeLinearKernel(inputNode, svNode), - KernelType.RBF => ComputeRBFKernel(inputNode, svNode), - KernelType.Sigmoid => ComputeSigmoidKernel(inputNode, svNode), - KernelType.Polynomial => ComputePolynomialKernel(inputNode, svNode), - KernelType.Laplacian => ComputeLaplacianKernel(inputNode, svNode), - _ => throw new NotSupportedException($"Kernel type {Options.KernelType} is not supported for JIT compilation") - }; - - // Multiply by alpha coefficient - var alphaTensor = CreateFilledTensorLike(kernelNode, Alphas[i]); - var alphaNode = TensorOperations.Constant(alphaTensor, $"alpha_{i}"); - var weightedNode = TensorOperations.ElementwiseMultiply(kernelNode, alphaNode); - - // Add to accumulator - if (sumNode == null) - { - sumNode = weightedNode; - } - else - { - sumNode = TensorOperations.Add(sumNode, weightedNode); - } - } - - // Add bias term - var biasTensor = CreateFilledTensorLike(sumNode!, B); - var biasNode = TensorOperations.Constant(biasTensor, "bias"); - var outputNode = TensorOperations.Add(sumNode!, biasNode); - - return outputNode; - } - - /// - /// Computes linear kernel: x1 · x2 (dot product). - /// - private ComputationNode ComputeLinearKernel(ComputationNode x1, ComputationNode x2) - { - // Element-wise multiply - var product = TensorOperations.ElementwiseMultiply(x1, x2); - - // Sum all elements to get the dot product (scalar) - return TensorOperations.Sum(product); - } - - /// - /// Computes RBF kernel: exp(-gamma * ||x1 - x2||^2). - /// - private ComputationNode ComputeRBFKernel(ComputationNode x1, ComputationNode x2) - { - // Compute difference: x1 - x2 - var diff = TensorOperations.Subtract(x1, x2); - - // Square: (x1 - x2)^2 - var squared = TensorOperations.ElementwiseMultiply(diff, diff); - - // Sum squared differences to get ||x1 - x2||^2 (scalar) - var sumSquared = TensorOperations.Sum(squared); - - // Multiply by -gamma - var gammaTensor = CreateFilledTensorLike(sumSquared, NumOps.FromDouble(-Options.Gamma)); - var gammaNode = TensorOperations.Constant(gammaTensor, "gamma"); - var scaled = TensorOperations.ElementwiseMultiply(sumSquared, gammaNode); - - // Exp(-gamma * ||x1 - x2||^2) - var result = TensorOperations.Exp(scaled); - - return result; - } - - /// - /// Computes Sigmoid kernel: tanh(gamma * (x1 · x2) + coef0). - /// - private ComputationNode ComputeSigmoidKernel(ComputationNode x1, ComputationNode x2) - { - // Dot product: x1 · x2 = sum(x1 * x2) - var product = TensorOperations.ElementwiseMultiply(x1, x2); - var dotProduct = TensorOperations.Sum(product); - - // Multiply by gamma - var gammaTensor = CreateFilledTensorLike(dotProduct, NumOps.FromDouble(Options.Gamma)); - var gammaNode = TensorOperations.Constant(gammaTensor, "gamma"); - var scaled = TensorOperations.ElementwiseMultiply(dotProduct, gammaNode); - - // Add coef0 - var coef0Tensor = CreateFilledTensorLike(scaled, NumOps.FromDouble(Options.Coef0)); - var coef0Node = TensorOperations.Constant(coef0Tensor, "coef0"); - var sum = TensorOperations.Add(scaled, coef0Node); - - // Tanh - var result = TensorOperations.Tanh(sum); - - return result; - } - - /// - /// Computes Polynomial kernel: (gamma * (x1 · x2) + coef0) ^ degree. - /// - private ComputationNode ComputePolynomialKernel(ComputationNode x1, ComputationNode x2) - { - // Dot product: x1 · x2 = sum(x1 * x2) - var product = TensorOperations.ElementwiseMultiply(x1, x2); - var dotProduct = TensorOperations.Sum(product); - - // Multiply by gamma - var gammaTensor = CreateFilledTensorLike(dotProduct, NumOps.FromDouble(Options.Gamma)); - var gammaNode = TensorOperations.Constant(gammaTensor, "gamma"); - var scaled = TensorOperations.ElementwiseMultiply(dotProduct, gammaNode); - - // Add coef0 - var coef0Tensor = CreateFilledTensorLike(scaled, NumOps.FromDouble(Options.Coef0)); - var coef0Node = TensorOperations.Constant(coef0Tensor, "coef0"); - var sum = TensorOperations.Add(scaled, coef0Node); - - // Power(sum, degree) - var result = TensorOperations.Power(sum, Options.PolynomialDegree); - - return result; - } - - /// - /// Computes Laplacian kernel: exp(-gamma * |x1 - x2|_1). - /// - private ComputationNode ComputeLaplacianKernel(ComputationNode x1, ComputationNode x2) - { - // Compute difference: x1 - x2 - var diff = TensorOperations.Subtract(x1, x2); - - // Compute |x1 - x2| using sqrt((x1-x2)^2) as approximation of abs - // Note: This works for element-wise absolute value - var squared = TensorOperations.ElementwiseMultiply(diff, diff); - var absDiff = TensorOperations.Sqrt(squared); - - // Sum absolute differences to get L1 norm (|x1 - x2|_1) - var l1Norm = TensorOperations.Sum(absDiff); - - // Multiply by -gamma - var gammaTensor = CreateFilledTensorLike(l1Norm, NumOps.FromDouble(-Options.Gamma)); - var gammaNode = TensorOperations.Constant(gammaTensor, "gamma"); - var scaled = TensorOperations.ElementwiseMultiply(l1Norm, gammaNode); - - // Exp(-gamma * |x1 - x2|_1) - var result = TensorOperations.Exp(scaled); - - return result; - } - - #endregion - - private static Tensor CreateFilledTensorLike(ComputationNode referenceNode, T value) - { - var tensor = new Tensor((int[])referenceNode.Value._shape); - tensor.Fill(value); - return tensor; - } } diff --git a/src/Regression/RandomForestRegression.cs b/src/Regression/RandomForestRegression.cs index 406e591ab8..de257c55ec 100644 --- a/src/Regression/RandomForestRegression.cs +++ b/src/Regression/RandomForestRegression.cs @@ -469,7 +469,4 @@ public override IFullModel, Vector> Clone() return clone; } - #region IJitCompilable Implementation Override - - #endregion } diff --git a/src/Regression/RegressionBase.cs b/src/Regression/RegressionBase.cs index 9407c138db..de60dc200e 100644 --- a/src/Regression/RegressionBase.cs +++ b/src/Regression/RegressionBase.cs @@ -1133,106 +1133,4 @@ public virtual void LoadState(Stream stream) Deserialize(serializedData); } - #region IJitCompilable Implementation - - /// - /// - /// - /// Regression models support JIT compilation for accelerated inference. - /// The computation graph represents the linear regression formula: - /// output = input @ coefficients + intercept (if HasIntercept) - /// - /// For Beginners: JIT (Just-In-Time) compilation optimizes the model for faster predictions. - /// - /// Instead of performing matrix operations step-by-step at runtime, JIT compilation: - /// - Analyzes the model's structure ahead of time - /// - Generates optimized native code - /// - Results in 5-10x faster predictions - /// - /// This is especially beneficial for: - /// - Real-time prediction systems - /// - High-throughput applications - /// - Batch processing of many predictions - /// - /// - public virtual bool SupportsJitCompilation => true; - - /// - /// - /// - /// Exports the regression model as a computation graph for JIT compilation. - /// The graph represents: output = input @ coefficients + intercept - /// - /// For Beginners: This method converts the regression model into a computation graph. - /// - /// A computation graph is like a recipe that describes: - /// 1. Take input features (a matrix) - /// 2. Multiply by learned coefficients - /// 3. Add intercept (if the model uses one) - /// 4. Return predictions - /// - /// The JIT compiler uses this graph to: - /// - Optimize the operations - /// - Combine steps where possible - /// - Generate fast native code - /// - /// For linear regression: y = X * w + b - /// - X: input features - /// - w: coefficients (weights) - /// - b: intercept (bias) - /// - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - if (inputNodes == null) - { - throw new ArgumentNullException(nameof(inputNodes)); - } - - // Validation: Ensure model is trained - if (Coefficients == null || Coefficients.Length == 0) - { - throw new InvalidOperationException("Cannot export computation graph: Model has not been trained yet."); - } - - // Create input node (placeholder for input features) - // Shape: [batch_size, feature_count] - var inputShape = new int[] { 1, Coefficients.Length }; - var inputTensor = new Tensor(inputShape); - var inputNode = new ComputationNode(inputTensor); - inputNodes.Add(inputNode); - - // Convert coefficients Vector to Tensor - // Shape: [feature_count, 1] for matrix multiplication - var coeffShape = new int[] { Coefficients.Length, 1 }; - var coeffData = new T[Coefficients.Length]; - for (int i = 0; i < Coefficients.Length; i++) - { - coeffData[i] = Coefficients[i]; - } - var coeffTensor = new Tensor(coeffShape, new Vector(coeffData)); - var coeffNode = new ComputationNode(coeffTensor); - - // MatMul: input @ coefficients - // Result shape: [batch_size, 1] - var outputNode = TensorOperations.MatrixMultiply(inputNode, coeffNode); - - // Add intercept if used - if (HasIntercept) - { - // Convert scalar intercept to Tensor - // Shape: [1, 1] (scalar broadcasted) - var interceptShape = new int[] { 1, 1 }; - var interceptData = new T[] { Intercept }; - var interceptTensor = new Tensor(interceptShape, new Vector(interceptData)); - var interceptNode = new ComputationNode(interceptTensor); - - // Add: (input @ coefficients) + intercept - outputNode = TensorOperations.Add(outputNode, interceptNode); - } - - return outputNode; - } - - #endregion } diff --git a/src/ReinforcementLearning/Agents/DeepReinforcementLearningAgentBase.cs b/src/ReinforcementLearning/Agents/DeepReinforcementLearningAgentBase.cs index 791bea3aa2..5ccc25f389 100644 --- a/src/ReinforcementLearning/Agents/DeepReinforcementLearningAgentBase.cs +++ b/src/ReinforcementLearning/Agents/DeepReinforcementLearningAgentBase.cs @@ -30,9 +30,10 @@ namespace AiDotNet.ReinforcementLearning.Agents; /// - Model-based methods (Dreamer, MuZero, World Models) /// - Transformer-based methods (Decision Transformer) /// -/// JIT Compilation Support: Deep RL agents support JIT compilation for policy inference -/// when their underlying neural networks support IJitCompilable. The JIT-compiled policy network -/// provides fast, deterministic action selection (without exploration) suitable for deployment. +/// Auto-Compile: Policy inference goes through the standard neural-network path, +/// which is auto-compiled by Tensors' AutoTracer once the input-shape pattern repeats. No +/// explicit compile call is required. Users can opt out via +/// TensorCodecOptions.Current.EnableCompilation = false. /// /// public abstract class DeepReinforcementLearningAgentBase : ReinforcementLearningAgentBase diff --git a/src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs b/src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs index 1ea1b781ac..a53d0b7cbf 100644 --- a/src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs +++ b/src/ReinforcementLearning/Agents/ReinforcementLearningAgentBase.cs @@ -505,110 +505,6 @@ public virtual void LoadState(Stream stream) } } - // ===== IJitCompilable, Vector> Implementation ===== - - /// - /// Gets whether this RL agent supports JIT compilation. - /// - /// - /// False for the base class. Derived classes may override to return true if they support JIT compilation. - /// - /// - /// - /// Most RL agents do not directly support JIT compilation because: - /// - They use layer-based neural networks without direct computation graph export - /// - Tabular methods use lookup tables rather than mathematical operations - /// - Policy selection often involves dynamic branching based on exploration strategies - /// - /// - /// Deep RL agents that use neural networks (DQN, PPO, SAC, etc.) may override this - /// to delegate JIT compilation to their underlying policy or value networks if those - /// networks support computation graph export. - /// - /// For Beginners: JIT compilation speeds up models by converting them to optimized code. - /// - /// RL agents typically don't support JIT compilation directly because: - /// - They combine multiple networks (policy, value, target networks) - /// - They use exploration strategies with random decisions - /// - The action selection process is complex and dynamic - /// - /// However, the underlying neural networks used by deep RL agents (like the Q-network in DQN) - /// can potentially be JIT compiled separately for faster inference. - /// - /// - public virtual bool SupportsJitCompilation => false; - - /// - /// Exports the agent's computation graph for JIT compilation. - /// - /// List to populate with input computation nodes. - /// The output computation node representing the agent's prediction. - /// - /// RL agents do not support direct JIT compilation. Use the underlying neural network for JIT compilation if needed. - /// - /// - /// - /// The base RL agent class does not support JIT compilation because RL agents are complex - /// systems that combine multiple components: - /// - Policy networks (select actions) - /// - Value networks (estimate state/action values) - /// - Target networks (provide stable training targets) - /// - Exploration strategies (epsilon-greedy, noise injection, etc.) - /// - Experience replay buffers - /// - /// - /// The action selection process in RL involves: - /// 1. Forward pass through policy/value network - /// 2. Exploration decision (random vs greedy) - /// 3. Action sampling or selection - /// 4. Potential action noise injection - /// - /// This complex pipeline with dynamic branching is not suitable for JIT compilation. - /// - /// Workaround for Deep RL Agents: - /// If you need to accelerate inference for deep RL agents (DQN, PPO, SAC, etc.), - /// consider JIT compiling the underlying neural networks separately: - /// - /// - /// // For DQN agent with Q-network - /// var dqnAgent = new DQNAgent<double>(options); - /// - /// // Access the Q-network directly if exposed - /// // (This requires the agent to expose its networks publicly or via a property) - /// var qNetwork = dqnAgent.QNetwork; // hypothetical property - /// - /// // JIT compile the Q-network for faster inference - /// if (qNetwork.SupportsJitCompilation) - /// { - /// var inputNodes = new List<ComputationNode<double>>(); - /// var graphOutput = qNetwork.ExportComputationGraph(inputNodes); - /// var jitCompiler = new JitCompiler<double>(graphOutput, inputNodes); - /// // Use jitCompiler.Evaluate() for fast Q-value computation - /// } - /// - /// - /// For Tabular RL Agents: - /// Tabular methods (Q-Learning, SARSA, etc.) use lookup tables rather than neural networks. - /// They perform dictionary lookups which cannot be JIT compiled. These agents are already - /// very fast for small state spaces and do not benefit from JIT compilation. - /// - /// - public virtual Autodiff.ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException( - "RL agents do not support direct JIT compilation. " + - "The agent's action selection involves complex processes including exploration strategies, " + - "multiple neural networks (policy, value, target), and dynamic branching that cannot be " + - "represented as a static computation graph. " + - "\n\n" + - "For deep RL agents (DQN, PPO, SAC, etc.), if you need faster inference, consider: " + - "\n1. Disabling exploration during inference (set training=false in SelectAction) " + - "\n2. Using the agent's Predict() method which uses the greedy policy " + - "\n3. JIT compiling the underlying neural networks separately if they are exposed " + - "\n\n" + - "For tabular RL agents (Q-Learning, SARSA, etc.), JIT compilation is not applicable " + - "as they use lookup tables which are already very fast for small state spaces."); - } } diff --git a/src/SurvivalAnalysis/SurvivalModelBase.cs b/src/SurvivalAnalysis/SurvivalModelBase.cs index 8de80daeae..b2206ee6a6 100644 --- a/src/SurvivalAnalysis/SurvivalModelBase.cs +++ b/src/SurvivalAnalysis/SurvivalModelBase.cs @@ -97,20 +97,6 @@ public abstract class SurvivalModelBase : ISurvivalModel, IModelShape public virtual Vector SanitizeParameters(Vector parameters) => parameters; - /// - /// Gets whether JIT compilation is supported. - /// - /// - /// - /// For Beginners: JIT (Just-In-Time) compilation can significantly accelerate - /// model inference by compiling the computation graph to optimized machine code. - /// Parametric models like Cox Proportional Hazards support JIT since their predictions - /// follow a clear mathematical formula. Non-parametric models like Kaplan-Meier are - /// harder to JIT compile since they rely on table lookups. - /// - /// - public virtual bool SupportsJitCompilation => false; - /// /// Initializes a new instance of the SurvivalModelBase class. /// @@ -743,13 +729,5 @@ public virtual void LoadState(Stream stream) Deserialize(serializedData); } - /// - /// Exports the computation graph for JIT compilation. - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation is not supported for survival models."); - } - #endregion } diff --git a/src/TimeSeries/NBEATSBlock.cs b/src/TimeSeries/NBEATSBlock.cs index 137de73cf6..395a40d56e 100644 --- a/src/TimeSeries/NBEATSBlock.cs +++ b/src/TimeSeries/NBEATSBlock.cs @@ -1,710 +1,767 @@ -using AiDotNet.Attributes; -using AiDotNet.Autodiff; -using AiDotNet.Enums; - -namespace AiDotNet.TimeSeries; - -/// -/// Represents a single block in the N-BEATS architecture. -/// -/// The numeric type used for calculations (e.g., float, double). -/// -/// -/// Each N-BEATS block consists of: -/// 1. A stack of fully connected layers (the "theta" network) -/// 2. A basis expansion layer for generating backcast (reconstruction of input) -/// 3. A basis expansion layer for generating forecast (prediction of future) -/// -/// -/// The block architecture implements a doubly residual stacking principle: -/// - Backcast residual: Input minus backcast is passed to the next block -/// - Forecast addition: Forecasts from all blocks are summed for the final prediction -/// -/// For Beginners: A block is the basic building unit of N-BEATS. Think of it like -/// a specialized predictor that: -/// 1. Looks at the input time series -/// 2. Tries to reconstruct what it saw (backcast) -/// 3. Predicts the future (forecast) -/// 4. Passes the "leftover" patterns it couldn't explain to the next block -/// -/// Multiple blocks work together, with each one focusing on different aspects of the data. -/// -/// -internal class NBEATSBlock : NeuralNetworks.Layers.LayerBase -{ - private readonly int _lookbackWindow; - private readonly int _forecastHorizon; - private readonly int _hiddenLayerSize; - private readonly int _numHiddenLayers; - private readonly int _thetaSizeBackcast; - private readonly int _thetaSizeForecast; - private readonly bool _useInterpretableBasis; - private readonly int _polynomialDegree; - - /// - /// Initializes a new instance with default settings. - /// - public NBEATSBlock() - : this(64, 16, 128, 4, 64, 16, false) - { - } - - /// - /// Weights for the fully connected layers (theta network), stored as Tensor<T> - /// for tape-based automatic differentiation. - /// - private List> _fcWeights; - - /// - /// Biases for the fully connected layers (theta network), stored as Tensor<T> - /// for tape-based automatic differentiation. - /// - private List> _fcBiases; - - /// - /// Precomputed basis matrix for backcast expansion: [lookbackWindow, thetaSizeBackcast]. - /// - private Tensor _basisBackcast; - - /// - /// Precomputed basis matrix for forecast expansion: [forecastHorizon, thetaSizeForecast]. - /// - private Tensor _basisForecast; - - /// - /// Gets the total number of trainable parameters in the block. - /// - public override int ParameterCount - { - get - { - int count = 0; - foreach (var weight in _fcWeights) - { - count += weight.Length; - } - foreach (var bias in _fcBiases) - { - count += bias.Length; - } - return count; - } - } - - /// - /// Initializes a new instance of the NBEATSBlock class. - /// - /// The number of historical time steps used as input. - /// The number of future time steps to predict. - /// The size of hidden layers in the fully connected network. - /// The number of hidden layers. - /// The size of the theta vector for backcast basis expansion. - /// The size of the theta vector for forecast basis expansion. - /// Whether to use interpretable basis functions. - /// The polynomial degree for trend basis (if interpretable). - /// - /// For Beginners: This creates a new block with specific parameters: - /// - lookbackWindow: How far back in time the block looks - /// - forecastHorizon: How far forward in time the block predicts - /// - hiddenLayerSize: How many neurons in each hidden layer (bigger = more capacity) - /// - numHiddenLayers: How many hidden layers (deeper = more complex patterns) - /// - useInterpretableBasis: Whether to use human-understandable basis functions - /// - /// - public NBEATSBlock( - int lookbackWindow, - int forecastHorizon, - int hiddenLayerSize, - int numHiddenLayers, - int thetaSizeBackcast, - int thetaSizeForecast, - bool useInterpretableBasis, - int polynomialDegree = 3) - : base(new[] { lookbackWindow }, new[] { lookbackWindow + forecastHorizon }) - { - if (lookbackWindow <= 0) - { - throw new ArgumentException("Lookback window must be positive.", nameof(lookbackWindow)); - } - if (forecastHorizon <= 0) - { - throw new ArgumentException("Forecast horizon must be positive.", nameof(forecastHorizon)); - } - if (hiddenLayerSize <= 0) - { - throw new ArgumentException("Hidden layer size must be positive.", nameof(hiddenLayerSize)); - } - if (numHiddenLayers <= 0) - { - throw new ArgumentException("Number of hidden layers must be positive.", nameof(numHiddenLayers)); - } - - _lookbackWindow = lookbackWindow; - _forecastHorizon = forecastHorizon; - _hiddenLayerSize = hiddenLayerSize; - _numHiddenLayers = numHiddenLayers; - _thetaSizeBackcast = thetaSizeBackcast; - _thetaSizeForecast = thetaSizeForecast; - _useInterpretableBasis = useInterpretableBasis; - _polynomialDegree = polynomialDegree; - - _fcWeights = new List>(); - _fcBiases = new List>(); - - if (_useInterpretableBasis) - { - // Interpretable blocks: fixed polynomial basis (not trainable) - // Per Oreshkin et al. 2020 Section 3.3 - _basisBackcast = ComputeBasisTensor(_thetaSizeBackcast, _lookbackWindow); - _basisForecast = ComputeBasisTensor(_thetaSizeForecast, _forecastHorizon); - } - else - { - // Generic blocks: V_b and V_f are fully learnable linear functions. - // Per Oreshkin et al. 2020 Section 3.2: - // "In the generic architecture, we do not restrict g^b and g^f to a - // particular functional form, and instead make them fully learnable" - // Initialize near identity for stable initial behavior. - var data_b = new T[_lookbackWindow * _thetaSizeBackcast]; - var data_f = new T[_forecastHorizon * _thetaSizeForecast]; - for (int i = 0; i < _lookbackWindow; i++) - for (int j = 0; j < _thetaSizeBackcast; j++) - data_b[i * _thetaSizeBackcast + j] = (i == j) ? NumOps.One : NumOps.Zero; - for (int i = 0; i < _forecastHorizon; i++) - for (int j = 0; j < _thetaSizeForecast; j++) - data_f[i * _thetaSizeForecast + j] = (i == j) ? NumOps.One : NumOps.Zero; - _basisBackcast = new Tensor(data_b, new[] { _lookbackWindow, _thetaSizeBackcast }); - _basisForecast = new Tensor(data_f, new[] { _forecastHorizon, _thetaSizeForecast }); - } - - InitializeWeights(); - } - - /// - /// Initializes the weights and biases for the fully connected layers. - /// Uses He initialization for ReLU networks and registers all parameters as trainable - /// for tape-based autodiff. - /// - private void InitializeWeights() - { - var random = RandomHelper.CreateSeededRandom(42); - - // First layer: lookbackWindow -> hiddenLayerSize - int inputSize = _lookbackWindow; - double stddev = Math.Sqrt(2.0 / inputSize); - var weight = CreateWeightTensor(_hiddenLayerSize, inputSize, stddev, random); - _fcWeights.Add(weight); - RegisterTrainableParameter(weight, PersistentTensorRole.Weights); - - var bias = CreateBiasTensor(_hiddenLayerSize, 0.01); - _fcBiases.Add(bias); - RegisterTrainableParameter(bias, PersistentTensorRole.Biases); - - // Hidden layers: hiddenLayerSize -> hiddenLayerSize - for (int layer = 1; layer < _numHiddenLayers; layer++) - { - stddev = Math.Sqrt(2.0 / _hiddenLayerSize); - weight = CreateWeightTensor(_hiddenLayerSize, _hiddenLayerSize, stddev, random); - _fcWeights.Add(weight); - RegisterTrainableParameter(weight, PersistentTensorRole.Weights); - - bias = CreateBiasTensor(_hiddenLayerSize, 0.01); - _fcBiases.Add(bias); - RegisterTrainableParameter(bias, PersistentTensorRole.Biases); - } - - // Output layer for backcast theta: hiddenLayerSize -> thetaSizeBackcast - stddev = Math.Sqrt(2.0 / (_hiddenLayerSize + _thetaSizeBackcast)); - weight = CreateWeightTensor(_thetaSizeBackcast, _hiddenLayerSize, stddev, random); - _fcWeights.Add(weight); - RegisterTrainableParameter(weight, PersistentTensorRole.Weights); - - bias = CreateBiasTensor(_thetaSizeBackcast, 0.0); - _fcBiases.Add(bias); - RegisterTrainableParameter(bias, PersistentTensorRole.Biases); - - // Output layer for forecast theta: hiddenLayerSize -> thetaSizeForecast - stddev = Math.Sqrt(2.0 / (_hiddenLayerSize + _thetaSizeForecast)); - weight = CreateWeightTensor(_thetaSizeForecast, _hiddenLayerSize, stddev, random); - _fcWeights.Add(weight); - RegisterTrainableParameter(weight, PersistentTensorRole.Weights); - - bias = CreateBiasTensor(_thetaSizeForecast, 0.0); - _fcBiases.Add(bias); - RegisterTrainableParameter(bias, PersistentTensorRole.Biases); - - // For generic blocks: register V_b and V_f as trainable - // Per Oreshkin et al. 2020 Section 3.2 - if (!_useInterpretableBasis) - { - RegisterTrainableParameter(_basisBackcast, PersistentTensorRole.Weights); - RegisterTrainableParameter(_basisForecast, PersistentTensorRole.Weights); - } - } - - /// - /// Creates a weight tensor with He initialization. - /// - private Tensor CreateWeightTensor(int rows, int cols, double stddev, Random random) - { - var data = new T[rows * cols]; - for (int i = 0; i < data.Length; i++) - { - data[i] = NumOps.FromDouble(random.NextDouble() * stddev * 2 - stddev); - } - return new Tensor(new[] { rows, cols }, new Vector(data)); - } - - /// - /// Creates a bias tensor initialized to a constant value. - /// - private Tensor CreateBiasTensor(int size, double initValue) - { - var data = new T[size]; - for (int i = 0; i < size; i++) - { - data[i] = NumOps.FromDouble(initValue); - } - return new Tensor(new[] { size }, new Vector(data)); - } - - /// - /// LayerBase Forward -- uses tape-tracked Engine operations for automatic differentiation. - /// Output tensor layout: [backcast(lookbackWindow) | forecast(forecastHorizon)]. - /// - public override Tensor Forward(Tensor input) - { - // Use Engine.Reshape for tape-tracked reshaping - var x = Engine.Reshape(input, [_lookbackWindow, 1]); - - // Pass through hidden layers with ReLU - for (int layer = 0; layer < _numHiddenLayers; layer++) - { - // Linear: y = W @ x + b - var linear = Engine.TensorMatMul(_fcWeights[layer], x); - // Add bias: reshape bias to column [hidden, 1] - var biasCol = Engine.Reshape(_fcBiases[layer], [_hiddenLayerSize, 1]); - linear = Engine.TensorAdd(linear, biasCol); - // ReLU activation - x = Engine.ReLU(linear); - } - - // Compute theta for backcast: [thetaSizeBackcast, 1] - int backcastLayerIdx = _numHiddenLayers; - var thetaBackcast = Engine.TensorMatMul(_fcWeights[backcastLayerIdx], x); - var bcBiasCol = Engine.Reshape(_fcBiases[backcastLayerIdx], [_thetaSizeBackcast, 1]); - thetaBackcast = Engine.TensorAdd(thetaBackcast, bcBiasCol); - - // Compute theta for forecast: [thetaSizeForecast, 1] - int forecastLayerIdx = _numHiddenLayers + 1; - var thetaForecast = Engine.TensorMatMul(_fcWeights[forecastLayerIdx], x); - var fcBiasCol = Engine.Reshape(_fcBiases[forecastLayerIdx], [_thetaSizeForecast, 1]); - thetaForecast = Engine.TensorAdd(thetaForecast, fcBiasCol); - - // Basis expansion: backcast = B_backcast @ theta_backcast - var backcast = Engine.TensorMatMul(_basisBackcast, thetaBackcast); // [lookbackWindow, 1] - // Basis expansion: forecast = B_forecast @ theta_forecast - var forecast = Engine.TensorMatMul(_basisForecast, thetaForecast); // [forecastHorizon, 1] - - // Concatenate backcast and forecast into output: flatten to 1D - var backcastFlat = Engine.Reshape(backcast, [_lookbackWindow]); - var forecastFlat = Engine.Reshape(forecast, [_forecastHorizon]); - - // Engine.TensorConcatenate along axis 0 is a 1:1 replacement for the scalar - // copy loop: it produces the same [lookbackWindow + forecastHorizon] 1D tensor - // by copying backcastFlat elements followed by forecastFlat elements. - var output = Engine.TensorConcatenate([backcastFlat, forecastFlat], axis: 0); - - return output; - } - - /// - /// Tape-tracked forward pass that returns separate backcast and forecast tensors. - /// Used by the NBEATSModel during training for residual block-by-block processing. - /// - public (Tensor backcast, Tensor forecast) ForwardTape(Tensor input) - { - // Use Engine.Reshape (tape-tracked) instead of tensor.Reshape (not tracked) - var x = Engine.Reshape(input, [_lookbackWindow, 1]); - - // Pass through hidden layers with ReLU - for (int layer = 0; layer < _numHiddenLayers; layer++) - { - var linear = Engine.TensorMatMul(_fcWeights[layer], x); - var biasCol = Engine.Reshape(_fcBiases[layer], [_hiddenLayerSize, 1]); - linear = Engine.TensorAdd(linear, biasCol); - x = Engine.ReLU(linear); - } - - // Compute theta for backcast - int backcastLayerIdx = _numHiddenLayers; - var thetaBackcast = Engine.TensorMatMul(_fcWeights[backcastLayerIdx], x); - var bcBiasCol = Engine.Reshape(_fcBiases[backcastLayerIdx], [_thetaSizeBackcast, 1]); - thetaBackcast = Engine.TensorAdd(thetaBackcast, bcBiasCol); - - // Compute theta for forecast - int forecastLayerIdx = _numHiddenLayers + 1; - var thetaForecast = Engine.TensorMatMul(_fcWeights[forecastLayerIdx], x); - var fcBiasCol = Engine.Reshape(_fcBiases[forecastLayerIdx], [_thetaSizeForecast, 1]); - thetaForecast = Engine.TensorAdd(thetaForecast, fcBiasCol); - - // Basis expansion — use Engine.Reshape for tape-tracked reshape - var backcastRaw = Engine.TensorMatMul(_basisBackcast, thetaBackcast); - var backcast = Engine.Reshape(backcastRaw, [_lookbackWindow]); - var forecastRaw = Engine.TensorMatMul(_basisForecast, thetaForecast); - var forecast = Engine.Reshape(forecastRaw, [_forecastHorizon]); - - return (backcast, forecast); - } - - public override bool SupportsTraining => true; - - public override void ResetState() { /* stateless layer -- no recurrent state to reset */ } - - /// - /// No-op: tape-based training handles parameter updates through the optimizer. - /// - public override void UpdateParameters(T learningRate) { } - - /// - /// Non-tape forward pass for inference (used by PredictSingle). - /// Uses plain matrix/vector operations without tape overhead. - /// - public (Vector backcast, Vector forecast) ForwardInternal(Vector input) - { - if (input.Length != _lookbackWindow) - { - throw new ArgumentException( - $"Input length ({input.Length}) must match lookback window ({_lookbackWindow}).", - nameof(input)); - } - - // Pass through fully connected layers with ReLU activation - Vector x = input.Clone(); - - for (int layer = 0; layer < _numHiddenLayers; layer++) - { - // Linear transformation: y = Wx + b using tensor operations - var xCol = new Tensor(new[] { x.Length, 1 }, x); - var wxResult = Engine.TensorMatMul(_fcWeights[layer], xCol); - Vector linear = new Vector(_hiddenLayerSize); - var biasVec = _fcBiases[layer].ToVector(); - for (int i = 0; i < _hiddenLayerSize; i++) - { - linear[i] = NumOps.Add(biasVec[i], wxResult[i, 0]); - } - - // ReLU activation - x = new Vector(linear.Length); - for (int i = 0; i < linear.Length; i++) - { - x[i] = NumOps.GreaterThan(linear[i], NumOps.Zero) ? linear[i] : NumOps.Zero; - } - } - - // Compute theta for backcast - int backcastLayerIdx = _numHiddenLayers; - var xColTheta = new Tensor(new[] { x.Length, 1 }, x); - var bcWx = Engine.TensorMatMul(_fcWeights[backcastLayerIdx], xColTheta); - var bcBiasVec = _fcBiases[backcastLayerIdx].ToVector(); - Vector thetaBackcast = new Vector(_thetaSizeBackcast); - for (int i = 0; i < _thetaSizeBackcast; i++) - { - thetaBackcast[i] = NumOps.Add(bcBiasVec[i], bcWx[i, 0]); - } - - // Compute theta for forecast - int forecastLayerIdx = _numHiddenLayers + 1; - var fcWx = Engine.TensorMatMul(_fcWeights[forecastLayerIdx], xColTheta); - var fcBiasVec = _fcBiases[forecastLayerIdx].ToVector(); - Vector thetaForecast = new Vector(_thetaSizeForecast); - for (int i = 0; i < _thetaSizeForecast; i++) - { - thetaForecast[i] = NumOps.Add(fcBiasVec[i], fcWx[i, 0]); - } - - // Apply basis expansion - Vector backcast = ApplyBasisExpansion(thetaBackcast, _lookbackWindow); - Vector forecast = ApplyBasisExpansion(thetaForecast, _forecastHorizon); - - return (backcast, forecast); - } - - /// - /// Computes the basis matrix as a Tensor for tape-tracked operations. - /// Shape: [outputLength, thetaSize]. - /// - private Tensor ComputeBasisTensor(int thetaSize, int outputLength) - { - var data = new T[outputLength * thetaSize]; - - if (_useInterpretableBasis) - { - for (int t = 0; t < outputLength; t++) - { - double tNormalized = (double)t / outputLength; - for (int p = 0; p < Math.Min(thetaSize, _polynomialDegree + 1); p++) - { - data[t * thetaSize + p] = NumOps.FromDouble(Math.Pow(tNormalized, p)); - } - } - } - else - { - // Generic basis per Oreshkin et al. (2020): when thetaSize == outputLength, - // theta IS the output directly (identity basis). When they differ, use a - // simple identity-like mapping (1 on the diagonal, 0 elsewhere). - for (int t = 0; t < outputLength; t++) - { - for (int k = 0; k < thetaSize; k++) - { - data[t * thetaSize + k] = (t == k) - ? NumOps.One - : NumOps.Zero; - } - } - } - - return new Tensor(new[] { outputLength, thetaSize }, new Vector(data)); - } - - /// - /// Computes the basis matrix as a Matrix (for legacy operations). - /// Shape: [outputLength, thetaSize]. - /// - private Matrix ComputeBasisMatrix(int thetaSize, int outputLength) - { - var basis = new Matrix(outputLength, thetaSize); - - if (_useInterpretableBasis) - { - for (int t = 0; t < outputLength; t++) - { - double tNormalized = (double)t / outputLength; - for (int p = 0; p < Math.Min(thetaSize, _polynomialDegree + 1); p++) - { - basis[t, p] = NumOps.FromDouble(Math.Pow(tNormalized, p)); - } - } - } - else - { - // Generic basis: identity matrix (theta IS the output) - for (int t = 0; t < outputLength; t++) - { - for (int k = 0; k < thetaSize; k++) - { - basis[t, k] = (t == k) ? NumOps.One : NumOps.Zero; - } - } - } - - return basis; - } - - private Vector ApplyBasisExpansion(Vector theta, int outputLength) - { - Vector output = new Vector(outputLength); - - if (_useInterpretableBasis) - { - for (int t = 0; t < outputLength; t++) - { - T value = NumOps.Zero; - T tNormalized = NumOps.FromDouble((double)t / outputLength); - - for (int p = 0; p < Math.Min(theta.Length, _polynomialDegree + 1); p++) - { - T power = NumOps.One; - for (int k = 0; k < p; k++) - { - power = NumOps.Multiply(power, tNormalized); - } - value = NumOps.Add(value, NumOps.Multiply(theta[p], power)); - } - - output[t] = value; - } - } - else - { - // Generic basis: identity — theta[t] maps directly to output[t] - for (int t = 0; t < outputLength; t++) - { - output[t] = (t < theta.Length) ? theta[t] : NumOps.Zero; - } - } - - return output; - } - - /// - /// Gets all parameters (weights and biases) as a single vector. - /// - public override Vector GetParameters() - { - var parameters = new List(); - - foreach (var weight in _fcWeights) - { - var vec = weight.ToVector(); - parameters.AddRange(vec); - } - - foreach (var bias in _fcBiases) - { - var vec = bias.ToVector(); - parameters.AddRange(vec); - } - - return new Vector(parameters.ToArray()); - } - - /// - /// Sets all parameters (weights and biases) from a single vector. - /// - public override void SetParameters(Vector parameters) - { - if (parameters.Length != ParameterCount) - { - throw new ArgumentException( - $"Expected {ParameterCount} parameters, but got {parameters.Length}.", - nameof(parameters)); - } - - int idx = 0; - - for (int w = 0; w < _fcWeights.Count; w++) - { - var weight = _fcWeights[w]; - int len = weight.Length; - var data = new T[len]; - for (int i = 0; i < len; i++) - { - data[i] = parameters[idx++]; - } - int rows = weight.Shape[0]; - int cols = weight.Shape[1]; - _fcWeights[w] = new Tensor(new[] { rows, cols }, new Vector(data)); - } - - for (int b = 0; b < _fcBiases.Count; b++) - { - var bias = _fcBiases[b]; - int len = bias.Length; - var data = new T[len]; - for (int i = 0; i < len; i++) - { - data[i] = parameters[idx++]; - } - _fcBiases[b] = new Tensor(new[] { len }, new Vector(data)); - } - - // Re-register trainable parameters after replacing tensors - ReRegisterParameters(); - } - - /// - /// Re-registers all weight and bias tensors as trainable parameters. - /// Called after SetParameters replaces tensor instances. - /// - private void ReRegisterParameters() - { - // Clear and re-register (RegisterTrainableParameter handles dedup) - foreach (var w in _fcWeights) - RegisterTrainableParameter(w, PersistentTensorRole.Weights); - foreach (var b in _fcBiases) - RegisterTrainableParameter(b, PersistentTensorRole.Biases); - } - - /// - /// Exports the block as computation graph nodes for JIT compilation. - /// - public (ComputationNode backcast, ComputationNode forecast) ExportComputationGraph(ComputationNode inputNode) - { - var numOps = MathHelper.GetNumericOperations(); - var x = inputNode; - - for (int layer = 0; layer < _numHiddenLayers; layer++) - { - var weightTensor = _fcWeights[layer]; - var weightNode = TensorOperations.Constant(weightTensor, $"block_fc{layer}_weight"); - - var biasTensor = _fcBiases[layer]; - var biasNode = TensorOperations.Constant(biasTensor, $"block_fc{layer}_bias"); - - var linear = TensorOperations.MatrixVectorMultiply(weightNode, x); - linear = TensorOperations.Add(linear, biasNode); - - x = TensorOperations.ReLU(linear); - } - - // Compute theta for backcast - var backcastWeightNode = TensorOperations.Constant(_fcWeights[_numHiddenLayers], "block_backcast_weight"); - var backcastBiasNode = TensorOperations.Constant(_fcBiases[_numHiddenLayers], "block_backcast_bias"); - - var thetaBackcast = TensorOperations.MatrixVectorMultiply(backcastWeightNode, x); - thetaBackcast = TensorOperations.Add(thetaBackcast, backcastBiasNode); - - // Compute theta for forecast - var forecastWeightNode = TensorOperations.Constant(_fcWeights[_numHiddenLayers + 1], "block_forecast_weight"); - var forecastBiasNode = TensorOperations.Constant(_fcBiases[_numHiddenLayers + 1], "block_forecast_bias"); - - var thetaForecast = TensorOperations.MatrixVectorMultiply(forecastWeightNode, x); - thetaForecast = TensorOperations.Add(thetaForecast, forecastBiasNode); - - // Apply basis expansion - var backcastNode = ApplyBasisExpansionGraph(thetaBackcast, _lookbackWindow, isBackcast: true); - var forecastNode = ApplyBasisExpansionGraph(thetaForecast, _forecastHorizon, isBackcast: false); - - return (backcastNode, forecastNode); - } - - /// - /// Applies basis expansion in the computation graph. - /// - private ComputationNode ApplyBasisExpansionGraph(ComputationNode theta, int outputLength, bool isBackcast) - { - var numOps = MathHelper.GetNumericOperations(); - - if (_useInterpretableBasis) - { - var basisData = new T[outputLength * theta.Value.Shape[0]]; - int thetaSize = theta.Value.Shape[0]; - - for (int t = 0; t < outputLength; t++) - { - double tNormalized = (double)t / outputLength; - for (int p = 0; p < Math.Min(thetaSize, _polynomialDegree + 1); p++) - { - double power = Math.Pow(tNormalized, p); - basisData[t * thetaSize + p] = numOps.FromDouble(power); - } - } - - var basisTensor = new Tensor(new[] { outputLength, thetaSize }, new Vector(basisData)); - var basisNode = TensorOperations.Constant(basisTensor, isBackcast ? "backcast_basis" : "forecast_basis"); - - return TensorOperations.MatrixVectorMultiply(basisNode, theta); - } - else - { - var basisData = new T[outputLength * theta.Value.Shape[0]]; - int thetaSize = theta.Value.Shape[0]; - - for (int t = 0; t < outputLength; t++) - { - for (int k = 0; k < thetaSize; k++) - { - double cosValue = Math.Cos(2.0 * Math.PI * k * t / outputLength); - basisData[t * thetaSize + k] = numOps.FromDouble(cosValue); - } - } - - var basisTensor = new Tensor(new[] { outputLength, thetaSize }, new Vector(basisData)); - var basisNode = TensorOperations.Constant(basisTensor, isBackcast ? "backcast_basis" : "forecast_basis"); - - return TensorOperations.MatrixVectorMultiply(basisNode, theta); - } - } -} +using AiDotNet.Attributes; +using AiDotNet.Autodiff; +using AiDotNet.Enums; + +namespace AiDotNet.TimeSeries; + +/// +/// Represents a single block in the N-BEATS architecture. +/// +/// The numeric type used for calculations (e.g., float, double). +/// +/// +/// Each N-BEATS block consists of: +/// 1. A stack of fully connected layers (the "theta" network) +/// 2. A basis expansion layer for generating backcast (reconstruction of input) +/// 3. A basis expansion layer for generating forecast (prediction of future) +/// +/// +/// The block architecture implements a doubly residual stacking principle: +/// - Backcast residual: Input minus backcast is passed to the next block +/// - Forecast addition: Forecasts from all blocks are summed for the final prediction +/// +/// For Beginners: A block is the basic building unit of N-BEATS. Think of it like +/// a specialized predictor that: +/// 1. Looks at the input time series +/// 2. Tries to reconstruct what it saw (backcast) +/// 3. Predicts the future (forecast) +/// 4. Passes the "leftover" patterns it couldn't explain to the next block +/// +/// Multiple blocks work together, with each one focusing on different aspects of the data. +/// +/// +internal class NBEATSBlock : NeuralNetworks.Layers.LayerBase +{ + private readonly int _lookbackWindow; + private readonly int _forecastHorizon; + private readonly int _hiddenLayerSize; + private readonly int _numHiddenLayers; + private readonly int _thetaSizeBackcast; + private readonly int _thetaSizeForecast; + private readonly bool _useInterpretableBasis; + private readonly int _polynomialDegree; + + /// + /// Initializes a new instance with default settings. + /// + public NBEATSBlock() + : this(64, 16, 128, 4, 64, 16, false) + { + } + + /// + /// Weights for the fully connected layers (theta network), stored as Tensor<T> + /// for tape-based automatic differentiation. + /// + private List> _fcWeights; + + /// + /// Biases for the fully connected layers (theta network), stored as Tensor<T> + /// for tape-based automatic differentiation. + /// + private List> _fcBiases; + + /// + /// Precomputed basis matrix for backcast expansion: [lookbackWindow, thetaSizeBackcast]. + /// + private Tensor _basisBackcast; + + /// + /// Precomputed basis matrix for forecast expansion: [forecastHorizon, thetaSizeForecast]. + /// + private Tensor _basisForecast; + + /// + /// Gets the total number of trainable parameters in the block. + /// + public override int ParameterCount + { + get + { + int count = 0; + foreach (var weight in _fcWeights) + { + count += weight.Length; + } + foreach (var bias in _fcBiases) + { + count += bias.Length; + } + // Generic blocks: V_b and V_f bases are learnable per Oreshkin et al. + // 2020 Section 3.2. Interpretable blocks use fixed polynomial bases + // that aren't trainable, so don't include them in the parameter count. + if (!_useInterpretableBasis) + { + count += _basisBackcast.Length; + count += _basisForecast.Length; + } + return count; + } + } + + /// + /// Initializes a new instance of the NBEATSBlock class. + /// + /// The number of historical time steps used as input. + /// The number of future time steps to predict. + /// The size of hidden layers in the fully connected network. + /// The number of hidden layers. + /// The size of the theta vector for backcast basis expansion. + /// The size of the theta vector for forecast basis expansion. + /// Whether to use interpretable basis functions. + /// The polynomial degree for trend basis (if interpretable). + /// + /// For Beginners: This creates a new block with specific parameters: + /// - lookbackWindow: How far back in time the block looks + /// - forecastHorizon: How far forward in time the block predicts + /// - hiddenLayerSize: How many neurons in each hidden layer (bigger = more capacity) + /// - numHiddenLayers: How many hidden layers (deeper = more complex patterns) + /// - useInterpretableBasis: Whether to use human-understandable basis functions + /// + /// + /// + /// Validates and returns the corresponding + /// LayerBase input shape. Runs BEFORE the base ctor so invalid values surface + /// as with the argument name instead of a + /// downstream shape error. + /// + private static int[] CreateInputShape(int lookbackWindow) + { + if (lookbackWindow <= 0) + { + throw new ArgumentException("Lookback window must be positive.", nameof(lookbackWindow)); + } + return new[] { lookbackWindow }; + } + + /// + /// Validates (and re-checks lookback for + /// consistency) and returns the corresponding LayerBase output shape. + /// + private static int[] CreateOutputShape(int lookbackWindow, int forecastHorizon) + { + if (lookbackWindow <= 0) + { + throw new ArgumentException("Lookback window must be positive.", nameof(lookbackWindow)); + } + if (forecastHorizon <= 0) + { + throw new ArgumentException("Forecast horizon must be positive.", nameof(forecastHorizon)); + } + return new[] { lookbackWindow + forecastHorizon }; + } + + public NBEATSBlock( + int lookbackWindow, + int forecastHorizon, + int hiddenLayerSize, + int numHiddenLayers, + int thetaSizeBackcast, + int thetaSizeForecast, + bool useInterpretableBasis, + int polynomialDegree = 3) + : base( + CreateInputShape(lookbackWindow), + CreateOutputShape(lookbackWindow, forecastHorizon)) + { + // Primary-argument validation happens inside the static shape factories + // above so `lookbackWindow` / `forecastHorizon` are rejected BEFORE + // LayerBase consumes them — users see the nameof(...)-tagged + // ArgumentException instead of a downstream shape error from the base. + // (The two blocks that previously validated those here are now in + // CreateInputShape / CreateOutputShape below.) + if (hiddenLayerSize <= 0) + { + throw new ArgumentException("Hidden layer size must be positive.", nameof(hiddenLayerSize)); + } + if (numHiddenLayers <= 0) + { + throw new ArgumentException("Number of hidden layers must be positive.", nameof(numHiddenLayers)); + } + if (thetaSizeBackcast <= 0) + { + throw new ArgumentException("Backcast theta size must be positive.", nameof(thetaSizeBackcast)); + } + if (thetaSizeForecast <= 0) + { + throw new ArgumentException("Forecast theta size must be positive.", nameof(thetaSizeForecast)); + } + if (useInterpretableBasis && polynomialDegree < 0) + { + throw new ArgumentException("Polynomial degree must be non-negative for interpretable basis.", nameof(polynomialDegree)); + } + // Interpretable-basis builders cap usable theta at polynomialDegree + 1 + // (ComputeBasisTensor populates only that many rows; ApplyBasisExpansion + // slices to the same count). Silently accepting oversized theta sizes + // would allocate trainable weights that are mathematically disconnected + // from the output — dead parameters that waste memory and mask bugs + // during gradient checks. + if (useInterpretableBasis && thetaSizeBackcast > polynomialDegree + 1) + { + throw new ArgumentException( + $"Backcast theta size ({thetaSizeBackcast}) cannot exceed polynomialDegree + 1 ({polynomialDegree + 1}) for interpretable basis.", + nameof(thetaSizeBackcast)); + } + if (useInterpretableBasis && thetaSizeForecast > polynomialDegree + 1) + { + throw new ArgumentException( + $"Forecast theta size ({thetaSizeForecast}) cannot exceed polynomialDegree + 1 ({polynomialDegree + 1}) for interpretable basis.", + nameof(thetaSizeForecast)); + } + + _lookbackWindow = lookbackWindow; + _forecastHorizon = forecastHorizon; + _hiddenLayerSize = hiddenLayerSize; + _numHiddenLayers = numHiddenLayers; + _thetaSizeBackcast = thetaSizeBackcast; + _thetaSizeForecast = thetaSizeForecast; + _useInterpretableBasis = useInterpretableBasis; + _polynomialDegree = polynomialDegree; + + _fcWeights = new List>(); + _fcBiases = new List>(); + + if (_useInterpretableBasis) + { + // Interpretable blocks: fixed polynomial basis (not trainable) + // Per Oreshkin et al. 2020 Section 3.3 + _basisBackcast = ComputeBasisTensor(_thetaSizeBackcast, _lookbackWindow); + _basisForecast = ComputeBasisTensor(_thetaSizeForecast, _forecastHorizon); + } + else + { + // Generic blocks: V_b and V_f are fully learnable linear functions. + // Per Oreshkin et al. 2020 Section 3.2: + // "In the generic architecture, we do not restrict g^b and g^f to a + // particular functional form, and instead make them fully learnable" + // Initialize near identity for stable initial behavior. + var data_b = new T[_lookbackWindow * _thetaSizeBackcast]; + var data_f = new T[_forecastHorizon * _thetaSizeForecast]; + for (int i = 0; i < _lookbackWindow; i++) + for (int j = 0; j < _thetaSizeBackcast; j++) + data_b[i * _thetaSizeBackcast + j] = (i == j) ? NumOps.One : NumOps.Zero; + for (int i = 0; i < _forecastHorizon; i++) + for (int j = 0; j < _thetaSizeForecast; j++) + data_f[i * _thetaSizeForecast + j] = (i == j) ? NumOps.One : NumOps.Zero; + _basisBackcast = new Tensor(data_b, new[] { _lookbackWindow, _thetaSizeBackcast }); + _basisForecast = new Tensor(data_f, new[] { _forecastHorizon, _thetaSizeForecast }); + } + + InitializeWeights(); + } + + /// + /// Initializes the weights and biases for the fully connected layers. + /// Uses He initialization for ReLU networks and registers all parameters as trainable + /// for tape-based autodiff. + /// + private void InitializeWeights() + { + var random = RandomHelper.CreateSeededRandom(42); + + // First layer: lookbackWindow -> hiddenLayerSize + int inputSize = _lookbackWindow; + double stddev = Math.Sqrt(2.0 / inputSize); + var weight = CreateWeightTensor(_hiddenLayerSize, inputSize, stddev, random); + _fcWeights.Add(weight); + RegisterTrainableParameter(weight, PersistentTensorRole.Weights); + + var bias = CreateBiasTensor(_hiddenLayerSize, 0.01); + _fcBiases.Add(bias); + RegisterTrainableParameter(bias, PersistentTensorRole.Biases); + + // Hidden layers: hiddenLayerSize -> hiddenLayerSize + for (int layer = 1; layer < _numHiddenLayers; layer++) + { + stddev = Math.Sqrt(2.0 / _hiddenLayerSize); + weight = CreateWeightTensor(_hiddenLayerSize, _hiddenLayerSize, stddev, random); + _fcWeights.Add(weight); + RegisterTrainableParameter(weight, PersistentTensorRole.Weights); + + bias = CreateBiasTensor(_hiddenLayerSize, 0.01); + _fcBiases.Add(bias); + RegisterTrainableParameter(bias, PersistentTensorRole.Biases); + } + + // Output layer for backcast theta: hiddenLayerSize -> thetaSizeBackcast + stddev = Math.Sqrt(2.0 / (_hiddenLayerSize + _thetaSizeBackcast)); + weight = CreateWeightTensor(_thetaSizeBackcast, _hiddenLayerSize, stddev, random); + _fcWeights.Add(weight); + RegisterTrainableParameter(weight, PersistentTensorRole.Weights); + + bias = CreateBiasTensor(_thetaSizeBackcast, 0.0); + _fcBiases.Add(bias); + RegisterTrainableParameter(bias, PersistentTensorRole.Biases); + + // Output layer for forecast theta: hiddenLayerSize -> thetaSizeForecast + stddev = Math.Sqrt(2.0 / (_hiddenLayerSize + _thetaSizeForecast)); + weight = CreateWeightTensor(_thetaSizeForecast, _hiddenLayerSize, stddev, random); + _fcWeights.Add(weight); + RegisterTrainableParameter(weight, PersistentTensorRole.Weights); + + bias = CreateBiasTensor(_thetaSizeForecast, 0.0); + _fcBiases.Add(bias); + RegisterTrainableParameter(bias, PersistentTensorRole.Biases); + + // For generic blocks: register V_b and V_f as trainable + // Per Oreshkin et al. 2020 Section 3.2 + if (!_useInterpretableBasis) + { + RegisterTrainableParameter(_basisBackcast, PersistentTensorRole.Weights); + RegisterTrainableParameter(_basisForecast, PersistentTensorRole.Weights); + } + } + + /// + /// Creates a weight tensor with He initialization. + /// + private Tensor CreateWeightTensor(int rows, int cols, double stddev, Random random) + { + var data = new T[rows * cols]; + for (int i = 0; i < data.Length; i++) + { + data[i] = NumOps.FromDouble(random.NextDouble() * stddev * 2 - stddev); + } + return new Tensor(new[] { rows, cols }, new Vector(data)); + } + + /// + /// Creates a bias tensor initialized to a constant value. + /// + private Tensor CreateBiasTensor(int size, double initValue) + { + var data = new T[size]; + for (int i = 0; i < size; i++) + { + data[i] = NumOps.FromDouble(initValue); + } + return new Tensor(new[] { size }, new Vector(data)); + } + + /// + /// LayerBase Forward -- uses tape-tracked Engine operations for automatic differentiation. + /// Output tensor layout: [backcast(lookbackWindow) | forecast(forecastHorizon)]. + /// + public override Tensor Forward(Tensor input) + { + // Use Engine.Reshape for tape-tracked reshaping + var x = Engine.Reshape(input, [_lookbackWindow, 1]); + + // Pass through hidden layers with ReLU + for (int layer = 0; layer < _numHiddenLayers; layer++) + { + // Linear: y = W @ x + b + var linear = Engine.TensorMatMul(_fcWeights[layer], x); + // Add bias: reshape bias to column [hidden, 1] + var biasCol = Engine.Reshape(_fcBiases[layer], [_hiddenLayerSize, 1]); + linear = Engine.TensorAdd(linear, biasCol); + // ReLU activation + x = Engine.ReLU(linear); + } + + // Compute theta for backcast: [thetaSizeBackcast, 1] + int backcastLayerIdx = _numHiddenLayers; + var thetaBackcast = Engine.TensorMatMul(_fcWeights[backcastLayerIdx], x); + var bcBiasCol = Engine.Reshape(_fcBiases[backcastLayerIdx], [_thetaSizeBackcast, 1]); + thetaBackcast = Engine.TensorAdd(thetaBackcast, bcBiasCol); + + // Compute theta for forecast: [thetaSizeForecast, 1] + int forecastLayerIdx = _numHiddenLayers + 1; + var thetaForecast = Engine.TensorMatMul(_fcWeights[forecastLayerIdx], x); + var fcBiasCol = Engine.Reshape(_fcBiases[forecastLayerIdx], [_thetaSizeForecast, 1]); + thetaForecast = Engine.TensorAdd(thetaForecast, fcBiasCol); + + // Basis expansion: backcast = B_backcast @ theta_backcast + var backcast = Engine.TensorMatMul(_basisBackcast, thetaBackcast); // [lookbackWindow, 1] + // Basis expansion: forecast = B_forecast @ theta_forecast + var forecast = Engine.TensorMatMul(_basisForecast, thetaForecast); // [forecastHorizon, 1] + + // Concatenate backcast and forecast into output: flatten to 1D + var backcastFlat = Engine.Reshape(backcast, [_lookbackWindow]); + var forecastFlat = Engine.Reshape(forecast, [_forecastHorizon]); + + // Engine.TensorConcatenate along axis 0 is a 1:1 replacement for the scalar + // copy loop: it produces the same [lookbackWindow + forecastHorizon] 1D tensor + // by copying backcastFlat elements followed by forecastFlat elements. + var output = Engine.TensorConcatenate([backcastFlat, forecastFlat], axis: 0); + + return output; + } + + /// + /// Tape-tracked forward pass that returns separate backcast and forecast tensors. + /// Used by the NBEATSModel during training for residual block-by-block processing. + /// + public (Tensor backcast, Tensor forecast) ForwardTape(Tensor input) + { + // Use Engine.Reshape (tape-tracked) instead of tensor.Reshape (not tracked) + var x = Engine.Reshape(input, [_lookbackWindow, 1]); + + // Pass through hidden layers with ReLU + for (int layer = 0; layer < _numHiddenLayers; layer++) + { + var linear = Engine.TensorMatMul(_fcWeights[layer], x); + var biasCol = Engine.Reshape(_fcBiases[layer], [_hiddenLayerSize, 1]); + linear = Engine.TensorAdd(linear, biasCol); + x = Engine.ReLU(linear); + } + + // Compute theta for backcast + int backcastLayerIdx = _numHiddenLayers; + var thetaBackcast = Engine.TensorMatMul(_fcWeights[backcastLayerIdx], x); + var bcBiasCol = Engine.Reshape(_fcBiases[backcastLayerIdx], [_thetaSizeBackcast, 1]); + thetaBackcast = Engine.TensorAdd(thetaBackcast, bcBiasCol); + + // Compute theta for forecast + int forecastLayerIdx = _numHiddenLayers + 1; + var thetaForecast = Engine.TensorMatMul(_fcWeights[forecastLayerIdx], x); + var fcBiasCol = Engine.Reshape(_fcBiases[forecastLayerIdx], [_thetaSizeForecast, 1]); + thetaForecast = Engine.TensorAdd(thetaForecast, fcBiasCol); + + // Basis expansion — use Engine.Reshape for tape-tracked reshape + var backcastRaw = Engine.TensorMatMul(_basisBackcast, thetaBackcast); + var backcast = Engine.Reshape(backcastRaw, [_lookbackWindow]); + var forecastRaw = Engine.TensorMatMul(_basisForecast, thetaForecast); + var forecast = Engine.Reshape(forecastRaw, [_forecastHorizon]); + + return (backcast, forecast); + } + + public override bool SupportsTraining => true; + + public override void ResetState() { /* stateless layer -- no recurrent state to reset */ } + + /// + /// Throws : this block is trained + /// through the tape-based optimizer path and has no eager scalar-step update. + /// + /// + /// N-BEATS blocks register their parameters via RegisterTrainableParameter + /// and are updated by the compiled training plan that + /// drives. Calling UpdateParameters(learningRate) directly bypasses + /// that path and would silently lose updates, so fail fast to catch the + /// misuse at the training boundary rather than later as a silent accuracy + /// regression. + /// + public override void UpdateParameters(T learningRate) + { + throw new InvalidOperationException( + $"{nameof(NBEATSBlock)} uses tape-based optimization. " + + "Update parameters through the optimizer / training step, " + + "not directly via UpdateParameters(learningRate)."); + } + + /// + /// Non-tape forward pass for inference (used by PredictSingle). + /// Uses plain matrix/vector operations without tape overhead. + /// + public (Vector backcast, Vector forecast) ForwardInternal(Vector input) + { + if (input.Length != _lookbackWindow) + { + throw new ArgumentException( + $"Input length ({input.Length}) must match lookback window ({_lookbackWindow}).", + nameof(input)); + } + + // Pass through fully connected layers with ReLU activation + Vector x = input.Clone(); + + for (int layer = 0; layer < _numHiddenLayers; layer++) + { + // Linear transformation: y = Wx + b using tensor operations + var xCol = new Tensor(new[] { x.Length, 1 }, x); + var wxResult = Engine.TensorMatMul(_fcWeights[layer], xCol); + Vector linear = new Vector(_hiddenLayerSize); + var biasVec = _fcBiases[layer].ToVector(); + for (int i = 0; i < _hiddenLayerSize; i++) + { + linear[i] = NumOps.Add(biasVec[i], wxResult[i, 0]); + } + + // ReLU activation + x = new Vector(linear.Length); + for (int i = 0; i < linear.Length; i++) + { + x[i] = NumOps.GreaterThan(linear[i], NumOps.Zero) ? linear[i] : NumOps.Zero; + } + } + + // Compute theta for backcast + int backcastLayerIdx = _numHiddenLayers; + var xColTheta = new Tensor(new[] { x.Length, 1 }, x); + var bcWx = Engine.TensorMatMul(_fcWeights[backcastLayerIdx], xColTheta); + var bcBiasVec = _fcBiases[backcastLayerIdx].ToVector(); + Vector thetaBackcast = new Vector(_thetaSizeBackcast); + for (int i = 0; i < _thetaSizeBackcast; i++) + { + thetaBackcast[i] = NumOps.Add(bcBiasVec[i], bcWx[i, 0]); + } + + // Compute theta for forecast + int forecastLayerIdx = _numHiddenLayers + 1; + var fcWx = Engine.TensorMatMul(_fcWeights[forecastLayerIdx], xColTheta); + var fcBiasVec = _fcBiases[forecastLayerIdx].ToVector(); + Vector thetaForecast = new Vector(_thetaSizeForecast); + for (int i = 0; i < _thetaSizeForecast; i++) + { + thetaForecast[i] = NumOps.Add(fcBiasVec[i], fcWx[i, 0]); + } + + // Apply basis expansion. Pass the matching basis tensor so generic + // blocks multiply by their learned V_b / V_f matrices (keeping this + // path consistent with Forward() / ForwardTape() and with the + // parameter export/import of _basisBackcast / _basisForecast). + Vector backcast = ApplyBasisExpansion(thetaBackcast, _basisBackcast, _lookbackWindow); + Vector forecast = ApplyBasisExpansion(thetaForecast, _basisForecast, _forecastHorizon); + + return (backcast, forecast); + } + + /// + /// Computes the basis matrix as a Tensor for tape-tracked operations. + /// Shape: [outputLength, thetaSize]. + /// + private Tensor ComputeBasisTensor(int thetaSize, int outputLength) + { + var data = new T[outputLength * thetaSize]; + + if (_useInterpretableBasis) + { + for (int t = 0; t < outputLength; t++) + { + double tNormalized = (double)t / outputLength; + for (int p = 0; p < Math.Min(thetaSize, _polynomialDegree + 1); p++) + { + data[t * thetaSize + p] = NumOps.FromDouble(Math.Pow(tNormalized, p)); + } + } + } + else + { + // Generic basis per Oreshkin et al. (2020): when thetaSize == outputLength, + // theta IS the output directly (identity basis). When they differ, use a + // simple identity-like mapping (1 on the diagonal, 0 elsewhere). + for (int t = 0; t < outputLength; t++) + { + for (int k = 0; k < thetaSize; k++) + { + data[t * thetaSize + k] = (t == k) + ? NumOps.One + : NumOps.Zero; + } + } + } + + return new Tensor(new[] { outputLength, thetaSize }, new Vector(data)); + } + + /// + /// Computes the basis matrix as a Matrix (for legacy operations). + /// Shape: [outputLength, thetaSize]. + /// + private Matrix ComputeBasisMatrix(int thetaSize, int outputLength) + { + var basis = new Matrix(outputLength, thetaSize); + + if (_useInterpretableBasis) + { + for (int t = 0; t < outputLength; t++) + { + double tNormalized = (double)t / outputLength; + for (int p = 0; p < Math.Min(thetaSize, _polynomialDegree + 1); p++) + { + basis[t, p] = NumOps.FromDouble(Math.Pow(tNormalized, p)); + } + } + } + else + { + // Generic basis: identity matrix (theta IS the output) + for (int t = 0; t < outputLength; t++) + { + for (int k = 0; k < thetaSize; k++) + { + basis[t, k] = (t == k) ? NumOps.One : NumOps.Zero; + } + } + } + + return basis; + } + + /// + /// Expands the theta coefficients into an output time series of the requested length. + /// + /// The theta coefficient vector produced by the fc head. + /// + /// The basis matrix for the generic branch — shape [outputLength, theta.Length]. + /// Ignored when is true (the closed-form + /// polynomial basis is computed on the fly from ). + /// + /// Length of the expanded output vector. + private Vector ApplyBasisExpansion(Vector theta, Tensor basis, int outputLength) + { + Vector output = new Vector(outputLength); + + if (_useInterpretableBasis) + { + for (int t = 0; t < outputLength; t++) + { + T value = NumOps.Zero; + T tNormalized = NumOps.FromDouble((double)t / outputLength); + + for (int p = 0; p < Math.Min(theta.Length, _polynomialDegree + 1); p++) + { + T power = NumOps.One; + for (int k = 0; k < p; k++) + { + power = NumOps.Multiply(power, tNormalized); + } + value = NumOps.Add(value, NumOps.Multiply(theta[p], power)); + } + + output[t] = value; + } + } + else + { + // Generic basis: output = basis · theta. Must use the learned V_b/V_f + // matrices per Oreshkin et al. 2020 Section 3.2 — they round-trip through + // GetParameters/SetParameters as trainable weights, and the tape-based + // Forward path multiplies by the same tensors. Returning theta directly + // here (as the pre-fix code did) made PredictSingle diverge from both + // training and model-load state. + for (int t = 0; t < outputLength; t++) + { + T value = NumOps.Zero; + for (int k = 0; k < theta.Length; k++) + { + value = NumOps.Add(value, NumOps.Multiply(basis[t, k], theta[k])); + } + output[t] = value; + } + } + + return output; + } + + /// + /// Gets all parameters (weights and biases) as a single vector. + /// + public override Vector GetParameters() + { + var parameters = new List(); + + foreach (var weight in _fcWeights) + { + var vec = weight.ToVector(); + parameters.AddRange(vec); + } + + foreach (var bias in _fcBiases) + { + var vec = bias.ToVector(); + parameters.AddRange(vec); + } + + // Generic blocks: include trainable V_b / V_f bases so export round-trips + // don't drop learned basis state. Ordering (fc weights, fc biases, then + // bases) must match SetParameters. + if (!_useInterpretableBasis) + { + parameters.AddRange(_basisBackcast.ToVector()); + parameters.AddRange(_basisForecast.ToVector()); + } + + return new Vector(parameters.ToArray()); + } + + /// + /// Sets all parameters (weights and biases) from a single vector. + /// + public override void SetParameters(Vector parameters) + { + if (parameters.Length != ParameterCount) + { + throw new ArgumentException( + $"Expected {ParameterCount} parameters, but got {parameters.Length}.", + nameof(parameters)); + } + + int idx = 0; + + for (int w = 0; w < _fcWeights.Count; w++) + { + var weight = _fcWeights[w]; + int len = weight.Length; + var data = new T[len]; + for (int i = 0; i < len; i++) + { + data[i] = parameters[idx++]; + } + int rows = weight.Shape[0]; + int cols = weight.Shape[1]; + _fcWeights[w] = new Tensor(new[] { rows, cols }, new Vector(data)); + } + + for (int b = 0; b < _fcBiases.Count; b++) + { + var bias = _fcBiases[b]; + int len = bias.Length; + var data = new T[len]; + for (int i = 0; i < len; i++) + { + data[i] = parameters[idx++]; + } + _fcBiases[b] = new Tensor(new[] { len }, new Vector(data)); + } + + // Generic blocks: restore trainable V_b / V_f bases. Must match the + // order GetParameters produced them in. + if (!_useInterpretableBasis) + { + int backcastLen = _basisBackcast.Length; + var backcastData = new T[backcastLen]; + for (int i = 0; i < backcastLen; i++) + { + backcastData[i] = parameters[idx++]; + } + _basisBackcast = new Tensor(_basisBackcast.Shape.ToArray(), new Vector(backcastData)); + + int forecastLen = _basisForecast.Length; + var forecastData = new T[forecastLen]; + for (int i = 0; i < forecastLen; i++) + { + forecastData[i] = parameters[idx++]; + } + _basisForecast = new Tensor(_basisForecast.Shape.ToArray(), new Vector(forecastData)); + } + + // Re-register trainable parameters after replacing tensors + ReRegisterParameters(); + } + + /// + /// Re-registers all weight and bias tensors as trainable parameters. + /// Called after SetParameters replaces tensor instances. + /// + private void ReRegisterParameters() + { + // Clear and re-register (RegisterTrainableParameter handles dedup) + foreach (var w in _fcWeights) + RegisterTrainableParameter(w, PersistentTensorRole.Weights); + foreach (var b in _fcBiases) + RegisterTrainableParameter(b, PersistentTensorRole.Biases); + + // Generic blocks also learn the basis matrices — re-register them after + // SetParameters replaces the tensor instances. Interpretable blocks use + // fixed polynomial bases that are not trainable, so skip. + if (!_useInterpretableBasis) + { + RegisterTrainableParameter(_basisBackcast, PersistentTensorRole.Weights); + RegisterTrainableParameter(_basisForecast, PersistentTensorRole.Weights); + } + } + +} diff --git a/src/TimeSeries/TimeSeriesModelBase.cs b/src/TimeSeries/TimeSeriesModelBase.cs index 80a1d97c77..48ca6c5000 100644 --- a/src/TimeSeries/TimeSeriesModelBase.cs +++ b/src/TimeSeries/TimeSeriesModelBase.cs @@ -1,2070 +1,1957 @@ -using System.Threading; -using AiDotNet.Autodiff; -using AiDotNet.Helpers; -using AiDotNet.Tensors.Engines.Autodiff; - -namespace AiDotNet.TimeSeries; - -/// -/// Provides a base class for all time series forecasting models in the library. -/// -/// The numeric data type used for calculations (e.g., float, double). -/// -/// -/// This abstract class defines the common interface and functionality that all time series models share, -/// including training, prediction, evaluation, and serialization/deserialization capabilities. -/// -/// -/// Time series models capture temporal dependencies in data and use patterns learned from historical -/// observations to predict future values. This base class provides the foundation for implementing -/// various time series forecasting algorithms like ARIMA, Exponential Smoothing, TBATS, and more complex -/// machine learning approaches. -/// -/// -/// For Beginners: -/// A time series model helps predict future values based on past observations. -/// -/// Think of a time series like a sequence of measurements taken over time - for example, -/// daily temperatures, monthly sales, or hourly website visits. These models analyze the patterns -/// in historical data to make predictions about what will happen next. -/// -/// This base class is like a blueprint that all specific time series models follow. -/// It ensures that every model can: -/// - Be trained on historical data to learn patterns -/// - Make predictions for future periods based on what it learned -/// - Evaluate how accurate its predictions are compared to actual values -/// - Be saved to disk and loaded later without retraining -/// -/// Time series models are used in many real-world applications, including: -/// - Weather forecasting -/// - Stock market prediction -/// - Demand planning for retail -/// - Energy consumption forecasting -/// - Website traffic prediction -/// -/// -public abstract class TimeSeriesModelBase : ITimeSeriesModel, IConfigurableModel, IModelShape -{ - /// - /// Configuration options for the time series model. - /// - /// - /// - /// These options control the core behavior of the time series model, including how much - /// historical data is considered, whether trends or seasonality are modeled, and how errors - /// are handled. - /// - /// - /// For Beginners: - /// Think of these options as settings that determine how the model works: - /// - LagOrder: How many past values to consider (like remembering the last 5 days to predict tomorrow) - /// - IncludeTrend: Whether to account for ongoing trends (like sales steadily increasing over time) - /// - SeasonalPeriod: Whether there are regular patterns (like retail sales spiking every December) - /// - AutocorrelationCorrection: Whether to fix systematic errors in predictions - /// - /// - protected TimeSeriesRegressionOptions Options { get; set; } - - /// - public virtual ModelOptions GetOptions() => Options; - - /// - /// Provides numeric operations for the specific type T. - /// - /// - /// - /// This property provides mathematical operations appropriate for the generic type T, - /// allowing the algorithm to work consistently with different numeric types like - /// float, double, or decimal. - /// - /// - /// For Beginners: - /// This is a helper that knows how to do math (addition, multiplication, etc.) with - /// your specific number type, whether that's a regular double, a precise decimal value, - /// or something else. It allows the model to work with different types of numbers - /// without changing its core logic. - /// - /// - protected INumericOperations NumOps { get; private set; } - - /// - /// Gets the global execution engine for vector operations. - /// - /// - /// - /// This property provides access to the execution engine (CPU or GPU) for performing - /// vectorized operations. The engine is determined by the global AiDotNetEngine configuration - /// and allows automatic fallback from GPU to CPU when GPU is not available. - /// - /// - /// For Beginners: - /// This gives access to either CPU or GPU processing for faster computations. - /// The system automatically chooses the best available option and falls back to CPU - /// if GPU acceleration is not available. - /// - /// - protected IEngine Engine => AiDotNetEngine.Current; - - /// - /// Gets or sets the trained model parameters. - /// - /// - /// - /// Contains the values that the model has learned during training, such as coefficients - /// for different lags, trend components, and seasonal factors. - /// - /// - /// For Beginners: - /// These are the numerical values the model learns during training that tell it exactly - /// how much influence each past observation should have on the prediction. They're like - /// the recipe ingredients with specific measurements that the model has figured out work best. - /// - /// - protected Vector ModelParameters { get; set; } - - /// - /// Indicates whether the model has been trained. - /// - /// - /// - /// This flag is set to true after the model has been successfully trained on data. - /// - /// - /// For Beginners: - /// This is like a switch that gets turned on once the model has learned from your data. - /// It helps prevent errors by making sure you don't try to use the model for predictions - /// before it's ready. - /// - /// - protected bool IsTrained { get; set; } = false; - - /// - /// The default loss function used for gradient computation. - /// - private readonly ILossFunction _defaultLossFunction; - - /// - /// Gets the last computed error metrics when the model was evaluated. - /// - /// - /// - /// Contains accuracy metrics calculated during model evaluation, such as MAE, RMSE, and MAPE. - /// - /// - /// For Beginners: - /// These numbers tell you how accurate the model's predictions are compared to actual values. - /// Lower numbers mean better predictions. They're like a scorecard for the model's performance. - /// - /// - protected Dictionary LastEvaluationMetrics { get; private set; } = new Dictionary(); - - /// - /// Initializes a new instance of the TimeSeriesModelBase class with the specified options. - /// - /// The configuration options for the time series model. - /// Thrown when options is null. - /// Thrown when options contain invalid values. - /// - /// - /// This constructor validates the provided options, initializes the model with the specified - /// configuration, and sets up the numeric operations appropriate for the data type. - /// - /// - /// For Beginners: - /// This constructor sets up the basic configuration for any time series model. - /// - /// It takes an options object that specifies important settings like: - /// - How many past values to consider (lag order) - /// - Whether to include a trend component (like steady growth or decline) - /// - The length of seasonal patterns (e.g., 7 for weekly, 12 for monthly) - /// - Whether to correct for autocorrelation in errors (systematic errors) - /// - /// It also checks that these settings make sense - for example, you can't have a negative - /// number of past values or a seasonal period less than 2. - /// - /// - protected TimeSeriesModelBase(TimeSeriesRegressionOptions options) - { - // Validate options - if (options == null) - { - throw new ArgumentNullException(nameof(options), "Time series options cannot be null."); - } - - ValidateOptions(options); - - Options = options; - NumOps = MathHelper.GetNumericOperations(); - ModelParameters = new Vector(0); // Initialize with empty vector - _defaultLossFunction = options.LossFunction ?? new MeanSquaredErrorLoss(); - } - - /// - /// Validates the provided time series options to ensure they are within acceptable ranges. - /// - /// The options to validate. - /// Thrown when any option is invalid. - /// - /// - /// Checks that LagOrder is non-negative, SeasonalPeriod is either 0 (no seasonality) or at least 2, - /// and that other parameters have reasonable values. - /// - /// - /// For Beginners: - /// This method makes sure the settings you've chosen for your model make logical sense. - /// For example, you can't look back a negative number of time periods, and a seasonal - /// pattern must repeat at least every 2 periods to be considered seasonal. - /// - /// - protected virtual void ValidateOptions(TimeSeriesRegressionOptions options) - { - if (options.LagOrder < 0) - { - throw new ArgumentException("Lag order must be non-negative.", nameof(options)); - } - - if (options.SeasonalPeriod < 0) - { - throw new ArgumentException("Seasonal period must be non-negative.", nameof(options)); - } - - if (options.SeasonalPeriod == 1) - { - throw new ArgumentException("Seasonal period must be at least 2 if seasonality is enabled.", nameof(options)); - } - - // Additional model-specific validation can be implemented in derived classes - } - - /// - /// Trains the time series model using the provided input data and target values. - /// - /// The input features matrix. - /// The target values vector. - /// Thrown when x or y is null. - /// Thrown when the dimensions of x and y don't match or when the data is insufficient. - /// - /// - /// This method validates the input data, prepares the model for training, performs the actual - /// training algorithm, and sets the IsTrained flag once complete. - /// - /// - /// For Beginners: - /// Training is the process where the model learns patterns from historical data. - /// - /// During training, the model analyzes the relationship between: - /// - Input features (x): These might include past values, time indicators, or external factors - /// - Target values (y): The actual observed values we want to predict - /// - /// After training, the model will have learned parameters that capture the patterns - /// in your data, which it can then use to make predictions for new inputs. - /// - /// This is an abstract method, meaning each specific model type (ARIMA, TBATS, etc.) - /// will implement its own training algorithm. - /// - /// - /// - /// Cancellation token that is active during training. Derived classes should check - /// TrainingCancellationToken.IsCancellationRequested in their training loops - /// to support both caller-initiated cancellation and wall-clock timeout from - /// . - /// - protected CancellationToken TrainingCancellationToken { get; private set; } = CancellationToken.None; - - /// - /// Auto-scaled guard threshold computed from training data. - /// Set to 1000 * max(|y|) during training. Falls back to 1e15 if not trained. - /// - private double _autoGuardThreshold = 1e15; - - public void Train(Matrix x, Vector y) => Train(x, y, CancellationToken.None); - - public void Train(Matrix x, Vector y, CancellationToken callerToken) - { - // Fail fast if already cancelled — don't discard existing trained state - callerToken.ThrowIfCancellationRequested(); - - // Input validation - ValidateTrainingInputs(x, y); - - // Reset model state before training - Reset(); - - // Auto-scale guard threshold from training data: 1000x max observed absolute value. - // This adapts overflow protection to the dataset scale instead of a fixed magic number. - double maxAbsY = 0; - for (int i = 0; i < y.Length; i++) - { - double absVal = Math.Abs(NumOps.ToDouble(y[i])); - if (absVal > maxAbsY && !double.IsNaN(absVal) && !double.IsInfinity(absVal)) - maxAbsY = absVal; - } - _autoGuardThreshold = maxAbsY > 0 ? maxAbsY * 1000.0 : 1e15; - - // Create a linked CancellationTokenSource that combines: - // 1. The caller's token (for external cancellation) - // 2. A wall-clock timeout from MaxTrainingTimeSeconds (safety net) - // Whichever fires first wins. - using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(callerToken); - if (Options.MaxTrainingTimeSeconds > 0) - { - linkedCts.CancelAfter(TimeSpan.FromSeconds(Options.MaxTrainingTimeSeconds)); - } - TrainingCancellationToken = linkedCts.Token; - - try - { - // Perform model-specific training (implemented by derived classes) - TrainCore(x, y); - } - catch (OperationCanceledException) - { - // Training was cancelled by MaxTrainingTimeSeconds timeout. - // The model may have partial training — mark as trained with whatever - // state was achieved. This is industry standard: early stopping produces - // a usable (if suboptimal) model rather than failing completely. - } - finally - { - TrainingCancellationToken = CancellationToken.None; - } - - // Mark the model as trained (even after early cancellation) - IsTrained = true; - } - - /// - /// Performs the model-specific training algorithm. - /// - /// The input features matrix. - /// The target values vector. - /// - /// - /// This abstract method must be implemented by derived classes to perform the actual model training. - /// - /// - /// For Beginners: - /// This is where the specific math and algorithms for each type of time series model are implemented. - /// Different models (like ARIMA, Exponential Smoothing, etc.) will have their own unique ways of - /// finding patterns in the data. - /// - /// - protected virtual void TrainCore(Matrix x, Vector y) - { - // Default: tape-based training handles parameter updates - } - - /// - /// Validates the training input data before proceeding with training. - /// - /// The input features matrix. - /// The target values vector. - /// Thrown when x or y is null. - /// Thrown when the dimensions of x and y don't match or when the data is insufficient. - /// - /// - /// This method verifies that the input data meets the requirements for model training, - /// including checking dimensions, sample size, and consistency. - /// - /// - /// For Beginners: - /// Before the model starts learning, this method checks that your data is valid and properly formatted. - /// It ensures that: - /// - You have provided both input features and target values - /// - The number of examples matches the number of target values - /// - You have enough data points to train the model effectively - /// - There are no obvious inconsistencies in your data structure - /// - /// - protected virtual void ValidateTrainingInputs(Matrix x, Vector y) - { - if (x == null) - { - throw new ArgumentNullException(nameof(x), "Input features matrix cannot be null."); - } - - if (y == null) - { - throw new ArgumentNullException(nameof(y), "Target values vector cannot be null."); - } - - if (x.Rows != y.Length) - { - throw new ArgumentException( - $"Number of rows in input matrix ({x.Rows}) must match the length of target vector ({y.Length})."); - } - - if (x.Rows <= Options.LagOrder) - { - throw new ArgumentException( - $"Number of samples ({x.Rows}) must be greater than lag order ({Options.LagOrder})."); - } - - // Check for sufficient data to handle seasonality - if (Options.SeasonalPeriod > 0 && x.Rows < 2 * Options.SeasonalPeriod) - { - throw new ArgumentException( - $"For seasonal models, the number of samples ({x.Rows}) should be at least twice the seasonal period ({Options.SeasonalPeriod})."); - } - - // Additional validation can be added in derived classes - } - - /// - /// Generates forecasts using the trained time series model. - /// - /// The input features matrix. - /// A vector of forecasted values. - /// Thrown when the model has not been trained. - /// Thrown when input is null. - /// Thrown when input has incorrect dimensions. - /// - /// - /// This method validates that the model is trained and the input data is valid, then - /// generates predictions for each row in the input matrix using the model-specific - /// prediction algorithm. - /// - /// - /// For Beginners: - /// This method uses the patterns learned during training to predict future values. - /// - /// The input matrix typically contains: - /// - Past values of the time series - /// - Time indicators (e.g., month, day of week) - /// - Any external factors that might influence the forecast - /// - /// The output is a vector of predicted values, one for each row in the input matrix. - /// Each prediction represents what the model thinks will happen at that future time point. - /// - /// - public virtual Vector Predict(Matrix input) - { - // Suppress tape recording during inference - using var _noGrad = new NoGradScope(); - - // Check if model is trained - if (!IsTrained) - { - throw new InvalidOperationException("The model must be trained before making predictions."); - } - - // Validate input - ValidatePredictionInput(input); - - // Create output vector for predictions - var predictions = new Vector(input.Rows); - - // Generate predictions for each input row - for (int i = 0; i < input.Rows; i++) - { - predictions[i] = PredictSingle(input.GetRow(i)); - } - - return predictions; - } - - /// - /// Validates the input data for prediction. - /// - /// The input features matrix. - /// Thrown when input is null. - /// Thrown when input has incorrect dimensions. - /// - /// - /// This method verifies that the input data for prediction is valid and has the correct dimensions. - /// - /// - /// For Beginners: - /// Before making predictions, this method checks that your input data is properly formatted. - /// It ensures that: - /// - You have provided input features - /// - The input has the correct structure (number of features/columns) - /// - The data meets any model-specific requirements - /// - /// - /// - /// Guards a prediction value against NaN, Infinity, and extreme overflow. - /// Returns a clamped finite value. All time series models should call this - /// before storing predictions to prevent cascading numerical instability - /// in recursive/autoregressive forecasting loops. - /// - /// The raw prediction value. - /// - /// Maximum allowed absolute value. If not specified, uses the priority chain: - /// (1) user-configured , - /// (2) auto-scaled from training data (1000x max |y|), - /// (3) fallback to 1e15. - /// - /// A finite, clamped value. - protected T GuardPrediction(T value, double maxAbsValue = -1) - { - // Priority chain: explicit parameter > user option > auto-scaled > 1e15 fallback - double threshold = maxAbsValue > 0 - ? maxAbsValue - : Options.MaxPredictionAbsValue ?? _autoGuardThreshold; - - var d = NumOps.ToDouble(value); - if (double.IsNaN(d) || double.IsInfinity(d) || Math.Abs(d) > threshold) - { - var safe = double.IsNaN(d) ? 0.0 : d; - var clamped = MathPolyfill.Clamp(safe, -threshold, threshold); - return NumOps.FromDouble(clamped); - } - return value; - } - - protected virtual void ValidatePredictionInput(Matrix input) - { - if (input == null) - { - throw new ArgumentNullException(nameof(input), "Input features matrix cannot be null."); - } - - // Additional validation can be added in derived classes - } - - /// - /// Generates a prediction for a single input vector. - /// - /// The input feature vector. - /// The predicted value. - /// - /// - /// This abstract method must be implemented by derived classes to generate a prediction - /// for a single input vector using the model-specific algorithm. - /// - /// - /// For Beginners: - /// This method takes a single row of input data (representing one time point) and - /// calculates what the model predicts will happen at that point. Each type of - /// time series model will have its own way of calculating this prediction based - /// on the patterns it learned during training. - /// - /// - public abstract T PredictSingle(Vector input); - - /// - /// Evaluates the performance of the trained model on test data. - /// - /// The input features matrix for testing. - /// The actual target values for testing. - /// A dictionary containing evaluation metrics. - /// Thrown when the model has not been trained. - /// Thrown when xTest or yTest is null. - /// Thrown when the dimensions of xTest and yTest don't match. - /// - /// - /// This method calculates various error metrics by comparing the model's predictions - /// on the test data to the actual values, providing a quantitative assessment of - /// model performance. - /// - /// - /// For Beginners: - /// This method tests how well the model performs by comparing its predictions to actual values. - /// - /// It works by: - /// 1. Using the model to make predictions based on the test inputs - /// 2. Comparing these predictions to the actual test values - /// 3. Calculating various error metrics to quantify the accuracy - /// - /// Common metrics include: - /// - Mean Absolute Error (MAE): Average of absolute differences between predictions and actual values - /// - Root Mean Squared Error (RMSE): Square root of the average squared differences - /// - Mean Absolute Percentage Error (MAPE): Average percentage differences - /// - /// These metrics help you understand how accurate your model is and compare different models. - /// Lower values indicate better performance for all these metrics. - /// - /// - public virtual Dictionary EvaluateModel(Matrix xTest, Vector yTest) - { - // Check if model is trained - if (!IsTrained) - { - throw new InvalidOperationException("The model must be trained before evaluation."); - } - - // Validate inputs - if (xTest == null) - { - throw new ArgumentNullException(nameof(xTest), "Test features matrix cannot be null."); - } - - if (yTest == null) - { - throw new ArgumentNullException(nameof(yTest), "Test target vector cannot be null."); - } - - if (xTest.Rows != yTest.Length) - { - throw new ArgumentException( - $"Number of rows in test matrix ({xTest.Rows}) must match the length of test vector ({yTest.Length})."); - } - - // Generate predictions - Vector predictions = Predict(xTest); - - // Calculate error metrics - Dictionary metrics = CalculateErrorMetrics(predictions, yTest); - - // Store metrics for later reference - LastEvaluationMetrics = metrics; - - return metrics; - } - - /// - /// Calculates error metrics by comparing predictions to actual values. - /// - /// The predicted values. - /// The actual values. - /// A dictionary containing error metrics. - /// - /// - /// This method computes standard error metrics for time series forecasting, including - /// MAE, RMSE, MAPE, and others as appropriate for the model type. - /// - /// - /// For Beginners: - /// This method calculates how far off the model's predictions are from the actual values. - /// It computes several different ways of measuring the prediction errors: - /// - /// - MAE (Mean Absolute Error): The average magnitude of errors, ignoring whether they're positive or negative - /// - RMSE (Root Mean Squared Error): Emphasizes larger errors by squaring them before averaging - /// - MAPE (Mean Absolute Percentage Error): Shows errors as percentages of the actual values - /// - /// These metrics help you understand not just how accurate the model is overall, - /// but also what kinds of errors it tends to make. - /// - /// - protected virtual Dictionary CalculateErrorMetrics(Vector predictions, Vector actuals) - { - int n = predictions.Length; - var metrics = new Dictionary(); - - // Calculate MAE (Mean Absolute Error) - T sumAbsoluteError = NumOps.Zero; - for (int i = 0; i < n; i++) - { - T error = NumOps.Subtract(predictions[i], actuals[i]); - sumAbsoluteError = NumOps.Add(sumAbsoluteError, NumOps.Abs(error)); - } - T mae = NumOps.Divide(sumAbsoluteError, NumOps.FromDouble(n)); - metrics["MAE"] = mae; - - // Calculate MSE (Mean Squared Error) and RMSE (Root Mean Squared Error) - T sumSquaredError = NumOps.Zero; - for (int i = 0; i < n; i++) - { - T error = NumOps.Subtract(predictions[i], actuals[i]); - sumSquaredError = NumOps.Add(sumSquaredError, NumOps.Square(error)); - } - T mse = NumOps.Divide(sumSquaredError, NumOps.FromDouble(n)); - T rmse = NumOps.Sqrt(mse); - metrics["MSE"] = mse; - metrics["RMSE"] = rmse; - - // Calculate MAPE (Mean Absolute Percentage Error) - // Only if actuals don't contain zeros or very small values - bool canCalculateMape = true; - T sumAbsolutePercentageError = NumOps.Zero; - for (int i = 0; i < n; i++) - { - if (NumOps.LessThan(NumOps.Abs(actuals[i]), NumOps.FromDouble(1e-10))) - { - canCalculateMape = false; - break; - } - - T percentageError = NumOps.Divide( - NumOps.Abs(NumOps.Subtract(predictions[i], actuals[i])), - NumOps.Abs(actuals[i]) - ); - sumAbsolutePercentageError = NumOps.Add(sumAbsolutePercentageError, percentageError); - } - - if (canCalculateMape) - { - T mape = NumOps.Multiply( - NumOps.Divide(sumAbsolutePercentageError, NumOps.FromDouble(n)), - NumOps.FromDouble(100) // Convert to percentage - ); - metrics["MAPE"] = mape; - } - - return metrics; - } - - /// - /// Serializes the model to a byte array for storage or transmission. - /// - /// A byte array containing the serialized model. - /// - /// - /// This method serializes the common components of the model (options, trained status, parameters) - /// and then calls the model-specific serialization method to handle specialized data. - /// - /// - /// For Beginners: - /// Serialization converts the model's state into a format that can be saved to disk - /// or transmitted over a network. - /// - /// This method: - /// 1. Creates a memory stream to hold the serialized data - /// 2. Writes the common configuration options shared by all models - /// 3. Writes whether the model has been trained - /// 4. Writes the model parameters learned during training - /// 5. Calls the model-specific serialization method to write specialized data - /// 6. Returns everything as a byte array - /// - /// This allows you to save a trained model and load it later without having to retrain it, - /// which can save significant time for complex models trained on large datasets. - /// - /// - public virtual byte[] Serialize() - { - ModelPersistenceGuard.EnforceBeforeSerialize(); - using var ms = new MemoryStream(); - using var writer = new BinaryWriter(ms); - - // Serialize common options - writer.Write(Options.LagOrder); - writer.Write(Options.IncludeTrend); - writer.Write(Options.SeasonalPeriod); - writer.Write(Options.AutocorrelationCorrection); - writer.Write((int)Options.ModelType); - - // Serialize trained state - writer.Write(IsTrained); - - // Serialize model parameters if trained - if (IsTrained) - { - writer.Write(ModelParameters.Length); - for (int i = 0; i < ModelParameters.Length; i++) - { - writer.Write(Convert.ToDouble(ModelParameters[i])); - } - - // Serialize evaluation metrics - writer.Write(LastEvaluationMetrics.Count); - foreach (var kvp in LastEvaluationMetrics) - { - writer.Write(kvp.Key); - writer.Write(Convert.ToDouble(kvp.Value)); - } - } - - // Serialize auto-scaled guard threshold (persists training-data-aware overflow protection) - writer.Write(_autoGuardThreshold); - - // Let derived classes serialize their specific data - SerializeCore(writer); - - return ms.ToArray(); - } - - /// - /// Deserializes the model from a byte array. - /// - /// The byte array containing the serialized model. - /// Thrown when data is null. - /// Thrown when the serialized data is corrupted or incompatible. - /// - /// - /// This method deserializes the common components of the model (options, trained status, parameters) - /// and then calls the model-specific deserialization method to handle specialized data. - /// - /// - /// For Beginners: - /// Deserialization is the process of loading a previously saved model from a byte array. - /// - /// This method: - /// 1. Creates a memory stream from the provided byte array - /// 2. Reads the common configuration options shared by all models - /// 3. Reads whether the model has been trained - /// 4. Reads the model parameters learned during training - /// 5. Calls the model-specific deserialization method to read specialized data - /// - /// After deserialization, the model is restored to the same state it was in when serialized, - /// allowing you to make predictions without retraining the model. - /// - /// This is particularly useful for: - /// - Deploying models to production environments - /// - Sharing models between different applications - /// - Saving computation time by not having to retrain complex models - /// - /// - public virtual void Deserialize(byte[] data) - { - ModelPersistenceGuard.EnforceBeforeDeserialize(); - if (data == null) - { - throw new ArgumentNullException(nameof(data), "Serialized data cannot be null."); - } - - try - { - using var ms = new MemoryStream(data); - using var reader = new BinaryReader(ms); - - // Deserialize common options - Options.LagOrder = reader.ReadInt32(); - Options.IncludeTrend = reader.ReadBoolean(); - Options.SeasonalPeriod = reader.ReadInt32(); - Options.AutocorrelationCorrection = reader.ReadBoolean(); - Options.ModelType = (TimeSeriesModelType)reader.ReadInt32(); - - // Deserialize trained state - IsTrained = reader.ReadBoolean(); - - // Deserialize model parameters if trained - if (IsTrained) - { - int parameterCount = reader.ReadInt32(); - ModelParameters = new Vector(parameterCount); - for (int i = 0; i < parameterCount; i++) - { - ModelParameters[i] = NumOps.FromDouble(reader.ReadDouble()); - } - - // Deserialize evaluation metrics - int metricsCount = reader.ReadInt32(); - LastEvaluationMetrics.Clear(); - for (int i = 0; i < metricsCount; i++) - { - string key = reader.ReadString(); - T value = NumOps.FromDouble(reader.ReadDouble()); - LastEvaluationMetrics[key] = value; - } - } - - // Deserialize auto-scaled guard threshold (backwards-compatible) - try - { - _autoGuardThreshold = reader.ReadDouble(); - } - catch (EndOfStreamException) - { - _autoGuardThreshold = 1e15; // Pre-patch model - } - - // Let derived classes deserialize their specific data - DeserializeCore(reader); - } - catch (Exception ex) - { - throw new InvalidOperationException("Failed to deserialize model data. The data may be corrupted or incompatible with this model version.", ex); - } - } - - /// - /// Serializes model-specific data to the binary writer. - /// - /// The binary writer to write to. - /// - /// - /// This abstract method must be implemented by each specific model type to save - /// its unique parameters and state. - /// - /// - /// For Beginners: - /// This method is responsible for saving the specific details that make each type of - /// time series model unique. Different models have different internal structures and parameters - /// that need to be saved separately from the common elements. - /// - /// For example: - /// - An ARIMA model would save its AR, I, and MA coefficients - /// - A TBATS model would save its level, trend, and seasonal components - /// - A neural network model would save its weights and biases - /// - /// This separation allows the base class to handle common serialization tasks - /// while each model type handles its specialized data. - /// - /// - protected abstract void SerializeCore(BinaryWriter writer); - - /// - /// Deserializes model-specific data from the binary reader. - /// - /// The binary reader to read from. - /// - /// - /// This abstract method must be implemented by each specific model type to load - /// its unique parameters and state. - /// - /// - /// For Beginners: - /// This method is responsible for loading the specific details that make each type of - /// time series model unique. It reads exactly what was written by SerializeCore, in the - /// same order, reconstructing the specialized parts of the model. - /// - /// It's the counterpart to SerializeCore and should read data in exactly the same - /// order and format that it was written. - /// - /// This separation allows the base class to handle common deserialization tasks - /// while each model type handles its specialized data. - /// - /// - protected abstract void DeserializeCore(BinaryReader reader); - - /// - /// Gets metadata about the time series model. - /// - /// A ModelMetaData object containing information about the model. - /// - /// - /// This method provides comprehensive metadata about the model, including its type, - /// configuration options, training status, evaluation metrics, and information about - /// which features/lags are most important. - /// - /// - /// For Beginners: - /// This method provides important information about the model that can help you understand - /// its characteristics and performance. - /// - /// The metadata includes: - /// - The type of model (e.g., ARIMA, TBATS, Neural Network) - /// - Configuration details (e.g., lag order, seasonality period) - /// - Whether the model has been trained - /// - Performance metrics from the last evaluation - /// - Information about which features (time periods) are most influential - /// - /// This information is useful for documentation, model comparison, and debugging. - /// It's like a complete summary of everything important about the model. - /// - /// - public abstract ModelMetadata GetModelMetadata(); - - /// - /// Gets the trainable parameters of the model as a vector. - /// - /// A vector containing all trainable parameters of the model. - /// Thrown when the model has not been trained. - /// - /// - /// This method returns all the parameters learned during training, combined into a single vector. - /// These parameters determine how the model makes predictions based on input data. - /// - /// - /// For Beginners: - /// This method returns all the numerical values that the model has learned during training. - /// - /// For time series models, these parameters typically include: - /// - Coefficients for each lag (how much each past value influences the prediction) - /// - Trend coefficients (if trend is included) - /// - Seasonal coefficients (if seasonality is included) - /// - Error correction terms (if autocorrelation correction is enabled) - /// - /// These parameters can be: - /// - Analyzed to understand what the model has learned - /// - Saved for later use - /// - Modified to adjust the model's behavior - /// - Transferred to another model with the same structure - /// - /// - /// - /// Time series models do not support random parameter initialization from the optimizer. - /// They must be trained on sequential data to learn meaningful coefficients. - /// - public virtual bool SupportsParameterInitialization => false; - /// - public virtual Vector SanitizeParameters(Vector parameters) => parameters; - - - public virtual Vector GetParameters() - { - if (!IsTrained && (ModelParameters == null || ModelParameters.Length == 0)) - { - throw new InvalidOperationException("Cannot get parameters for an untrained model."); - } - - if (ModelParameters == null || ModelParameters.Length == 0) - { - throw new InvalidOperationException("Model parameters have not been initialized."); - } - - return ModelParameters.Clone(); - } - - /// - /// Creates a new model with the specified parameters. - /// - /// The vector of parameters to use for the new model. - /// A new model instance with the specified parameters. - /// Thrown when parameters is null. - /// Thrown when the parameters vector has incorrect length. - /// - /// - /// This method creates a clone of the current model but replaces its parameters with the - /// provided values. This allows for creating variations of a model without retraining. - /// - /// - /// For Beginners: - /// This method creates a copy of the current model but with different parameter values. - /// - /// This allows you to: - /// - Create a model with manually specified parameters (e.g., from expert knowledge) - /// - Make small adjustments to a trained model without full retraining - /// - Implement ensemble models that combine multiple parameter sets - /// - Perform what-if analysis by changing specific parameters - /// - /// The parameters must be in the same order and have the same meaning as those - /// returned by the GetParameters method. - /// - /// - public virtual IFullModel, Vector> WithParameters(Vector parameters) - { - if (parameters == null) - { - throw new ArgumentNullException(nameof(parameters), "Parameters vector cannot be null."); - } - - // Create a clone of the current model - var newModel = (TimeSeriesModelBase)this.Clone(); - - // Apply the new parameters to the cloned model - newModel.ApplyParameters(parameters); - - // Mark as trained since parameters have been specified - newModel.IsTrained = true; - - return newModel; - } - - /// - /// Applies the provided parameters to the model. - /// - /// The vector of parameters to apply. - /// Thrown when the parameters vector is invalid. - /// - /// - /// This method applies the provided parameter values to the model, updating its internal state - /// to reflect the new parameters. The implementation is model-specific and should be overridden - /// by derived classes as needed. - /// - /// - /// For Beginners: - /// This method updates the model's internal parameters with new values. - /// It's the counterpart to GetParameters and should understand the parameter - /// vector in exactly the same way. - /// - /// For example, if the first 5 elements of the parameters vector represent - /// lag coefficients, this method should apply them as lag coefficients in - /// the model's internal structure. - /// - /// - protected virtual void ApplyParameters(Vector parameters) - { - if (parameters == null) - { - throw new ArgumentNullException(nameof(parameters), "Parameters vector cannot be null."); - } - - // Store the parameters - ModelParameters = parameters.Clone(); - - // Derived classes should override this to apply parameters to their specific structures - } - - /// - /// Gets the indices of features (lags/time periods) actively used by the model. - /// - /// A collection of indices representing the active features. - /// Thrown when the model has not been trained. - /// - /// - /// This method identifies which input features (lags) have significant impact on the model's - /// predictions, based on their corresponding parameter values. - /// - /// - /// For Beginners: - /// This method tells you which past time periods (lags) are most important for predictions. - /// - /// For example, if the result includes indices [1, 7, 12], this means: - /// - The value from 1 period ago strongly influences the prediction - /// - The value from 7 periods ago strongly influences the prediction (could be weekly seasonality) - /// - The value from 12 periods ago strongly influences the prediction (could be yearly for monthly data) - /// - /// These active features are determined by the model's structure and learned parameters. - /// For instance, in an ARIMA model, non-zero AR coefficients indicate active features. - /// - /// Understanding active features helps interpret how the model works and which - /// historical points matter most for forecasting. - /// - /// - public virtual IEnumerable GetActiveFeatureIndices() - { - if (!IsTrained) - { - throw new InvalidOperationException("The model must be trained before getting active feature indices."); - } - - List activeIndices = new List(); - - // Consider common lag patterns based on model configuration - for (int lag = 1; lag <= Options.LagOrder; lag++) - { - if (IsFeatureUsed(lag)) - { - activeIndices.Add(lag); - } - } - - // If seasonal, also include seasonal lags - if (Options.SeasonalPeriod > 0) - { - for (int s = 1; s <= 4; s++) // Consider up to 4 seasonal lags - { - int seasonalLag = s * Options.SeasonalPeriod; - if (seasonalLag <= Options.LagOrder && IsFeatureUsed(seasonalLag)) - { - activeIndices.Add(seasonalLag); - } - } - } - - return activeIndices; - } - - /// - /// Determines if a specific feature (lag) is actively used by the model. - /// - /// The index of the feature to check. - /// True if the feature is actively used; otherwise, false. - /// Thrown when the model has not been trained. - /// Thrown when featureIndex is negative or exceeds the maximum lag order. - /// - /// - /// This method determines whether a specific lag has a significant impact on the model's predictions, - /// based on its corresponding parameter value. The threshold for significance is model-specific. - /// - /// - /// For Beginners: - /// This method checks if a specific past time period (lag) has a significant - /// influence on the model's predictions. - /// - /// For example: - /// - IsFeatureUsed(1) checks if the value from 1 period ago matters - /// - IsFeatureUsed(7) checks if the value from 7 periods ago matters - /// - IsFeatureUsed(12) checks if the value from 12 periods ago matters - /// - /// A feature is typically considered "used" if its coefficient or weight - /// in the model is significantly different from zero. - /// - /// This information helps understand which historical points the model - /// considers important when making predictions. - /// - /// - public virtual bool IsFeatureUsed(int featureIndex) - { - if (!IsTrained) - { - throw new InvalidOperationException("The model must be trained before checking feature usage."); - } - - if (featureIndex < 0) - { - throw new ArgumentOutOfRangeException(nameof(featureIndex), "Feature index cannot be negative."); - } - - if (featureIndex > Options.LagOrder) - { - // For indices beyond the lag order, check if it's a valid seasonal lag - if (Options.SeasonalPeriod > 0 && featureIndex % Options.SeasonalPeriod == 0) - { - return NumOps.GreaterThan(GetFeatureImportance(featureIndex), NumOps.FromDouble(0.01)); - } - - return false; - } - - // For standard lags, check if the feature importance exceeds a threshold - T importance = GetFeatureImportance(featureIndex); - return NumOps.GreaterThan(importance, NumOps.FromDouble(0.01)); - } - - /// - /// Gets the importance of a specific feature (lag). - /// - /// The index of the feature. - /// A value indicating the feature's importance. - /// Thrown when the model has not been trained. - /// Thrown when featureIndex is negative. - /// - /// - /// This method calculates the importance of a specific lag in the model's predictions, - /// based on its parameter value and the model's structure. The implementation is model-specific. - /// - /// - /// For Beginners: - /// This method estimates how important a specific past time period is - /// for making predictions. Higher values indicate more influential features. - /// - /// For example, in many time series models: - /// - Recent lags (like lag 1) often have higher importance - /// - Seasonal lags (like lag 7 for weekly data) often have higher importance - /// - Some lags may have near-zero importance, meaning they don't affect predictions much - /// - /// This information helps understand the model's internal logic and which past - /// time periods it considers most predictive of future values. - /// - /// - protected virtual T GetFeatureImportance(int featureIndex) - { - if (!IsTrained) - { - throw new InvalidOperationException("The model must be trained before getting feature importance."); - } - - if (featureIndex < 0) - { - throw new ArgumentOutOfRangeException(nameof(featureIndex), "Feature index cannot be negative."); - } - - // Default implementation - derived classes should override with model-specific logic - // For time series models, standard importance calculation might consider: - // 1. The magnitude of coefficients for each lag - // 2. The recency of the lag (more recent lags may be more important) - // 3. Seasonal patterns (lags at seasonal intervals may be more important) - - // As a simple default, if the feature index is within the parameter range, use its absolute value - if (featureIndex < ModelParameters.Length) - { - return NumOps.Abs(ModelParameters[featureIndex]); - } - - // Otherwise, define some heuristic defaults - if (featureIndex == 1) - { - // The most recent lag is usually important - return NumOps.FromDouble(0.5); - } - else if (Options.SeasonalPeriod > 0 && featureIndex % Options.SeasonalPeriod == 0) - { - // Seasonal lags are usually important - return NumOps.FromDouble(0.3); - } - else if (featureIndex <= 3) - { - // Recent lags are moderately important - return NumOps.FromDouble(0.2); - } - - // Default to very low importance for other lags - return NumOps.FromDouble(0.01); - } - - /// - /// Sets the parameters for this model. - /// - /// A vector containing the model parameters. - /// - /// If the model is untrained (ModelParameters is empty), this method will - /// resize ModelParameters to accept the incoming parameters. This allows - /// optimizers to initialize untrained models with random parameters. - /// - public virtual void SetParameters(Vector parameters) - { - // If model is untrained (empty parameters), resize to accept the new parameters - // This allows optimizers to initialize untrained models with random parameters - if (ModelParameters.Length == 0 && parameters.Length > 0) - { - ModelParameters = new Vector(parameters.Length); - } - - if (parameters.Length != ModelParameters.Length) - { - throw new ArgumentException($"Expected {ModelParameters.Length} parameters, but got {parameters.Length}", nameof(parameters)); - } - - for (int i = 0; i < ModelParameters.Length; i++) - { - ModelParameters[i] = parameters[i]; - } - } - - /// - /// Sets the active feature indices for this model. - /// - /// The indices of features to activate. - public virtual void SetActiveFeatureIndices(IEnumerable featureIndices) - { - var activeSet = new HashSet(featureIndices); - - for (int i = 0; i < ModelParameters.Length; i++) - { - if (!activeSet.Contains(i)) - { - ModelParameters[i] = NumOps.Zero; - } - } - } - - /// - /// Gets the feature importance scores as a dictionary. - /// - /// A dictionary mapping feature names to their importance scores. - public virtual Dictionary GetFeatureImportance() - { - var result = new Dictionary(); - - for (int i = 0; i < ModelParameters.Length; i++) - { - string featureName = $"Lag_{i + 1}"; - result[featureName] = NumOps.Abs(ModelParameters[i]); - } - - return result; - } - - /// - /// Creates a deep copy of the time series model. - /// - /// A new instance that is a deep copy of this model. - /// - /// - /// This method creates a completely independent copy of the model, with all parameters, - /// options, and internal state duplicated. Modifications to the copy will not affect the - /// original, and vice versa. - /// - /// - /// For Beginners: - /// This method creates a completely independent copy of the current model. - /// - /// A deep copy means that all components of the model are duplicated, - /// including: - /// - Configuration options - /// - Learned parameters - /// - Internal state variables - /// - /// This is useful when you need to: - /// - Create multiple variations of a model for experimentation - /// - Save a model at a specific point during training - /// - Use the same model structure for different datasets - /// - /// Changes to the copy won't affect the original model and vice versa. - /// - /// - public virtual IFullModel, Vector> DeepCopy() - { - // Create a new instance through serialization/deserialization for a true deep copy - byte[] serialized = this.Serialize(); - var newModel = (TimeSeriesModelBase)CreateInstance(); - newModel.Deserialize(serialized); - - return newModel; - } - - /// - /// Creates a clone of the time series model. - /// - /// A new instance that is a clone of this model. - /// - /// - /// This method creates a copy of the model that shares the same options but has independent - /// parameter values. It's a lighter-weight alternative to DeepCopy for cases where a complete - /// independent copy is not needed. - /// - /// - /// For Beginners: - /// This method creates a copy of the current model with the same configuration - /// and parameters. - /// - /// While DeepCopy creates a fully independent duplicate of everything in the model, - /// Clone sometimes creates a more lightweight copy that might share some non-essential - /// components with the original (depending on the specific model implementation). - /// - /// This is useful for: - /// - Creating variations of a model for ensemble methods - /// - Saving a snapshot of the model before making changes - /// - Creating multiple instances for parallel training - /// - /// - public virtual IFullModel, Vector> Clone() - { - // Use DeepCopy (serialize/deserialize) to ensure all internal state is preserved. - // The lighter Clone approach only copies ModelParameters but misses subclass-specific - // state (e.g. _arCoefficients, _trainedSeries) that Predict actually uses. - return DeepCopy(); - } - - /// - /// Creates a new instance of the derived model class. - /// - /// A new instance of the same model type. - /// - /// - /// This abstract factory method must be implemented by derived classes to create a new - /// instance of their specific type. It's used by Clone and DeepCopy to ensure that - /// the correct derived type is instantiated. - /// - /// - /// For Beginners: - /// This method creates a new, empty instance of the specific model type. - /// It's used during cloning and deep copying to ensure that the copy - /// is of the same specific type as the original. - /// - /// For example, if the original model is an ARIMA model, this method - /// would create a new ARIMA model. If it's a TBATS model, it would - /// create a new TBATS model. - /// - /// - protected abstract IFullModel, Vector> CreateInstance(); - - /// - /// Resets the model to its untrained state. - /// - /// - /// - /// This method clears all trained parameters and returns the model to its initial untrained state. - /// - /// - /// For Beginners: - /// This method erases all the patterns the model has learned. - /// - /// After calling this method: - /// - All coefficients and learned parameters are cleared - /// - The model behaves as if it was never trained - /// - You would need to train it again before making predictions - /// - /// This is useful when you want to: - /// - Experiment with different training data on the same model - /// - Retrain a model from scratch with new parameters - /// - Reset a model that might have been trained incorrectly - /// - /// - public virtual void Reset() - { - // Clear model parameters - ModelParameters = new Vector(0); - - // Reset trained flag - IsTrained = false; - - // Clear evaluation metrics - LastEvaluationMetrics.Clear(); - - // Derived classes should override this to reset any additional state - } - - /// - /// Clips a value to be within the specified range. - /// - /// The value to clip. - /// The minimum allowed value. - /// The maximum allowed value. - /// The clipped value. - /// - /// - /// This utility method constrains a value to be within the specified range. - /// If the value is less than the minimum, the minimum is returned. - /// If the value is greater than the maximum, the maximum is returned. - /// Otherwise, the original value is returned. - /// - /// - /// For Beginners: - /// This method ensures a value stays within a specified range (between min and max). - /// It's like setting boundaries that a value cannot cross. - /// - /// For example, if you clip a value with min=0 and max=1: - /// - If the value is -0.5, it returns 0 (the minimum) - /// - If the value is 1.5, it returns 1 (the maximum) - /// - If the value is 0.7, it returns 0.7 (unchanged, as it's within range) - /// - /// This is useful for: - /// - Preventing parameters from taking extreme values - /// - Constraining predictions to reasonable ranges - /// - Implementing optimization algorithms that require bounded parameters - /// - /// - protected T Clip(T value, T min, T max) - { - if (NumOps.LessThan(value, min)) - { - return min; - } - - if (NumOps.GreaterThan(value, max)) - { - return max; - } - - return value; - } - - /// - /// Generates a forecast for multiple steps ahead. - /// - /// The historical time series data. - /// The number of steps to forecast. - /// A vector containing the forecasted values. - /// Thrown when the model has not been trained. - /// Thrown when history is null. - /// Thrown when steps is not positive or history is insufficient. - /// - /// - /// This method generates a multi-step forecast using the history data as the starting point. - /// For each step, it makes a prediction and then updates the history with the predicted value - /// to generate the next prediction. - /// - /// - /// For Beginners: - /// This method predicts multiple future values in sequence. - /// - /// For example, if you have daily data and want to forecast the next 7 days: - /// 1. It first predicts day 1 using your historical data - /// 2. Then it adds that prediction to the history - /// 3. Then it predicts day 2 using the updated history (including the day 1 prediction) - /// 4. And so on, until it has predicted all 7 days - /// - /// This approach lets you make predictions further into the future, - /// but be aware that errors tend to accumulate with each step (predictions - /// become less accurate the further ahead you forecast). - /// - /// - public virtual Vector Forecast(Vector history, int steps) - { - if (!IsTrained) - { - throw new InvalidOperationException("The model must be trained before forecasting."); - } - - if (history == null) - { - throw new ArgumentNullException(nameof(history), "History cannot be null."); - } - - if (steps <= 0) - { - throw new ArgumentException("Number of forecast steps must be positive.", nameof(steps)); - } - - if (history.Length < Options.LagOrder) - { - throw new ArgumentException( - $"History length ({history.Length}) must be at least equal to lag order ({Options.LagOrder}).", - nameof(history)); - } - - // Create a working copy of the history that we can extend - List extendedHistory = new List(history.Length + steps); - for (int i = 0; i < history.Length; i++) - { - extendedHistory.Add(history[i]); - } - - // Generate forecasts one step at a time - Vector forecasts = new Vector(steps); - for (int step = 0; step < steps; step++) - { - // Prepare input features for this forecast step - Vector features = PrepareForecastFeatures(extendedHistory, step); - - // Make prediction - T forecast = PredictSingle(features); - - // Store forecast - forecasts[step] = forecast; - - // Add forecast to extended history for next step - extendedHistory.Add(forecast); - } - - return forecasts; - } - - /// - /// Prepares input features for a forecast step using the extended history. - /// - /// The historical data including any previous forecasts. - /// The current forecast step (0-based). - /// A vector of input features for the forecast. - /// - /// - /// This method extracts the appropriate lags and constructs any additional features - /// needed for the forecast, such as trend indicators or seasonal dummies. - /// - /// - /// For Beginners: - /// This method prepares the input data needed to make a forecast for a specific step. - /// It typically extracts recent values, seasonal patterns, and trend indicators from - /// the history (which may include previous predictions for multi-step forecasts). - /// - /// - protected virtual Vector PrepareForecastFeatures(List extendedHistory, int step) - { - // This is a basic implementation that derived classes should override - // to include model-specific feature preparation - - // For a simple AR model, we would just include the last LagOrder values - int historyLength = extendedHistory.Count; - int featureCount = Options.LagOrder; - - // Add space for trend if included - if (Options.IncludeTrend) - { - featureCount += 1; - } - - // Add space for seasonal dummies if seasonal - if (Options.SeasonalPeriod > 0) - { - featureCount += Options.SeasonalPeriod; - } - - Vector features = new Vector(featureCount); - int featureIndex = 0; - - // Add lag features - for (int lag = 1; lag <= Options.LagOrder; lag++) - { - if (historyLength - lag >= 0) - { - features[featureIndex++] = extendedHistory[historyLength - lag]; - } - else - { - // Not enough history for this lag, use a default value - features[featureIndex++] = NumOps.Zero; - } - } - - // Add trend feature if included - if (Options.IncludeTrend) - { - features[featureIndex++] = NumOps.FromDouble(step + 1); - } - - // Add seasonal dummies if seasonal - if (Options.SeasonalPeriod > 0) - { - int season = (historyLength + step) % Options.SeasonalPeriod; - for (int s = 0; s < Options.SeasonalPeriod; s++) - { - features[featureIndex++] = NumOps.FromDouble(s == season ? 1.0 : 0.0); - } - } - - return features; - } - - public virtual int ParameterCount - { - get { return ModelParameters.Length; } - } - - /// - public virtual int[] GetInputShape() - { - // LagOrder defines the lookback window. Exogenous features are handled - // as additional columns in the input matrix, so effective input width - // equals LagOrder * (1 + exogenousFeatureCount). Since the exogenous - // count varies per dataset, we report the lag dimension here; subclasses - // with known exogenous counts should override. - return Options.LagOrder > 0 ? new[] { Options.LagOrder } : Array.Empty(); - } - - /// - public virtual int[] GetOutputShape() - { - // Forecasts one value per step; multi-horizon models should override - return new[] { 1 }; - } - - /// - public virtual DynamicShapeInfo GetDynamicShapeInfo() - { - return DynamicShapeInfo.None; - } - - - public virtual void SaveModel(string filePath) - { - if (string.IsNullOrWhiteSpace(filePath)) - throw new ArgumentException("File path must not be null or empty.", nameof(filePath)); - - try - { - var data = Serialize(); - var directory = Path.GetDirectoryName(filePath); - if (!string.IsNullOrEmpty(directory) && !Directory.Exists(directory)) - Directory.CreateDirectory(directory); - byte[] envelopedData = ModelFileHeader.WrapWithHeader( - data, this, GetInputShape(), GetOutputShape(), SerializationFormat.Binary, - GetDynamicShapeInfo()); - File.WriteAllBytes(filePath, envelopedData); - } - catch (IOException ex) { throw new InvalidOperationException($"Failed to save model to '{filePath}': {ex.Message}", ex); } - catch (UnauthorizedAccessException ex) { throw new InvalidOperationException($"Access denied when saving model to '{filePath}': {ex.Message}", ex); } - catch (System.Security.SecurityException ex) { throw new InvalidOperationException($"Security error when saving model to '{filePath}': {ex.Message}", ex); } - } - - public virtual void LoadModel(string filePath) - { - if (string.IsNullOrWhiteSpace(filePath)) - throw new ArgumentException("File path must not be null or empty.", nameof(filePath)); - - try - { - var data = File.ReadAllBytes(filePath); - - // Extract payload from AIMF envelope if present; use raw bytes for legacy files - if (ModelFileHeader.HasHeader(data)) - { - data = ModelFileHeader.ExtractPayload(data); - } - - Deserialize(data); - } - catch (FileNotFoundException ex) { throw new FileNotFoundException($"The specified model file does not exist: {filePath}", filePath, ex); } - catch (IOException ex) { throw new InvalidOperationException($"File I/O error while loading model from '{filePath}': {ex.Message}", ex); } - catch (UnauthorizedAccessException ex) { throw new InvalidOperationException($"Access denied when loading model from '{filePath}': {ex.Message}", ex); } - catch (System.Security.SecurityException ex) { throw new InvalidOperationException($"Security error when loading model from '{filePath}': {ex.Message}", ex); } - catch (Exception ex) { throw new InvalidOperationException($"Failed to deserialize model from file '{filePath}'. The file may be corrupted or incompatible: {ex.Message}", ex); } - } - - public virtual ILossFunction DefaultLossFunction => _defaultLossFunction; - - public virtual Vector ComputeGradients(Matrix input, Vector target, ILossFunction? lossFunction = null) - { - var loss = lossFunction ?? DefaultLossFunction; - - // Primary path: layer-level backpropagation for exact gradients. - // Available for NeuralNetworkBase-derived time series models (Autoformer, Informer, etc.). - try - { - var predicted = Predict(input); - - var lossGrad = loss.CalculateDerivative(predicted, target); - var lossGradTensor = Tensor.FromVector(lossGrad); - - BackpropagateLayers(lossGradTensor); - - var gradients = GetLayerParameterGradients(); - - bool hasValidGradients = false; - for (int i = 0; i < Math.Min(gradients.Length, 100); i++) - { - if (!NumOps.Equals(gradients[i], NumOps.Zero)) - { - hasValidGradients = true; - break; - } - } - - if (hasValidGradients) - return gradients; - } - catch (NotSupportedException) - { - // Expected for models without layer-level backprop — fall through to SPSA - } - catch (Exception ex) - { - System.Diagnostics.Trace.TraceWarning( - $"Layer backpropagation failed for {GetType().Name}, falling back to SPSA: {ex.Message}"); - } - - // Fallback: SPSA gradient approximation — estimates ALL N gradients with just 2 forward - // passes per sample (vs 2N for per-parameter finite differences). - // Reference: Spall, J.C., IEEE TAC, 1992. - var parameters = GetParameters(); - var gradients_spsa = new Vector(parameters.Length); - T epsilon = NumOps.FromDouble(1e-3); - T twoEpsilon = NumOps.Multiply(epsilon, NumOps.FromDouble(2.0)); - var rng = Tensors.Helpers.RandomHelper.CreateSeededRandom(42); - int numSamples = 3; - - var delta = new Vector(parameters.Length); - - for (int s = 0; s < numSamples; s++) - { - for (int i = 0; i < parameters.Length; i++) - delta[i] = rng.NextDouble() < 0.5 ? NumOps.FromDouble(-1.0) : NumOps.FromDouble(1.0); - - var eDelta = Engine.Multiply(delta, epsilon); - - var modelPlus = (TimeSeriesModelBase)WithParameters(Engine.Add(parameters, eDelta)); - var lossPlus = loss.CalculateLoss(modelPlus.Predict(input), target); - - var modelMinus = (TimeSeriesModelBase)WithParameters(Engine.Subtract(parameters, eDelta)); - var lossMinus = loss.CalculateLoss(modelMinus.Predict(input), target); - - T lossDiff = NumOps.Subtract(lossPlus, lossMinus); - var scaledDelta = Engine.Multiply(delta, twoEpsilon); - gradients_spsa = Engine.Add(gradients_spsa, Engine.Divide( - Engine.Fill(parameters.Length, lossDiff), scaledDelta)); - } - - gradients_spsa = Engine.Multiply(gradients_spsa, NumOps.FromDouble(1.0 / numSamples)); - - return gradients_spsa; - } - - /// - /// Backpropagates the loss gradient through the model's neural network layers. - /// Override in NeuralNetworkBase-derived time series models to enable exact gradients. - /// - /// Gradient of the loss w.r.t. the model output. - protected virtual void BackpropagateLayers(Tensor lossGradient) - { - throw new NotSupportedException( - $"{GetType().Name} does not implement BackpropagateLayers. " + - "Override this method in NeuralNetworkBase-derived models."); - } - - /// - /// Extracts accumulated parameter gradients from all layers after backpropagation. - /// - protected virtual Vector GetLayerParameterGradients() - { - throw new NotSupportedException( - $"{GetType().Name} does not implement GetLayerParameterGradients. " + - "Override this method to extract layer-level gradients."); - } - - public virtual void ApplyGradients(Vector gradients, T learningRate) - { - if (gradients == null) - throw new ArgumentNullException(nameof(gradients)); - - var parameters = GetParameters(); - - if (gradients.Length != parameters.Length) - throw new ArgumentException($"Gradient vector length ({gradients.Length}) must match parameter count ({parameters.Length})."); - - parameters = Engine.Subtract(parameters, Engine.Multiply(gradients, learningRate)); - - SetParameters(parameters); - } - - /// - /// Saves the time series model's current state to a stream. - /// - /// The stream to write the model state to. - /// - /// - /// This method serializes the time series model's parameters and configuration. - /// It uses the existing Serialize method and writes the data to the provided stream. - /// - /// For Beginners: This is like creating a snapshot of your trained time series model. - /// - /// When you call SaveState: - /// - All learned parameters and trends are written to the stream - /// - Model configuration and internal state are preserved - /// - /// This is particularly useful for: - /// - Checkpointing during long training sessions - /// - Saving the best model for forecasting - /// - Knowledge distillation from time series models - /// - Deploying forecasting models to production - /// - /// You can later use LoadState to restore the model. - /// - /// - /// Thrown when stream is null. - /// Thrown when there's an error writing to the stream. - public virtual void SaveState(Stream stream) - { - if (stream == null) - throw new ArgumentNullException(nameof(stream)); - - if (!stream.CanWrite) - throw new ArgumentException("Stream must be writable.", nameof(stream)); - - try - { - var data = this.Serialize(); - stream.Write(data, 0, data.Length); - stream.Flush(); - } - catch (IOException ex) - { - throw new IOException($"Failed to save time series model state to stream: {ex.Message}", ex); - } - catch (Exception ex) - { - throw new InvalidOperationException($"Unexpected error while saving time series model state: {ex.Message}", ex); - } - } - - /// - /// Loads the time series model's state from a stream. - /// - /// The stream to read the model state from. - /// - /// - /// This method deserializes a time series model that was previously saved with SaveState. - /// It uses the existing Deserialize method after reading data from the stream. - /// - /// For Beginners: This is like loading a saved snapshot of your time series model. - /// - /// When you call LoadState: - /// - All parameters and trends are read from the stream - /// - Model configuration and state are restored - /// - /// After loading, the model can: - /// - Make forecasts using the restored parameters - /// - Continue training from where it left off - /// - Be deployed to production for time series prediction - /// - /// This is essential for: - /// - Resuming interrupted training sessions - /// - Loading the best model for forecasting - /// - Deploying trained models to production - /// - Knowledge distillation workflows - /// - /// - /// Thrown when stream is null. - /// Thrown when there's an error reading from the stream. - /// Thrown when the stream contains invalid or incompatible data. - public virtual void LoadState(Stream stream) - { - if (stream == null) - throw new ArgumentNullException(nameof(stream)); - - if (!stream.CanRead) - throw new ArgumentException("Stream must be readable.", nameof(stream)); - - try - { - using var ms = new MemoryStream(); - stream.CopyTo(ms); - var data = ms.ToArray(); - - if (data.Length == 0) - throw new InvalidOperationException("Stream contains no data."); - - this.Deserialize(data); - } - catch (IOException ex) - { - throw new IOException($"Failed to read time series model state from stream: {ex.Message}", ex); - } - catch (InvalidOperationException) - { - // Re-throw InvalidOperationException from Deserialize - throw; - } - catch (Exception ex) - { - throw new InvalidOperationException( - $"Failed to deserialize time series model state. The stream may contain corrupted or incompatible data: {ex.Message}", ex); - } - } - - #region IJitCompilable Implementation - - /// - /// - /// - /// Time series models support JIT compilation for accelerated inference. - /// The computation graph represents the linear time series model formula. - /// - /// For Beginners: JIT (Just-In-Time) compilation optimizes time series models for faster predictions. - /// - /// Time series models often involve computing weighted sums of past observations and features. - /// JIT compilation: - /// - Analyzes the model's structure - /// - Optimizes the mathematical operations - /// - Generates specialized native code - /// - Results in 3-7x faster predictions - /// - /// This is especially beneficial for: - /// - Real-time forecasting systems - /// - High-frequency time series (e.g., financial tick data) - /// - Large-scale forecasting (predicting many series simultaneously) - /// - /// Note: JIT compilation works best for linear time series models (AR, ARMA, etc.). - /// More complex models (e.g., those with non-linear transformations) may have - /// limited JIT support. - /// - /// - public virtual bool SupportsJitCompilation - { - get - { - // Check if model is trained and has parameters - return IsTrained && ModelParameters != null && ModelParameters.Length > 0; - } - } - - /// - /// - /// - /// Exports the time series model as a computation graph for JIT compilation. - /// The graph represents the linear model formula: output = input @ model_parameters - /// - /// For Beginners: This method converts the time series model into a computation graph. - /// - /// A computation graph is like a recipe that describes: - /// 1. Take input features (past observations, seasonal indicators, etc.) - /// 2. Multiply by learned model parameters (weights) - /// 3. Return prediction - /// - /// The JIT compiler uses this graph to: - /// - Optimize the operations - /// - Combine steps where possible - /// - Generate fast native code - /// - /// For time series models: - /// - Input: [lag_1, lag_2, ..., lag_p, seasonal_features, trend_features] - /// - Parameters: [φ₁, φ₂, ..., φ_p, seasonal_coeffs, trend_coeffs] - /// - Output: prediction = sum(input[i] * parameters[i]) - /// - /// This is similar to linear regression but specifically structured for time series data. - /// - /// - public virtual ComputationNode ExportComputationGraph(List> inputNodes) - { - // Validation: Ensure inputNodes is not null - if (inputNodes == null) - { - throw new ArgumentNullException(nameof(inputNodes), "Input nodes list cannot be null."); - } - - // Validation: Ensure model is trained - if (!IsTrained) - { - throw new InvalidOperationException("Cannot export computation graph: Model has not been trained yet."); - } - - if (ModelParameters == null || ModelParameters.Length == 0) - { - throw new InvalidOperationException("Cannot export computation graph: Model has no parameters."); - } - - // Create input node (placeholder for input features) - // Time series input shape: [1, feature_count] - // Features typically include: lag values, seasonal indicators, trend components - var featureCount = ModelParameters.Length; - var inputShape = new int[] { 1, featureCount }; - var inputTensor = new Tensor(inputShape); - var inputNode = new ComputationNode(inputTensor); - inputNodes.Add(inputNode); - - // Convert model parameters Vector to Tensor - // Shape: [feature_count, 1] for matrix multiplication - var paramShape = new int[] { featureCount, 1 }; - var paramData = new T[featureCount]; - for (int i = 0; i < featureCount; i++) - { - paramData[i] = ModelParameters[i]; - } - var paramTensor = new Tensor(paramShape, new Vector(paramData)); - var paramNode = new ComputationNode(paramTensor); - - // MatMul: input @ parameters - // Result shape: [1, 1] (single prediction) - var outputNode = TensorOperations.MatrixMultiply(inputNode, paramNode); - - // Note: Most time series models don't have an explicit intercept term - // as it's often absorbed into the parameters or handled during preprocessing. - // If your specific model has an intercept, override this method to add it. - - return outputNode; - } - - #endregion -} +using System.Threading; +using AiDotNet.Autodiff; +using AiDotNet.Helpers; +using AiDotNet.Tensors.Engines.Autodiff; + +namespace AiDotNet.TimeSeries; + +/// +/// Provides a base class for all time series forecasting models in the library. +/// +/// The numeric data type used for calculations (e.g., float, double). +/// +/// +/// This abstract class defines the common interface and functionality that all time series models share, +/// including training, prediction, evaluation, and serialization/deserialization capabilities. +/// +/// +/// Time series models capture temporal dependencies in data and use patterns learned from historical +/// observations to predict future values. This base class provides the foundation for implementing +/// various time series forecasting algorithms like ARIMA, Exponential Smoothing, TBATS, and more complex +/// machine learning approaches. +/// +/// +/// For Beginners: +/// A time series model helps predict future values based on past observations. +/// +/// Think of a time series like a sequence of measurements taken over time - for example, +/// daily temperatures, monthly sales, or hourly website visits. These models analyze the patterns +/// in historical data to make predictions about what will happen next. +/// +/// This base class is like a blueprint that all specific time series models follow. +/// It ensures that every model can: +/// - Be trained on historical data to learn patterns +/// - Make predictions for future periods based on what it learned +/// - Evaluate how accurate its predictions are compared to actual values +/// - Be saved to disk and loaded later without retraining +/// +/// Time series models are used in many real-world applications, including: +/// - Weather forecasting +/// - Stock market prediction +/// - Demand planning for retail +/// - Energy consumption forecasting +/// - Website traffic prediction +/// +/// +public abstract class TimeSeriesModelBase : ITimeSeriesModel, IConfigurableModel, IModelShape +{ + /// + /// Configuration options for the time series model. + /// + /// + /// + /// These options control the core behavior of the time series model, including how much + /// historical data is considered, whether trends or seasonality are modeled, and how errors + /// are handled. + /// + /// + /// For Beginners: + /// Think of these options as settings that determine how the model works: + /// - LagOrder: How many past values to consider (like remembering the last 5 days to predict tomorrow) + /// - IncludeTrend: Whether to account for ongoing trends (like sales steadily increasing over time) + /// - SeasonalPeriod: Whether there are regular patterns (like retail sales spiking every December) + /// - AutocorrelationCorrection: Whether to fix systematic errors in predictions + /// + /// + protected TimeSeriesRegressionOptions Options { get; set; } + + /// + public virtual ModelOptions GetOptions() => Options; + + /// + /// Provides numeric operations for the specific type T. + /// + /// + /// + /// This property provides mathematical operations appropriate for the generic type T, + /// allowing the algorithm to work consistently with different numeric types like + /// float, double, or decimal. + /// + /// + /// For Beginners: + /// This is a helper that knows how to do math (addition, multiplication, etc.) with + /// your specific number type, whether that's a regular double, a precise decimal value, + /// or something else. It allows the model to work with different types of numbers + /// without changing its core logic. + /// + /// + protected INumericOperations NumOps { get; private set; } + + /// + /// Gets the global execution engine for vector operations. + /// + /// + /// + /// This property provides access to the execution engine (CPU or GPU) for performing + /// vectorized operations. The engine is determined by the global AiDotNetEngine configuration + /// and allows automatic fallback from GPU to CPU when GPU is not available. + /// + /// + /// For Beginners: + /// This gives access to either CPU or GPU processing for faster computations. + /// The system automatically chooses the best available option and falls back to CPU + /// if GPU acceleration is not available. + /// + /// + protected IEngine Engine => AiDotNetEngine.Current; + + /// + /// Gets or sets the trained model parameters. + /// + /// + /// + /// Contains the values that the model has learned during training, such as coefficients + /// for different lags, trend components, and seasonal factors. + /// + /// + /// For Beginners: + /// These are the numerical values the model learns during training that tell it exactly + /// how much influence each past observation should have on the prediction. They're like + /// the recipe ingredients with specific measurements that the model has figured out work best. + /// + /// + protected Vector ModelParameters { get; set; } + + /// + /// Indicates whether the model has been trained. + /// + /// + /// + /// This flag is set to true after the model has been successfully trained on data. + /// + /// + /// For Beginners: + /// This is like a switch that gets turned on once the model has learned from your data. + /// It helps prevent errors by making sure you don't try to use the model for predictions + /// before it's ready. + /// + /// + protected bool IsTrained { get; set; } = false; + + /// + /// The default loss function used for gradient computation. + /// + private readonly ILossFunction _defaultLossFunction; + + /// + /// Gets the last computed error metrics when the model was evaluated. + /// + /// + /// + /// Contains accuracy metrics calculated during model evaluation, such as MAE, RMSE, and MAPE. + /// + /// + /// For Beginners: + /// These numbers tell you how accurate the model's predictions are compared to actual values. + /// Lower numbers mean better predictions. They're like a scorecard for the model's performance. + /// + /// + protected Dictionary LastEvaluationMetrics { get; private set; } = new Dictionary(); + + /// + /// Initializes a new instance of the TimeSeriesModelBase class with the specified options. + /// + /// The configuration options for the time series model. + /// Thrown when options is null. + /// Thrown when options contain invalid values. + /// + /// + /// This constructor validates the provided options, initializes the model with the specified + /// configuration, and sets up the numeric operations appropriate for the data type. + /// + /// + /// For Beginners: + /// This constructor sets up the basic configuration for any time series model. + /// + /// It takes an options object that specifies important settings like: + /// - How many past values to consider (lag order) + /// - Whether to include a trend component (like steady growth or decline) + /// - The length of seasonal patterns (e.g., 7 for weekly, 12 for monthly) + /// - Whether to correct for autocorrelation in errors (systematic errors) + /// + /// It also checks that these settings make sense - for example, you can't have a negative + /// number of past values or a seasonal period less than 2. + /// + /// + protected TimeSeriesModelBase(TimeSeriesRegressionOptions options) + { + // Validate options + if (options == null) + { + throw new ArgumentNullException(nameof(options), "Time series options cannot be null."); + } + + ValidateOptions(options); + + Options = options; + NumOps = MathHelper.GetNumericOperations(); + ModelParameters = new Vector(0); // Initialize with empty vector + _defaultLossFunction = options.LossFunction ?? new MeanSquaredErrorLoss(); + } + + /// + /// Validates the provided time series options to ensure they are within acceptable ranges. + /// + /// The options to validate. + /// Thrown when any option is invalid. + /// + /// + /// Checks that LagOrder is non-negative, SeasonalPeriod is either 0 (no seasonality) or at least 2, + /// and that other parameters have reasonable values. + /// + /// + /// For Beginners: + /// This method makes sure the settings you've chosen for your model make logical sense. + /// For example, you can't look back a negative number of time periods, and a seasonal + /// pattern must repeat at least every 2 periods to be considered seasonal. + /// + /// + protected virtual void ValidateOptions(TimeSeriesRegressionOptions options) + { + if (options.LagOrder < 0) + { + throw new ArgumentException("Lag order must be non-negative.", nameof(options)); + } + + if (options.SeasonalPeriod < 0) + { + throw new ArgumentException("Seasonal period must be non-negative.", nameof(options)); + } + + if (options.SeasonalPeriod == 1) + { + throw new ArgumentException("Seasonal period must be at least 2 if seasonality is enabled.", nameof(options)); + } + + // Additional model-specific validation can be implemented in derived classes + } + + /// + /// Trains the time series model using the provided input data and target values. + /// + /// The input features matrix. + /// The target values vector. + /// Thrown when x or y is null. + /// Thrown when the dimensions of x and y don't match or when the data is insufficient. + /// + /// + /// This method validates the input data, prepares the model for training, performs the actual + /// training algorithm, and sets the IsTrained flag once complete. + /// + /// + /// For Beginners: + /// Training is the process where the model learns patterns from historical data. + /// + /// During training, the model analyzes the relationship between: + /// - Input features (x): These might include past values, time indicators, or external factors + /// - Target values (y): The actual observed values we want to predict + /// + /// After training, the model will have learned parameters that capture the patterns + /// in your data, which it can then use to make predictions for new inputs. + /// + /// This is an abstract method, meaning each specific model type (ARIMA, TBATS, etc.) + /// will implement its own training algorithm. + /// + /// + /// + /// Cancellation token that is active during training. Derived classes should check + /// TrainingCancellationToken.IsCancellationRequested in their training loops + /// to support both caller-initiated cancellation and wall-clock timeout from + /// . + /// + protected CancellationToken TrainingCancellationToken { get; private set; } = CancellationToken.None; + + /// + /// Auto-scaled guard threshold computed from training data. + /// Set to 1000 * max(|y|) during training. Falls back to 1e15 if not trained. + /// + private double _autoGuardThreshold = 1e15; + + public void Train(Matrix x, Vector y) => Train(x, y, CancellationToken.None); + + public void Train(Matrix x, Vector y, CancellationToken callerToken) + { + // Fail fast if already cancelled — don't discard existing trained state + callerToken.ThrowIfCancellationRequested(); + + // Input validation + ValidateTrainingInputs(x, y); + + // Reset model state before training + Reset(); + + // Auto-scale guard threshold from training data: 1000x max observed absolute value. + // This adapts overflow protection to the dataset scale instead of a fixed magic number. + double maxAbsY = 0; + for (int i = 0; i < y.Length; i++) + { + double absVal = Math.Abs(NumOps.ToDouble(y[i])); + if (absVal > maxAbsY && !double.IsNaN(absVal) && !double.IsInfinity(absVal)) + maxAbsY = absVal; + } + _autoGuardThreshold = maxAbsY > 0 ? maxAbsY * 1000.0 : 1e15; + + // Create a linked CancellationTokenSource that combines: + // 1. The caller's token (for external cancellation) + // 2. A wall-clock timeout from MaxTrainingTimeSeconds (safety net) + // Whichever fires first wins. + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(callerToken); + if (Options.MaxTrainingTimeSeconds > 0) + { + linkedCts.CancelAfter(TimeSpan.FromSeconds(Options.MaxTrainingTimeSeconds)); + } + TrainingCancellationToken = linkedCts.Token; + + try + { + // Perform model-specific training (implemented by derived classes) + TrainCore(x, y); + } + catch (OperationCanceledException) + { + // Training was cancelled by MaxTrainingTimeSeconds timeout. + // The model may have partial training — mark as trained with whatever + // state was achieved. This is industry standard: early stopping produces + // a usable (if suboptimal) model rather than failing completely. + } + finally + { + TrainingCancellationToken = CancellationToken.None; + } + + // Mark the model as trained (even after early cancellation) + IsTrained = true; + } + + /// + /// Performs the model-specific training algorithm. + /// + /// The input features matrix. + /// The target values vector. + /// + /// + /// This abstract method must be implemented by derived classes to perform the actual model training. + /// + /// + /// For Beginners: + /// This is where the specific math and algorithms for each type of time series model are implemented. + /// Different models (like ARIMA, Exponential Smoothing, etc.) will have their own unique ways of + /// finding patterns in the data. + /// + /// + protected virtual void TrainCore(Matrix x, Vector y) + { + // Default: tape-based training handles parameter updates + } + + /// + /// Validates the training input data before proceeding with training. + /// + /// The input features matrix. + /// The target values vector. + /// Thrown when x or y is null. + /// Thrown when the dimensions of x and y don't match or when the data is insufficient. + /// + /// + /// This method verifies that the input data meets the requirements for model training, + /// including checking dimensions, sample size, and consistency. + /// + /// + /// For Beginners: + /// Before the model starts learning, this method checks that your data is valid and properly formatted. + /// It ensures that: + /// - You have provided both input features and target values + /// - The number of examples matches the number of target values + /// - You have enough data points to train the model effectively + /// - There are no obvious inconsistencies in your data structure + /// + /// + protected virtual void ValidateTrainingInputs(Matrix x, Vector y) + { + if (x == null) + { + throw new ArgumentNullException(nameof(x), "Input features matrix cannot be null."); + } + + if (y == null) + { + throw new ArgumentNullException(nameof(y), "Target values vector cannot be null."); + } + + if (x.Rows != y.Length) + { + throw new ArgumentException( + $"Number of rows in input matrix ({x.Rows}) must match the length of target vector ({y.Length})."); + } + + if (x.Rows <= Options.LagOrder) + { + throw new ArgumentException( + $"Number of samples ({x.Rows}) must be greater than lag order ({Options.LagOrder})."); + } + + // Check for sufficient data to handle seasonality + if (Options.SeasonalPeriod > 0 && x.Rows < 2 * Options.SeasonalPeriod) + { + throw new ArgumentException( + $"For seasonal models, the number of samples ({x.Rows}) should be at least twice the seasonal period ({Options.SeasonalPeriod})."); + } + + // Additional validation can be added in derived classes + } + + /// + /// Generates forecasts using the trained time series model. + /// + /// The input features matrix. + /// A vector of forecasted values. + /// Thrown when the model has not been trained. + /// Thrown when input is null. + /// Thrown when input has incorrect dimensions. + /// + /// + /// This method validates that the model is trained and the input data is valid, then + /// generates predictions for each row in the input matrix using the model-specific + /// prediction algorithm. + /// + /// + /// For Beginners: + /// This method uses the patterns learned during training to predict future values. + /// + /// The input matrix typically contains: + /// - Past values of the time series + /// - Time indicators (e.g., month, day of week) + /// - Any external factors that might influence the forecast + /// + /// The output is a vector of predicted values, one for each row in the input matrix. + /// Each prediction represents what the model thinks will happen at that future time point. + /// + /// + public virtual Vector Predict(Matrix input) + { + // Suppress tape recording during inference + using var _noGrad = new NoGradScope(); + + // Check if model is trained + if (!IsTrained) + { + throw new InvalidOperationException("The model must be trained before making predictions."); + } + + // Validate input + ValidatePredictionInput(input); + + // Create output vector for predictions + var predictions = new Vector(input.Rows); + + // Generate predictions for each input row + for (int i = 0; i < input.Rows; i++) + { + predictions[i] = PredictSingle(input.GetRow(i)); + } + + return predictions; + } + + /// + /// Validates the input data for prediction. + /// + /// The input features matrix. + /// Thrown when input is null. + /// Thrown when input has incorrect dimensions. + /// + /// + /// This method verifies that the input data for prediction is valid and has the correct dimensions. + /// + /// + /// For Beginners: + /// Before making predictions, this method checks that your input data is properly formatted. + /// It ensures that: + /// - You have provided input features + /// - The input has the correct structure (number of features/columns) + /// - The data meets any model-specific requirements + /// + /// + /// + /// Guards a prediction value against NaN, Infinity, and extreme overflow. + /// Returns a clamped finite value. All time series models should call this + /// before storing predictions to prevent cascading numerical instability + /// in recursive/autoregressive forecasting loops. + /// + /// The raw prediction value. + /// + /// Maximum allowed absolute value. If not specified, uses the priority chain: + /// (1) user-configured , + /// (2) auto-scaled from training data (1000x max |y|), + /// (3) fallback to 1e15. + /// + /// A finite, clamped value. + protected T GuardPrediction(T value, double maxAbsValue = -1) + { + // Priority chain: explicit parameter > user option > auto-scaled > 1e15 fallback + double threshold = maxAbsValue > 0 + ? maxAbsValue + : Options.MaxPredictionAbsValue ?? _autoGuardThreshold; + + var d = NumOps.ToDouble(value); + if (double.IsNaN(d) || double.IsInfinity(d) || Math.Abs(d) > threshold) + { + var safe = double.IsNaN(d) ? 0.0 : d; + var clamped = MathPolyfill.Clamp(safe, -threshold, threshold); + return NumOps.FromDouble(clamped); + } + return value; + } + + protected virtual void ValidatePredictionInput(Matrix input) + { + if (input == null) + { + throw new ArgumentNullException(nameof(input), "Input features matrix cannot be null."); + } + + // Additional validation can be added in derived classes + } + + /// + /// Generates a prediction for a single input vector. + /// + /// The input feature vector. + /// The predicted value. + /// + /// + /// This abstract method must be implemented by derived classes to generate a prediction + /// for a single input vector using the model-specific algorithm. + /// + /// + /// For Beginners: + /// This method takes a single row of input data (representing one time point) and + /// calculates what the model predicts will happen at that point. Each type of + /// time series model will have its own way of calculating this prediction based + /// on the patterns it learned during training. + /// + /// + public abstract T PredictSingle(Vector input); + + /// + /// Evaluates the performance of the trained model on test data. + /// + /// The input features matrix for testing. + /// The actual target values for testing. + /// A dictionary containing evaluation metrics. + /// Thrown when the model has not been trained. + /// Thrown when xTest or yTest is null. + /// Thrown when the dimensions of xTest and yTest don't match. + /// + /// + /// This method calculates various error metrics by comparing the model's predictions + /// on the test data to the actual values, providing a quantitative assessment of + /// model performance. + /// + /// + /// For Beginners: + /// This method tests how well the model performs by comparing its predictions to actual values. + /// + /// It works by: + /// 1. Using the model to make predictions based on the test inputs + /// 2. Comparing these predictions to the actual test values + /// 3. Calculating various error metrics to quantify the accuracy + /// + /// Common metrics include: + /// - Mean Absolute Error (MAE): Average of absolute differences between predictions and actual values + /// - Root Mean Squared Error (RMSE): Square root of the average squared differences + /// - Mean Absolute Percentage Error (MAPE): Average percentage differences + /// + /// These metrics help you understand how accurate your model is and compare different models. + /// Lower values indicate better performance for all these metrics. + /// + /// + public virtual Dictionary EvaluateModel(Matrix xTest, Vector yTest) + { + // Check if model is trained + if (!IsTrained) + { + throw new InvalidOperationException("The model must be trained before evaluation."); + } + + // Validate inputs + if (xTest == null) + { + throw new ArgumentNullException(nameof(xTest), "Test features matrix cannot be null."); + } + + if (yTest == null) + { + throw new ArgumentNullException(nameof(yTest), "Test target vector cannot be null."); + } + + if (xTest.Rows != yTest.Length) + { + throw new ArgumentException( + $"Number of rows in test matrix ({xTest.Rows}) must match the length of test vector ({yTest.Length})."); + } + + // Generate predictions + Vector predictions = Predict(xTest); + + // Calculate error metrics + Dictionary metrics = CalculateErrorMetrics(predictions, yTest); + + // Store metrics for later reference + LastEvaluationMetrics = metrics; + + return metrics; + } + + /// + /// Calculates error metrics by comparing predictions to actual values. + /// + /// The predicted values. + /// The actual values. + /// A dictionary containing error metrics. + /// + /// + /// This method computes standard error metrics for time series forecasting, including + /// MAE, RMSE, MAPE, and others as appropriate for the model type. + /// + /// + /// For Beginners: + /// This method calculates how far off the model's predictions are from the actual values. + /// It computes several different ways of measuring the prediction errors: + /// + /// - MAE (Mean Absolute Error): The average magnitude of errors, ignoring whether they're positive or negative + /// - RMSE (Root Mean Squared Error): Emphasizes larger errors by squaring them before averaging + /// - MAPE (Mean Absolute Percentage Error): Shows errors as percentages of the actual values + /// + /// These metrics help you understand not just how accurate the model is overall, + /// but also what kinds of errors it tends to make. + /// + /// + protected virtual Dictionary CalculateErrorMetrics(Vector predictions, Vector actuals) + { + int n = predictions.Length; + var metrics = new Dictionary(); + + // Calculate MAE (Mean Absolute Error) + T sumAbsoluteError = NumOps.Zero; + for (int i = 0; i < n; i++) + { + T error = NumOps.Subtract(predictions[i], actuals[i]); + sumAbsoluteError = NumOps.Add(sumAbsoluteError, NumOps.Abs(error)); + } + T mae = NumOps.Divide(sumAbsoluteError, NumOps.FromDouble(n)); + metrics["MAE"] = mae; + + // Calculate MSE (Mean Squared Error) and RMSE (Root Mean Squared Error) + T sumSquaredError = NumOps.Zero; + for (int i = 0; i < n; i++) + { + T error = NumOps.Subtract(predictions[i], actuals[i]); + sumSquaredError = NumOps.Add(sumSquaredError, NumOps.Square(error)); + } + T mse = NumOps.Divide(sumSquaredError, NumOps.FromDouble(n)); + T rmse = NumOps.Sqrt(mse); + metrics["MSE"] = mse; + metrics["RMSE"] = rmse; + + // Calculate MAPE (Mean Absolute Percentage Error) + // Only if actuals don't contain zeros or very small values + bool canCalculateMape = true; + T sumAbsolutePercentageError = NumOps.Zero; + for (int i = 0; i < n; i++) + { + if (NumOps.LessThan(NumOps.Abs(actuals[i]), NumOps.FromDouble(1e-10))) + { + canCalculateMape = false; + break; + } + + T percentageError = NumOps.Divide( + NumOps.Abs(NumOps.Subtract(predictions[i], actuals[i])), + NumOps.Abs(actuals[i]) + ); + sumAbsolutePercentageError = NumOps.Add(sumAbsolutePercentageError, percentageError); + } + + if (canCalculateMape) + { + T mape = NumOps.Multiply( + NumOps.Divide(sumAbsolutePercentageError, NumOps.FromDouble(n)), + NumOps.FromDouble(100) // Convert to percentage + ); + metrics["MAPE"] = mape; + } + + return metrics; + } + + /// + /// Serializes the model to a byte array for storage or transmission. + /// + /// A byte array containing the serialized model. + /// + /// + /// This method serializes the common components of the model (options, trained status, parameters) + /// and then calls the model-specific serialization method to handle specialized data. + /// + /// + /// For Beginners: + /// Serialization converts the model's state into a format that can be saved to disk + /// or transmitted over a network. + /// + /// This method: + /// 1. Creates a memory stream to hold the serialized data + /// 2. Writes the common configuration options shared by all models + /// 3. Writes whether the model has been trained + /// 4. Writes the model parameters learned during training + /// 5. Calls the model-specific serialization method to write specialized data + /// 6. Returns everything as a byte array + /// + /// This allows you to save a trained model and load it later without having to retrain it, + /// which can save significant time for complex models trained on large datasets. + /// + /// + public virtual byte[] Serialize() + { + ModelPersistenceGuard.EnforceBeforeSerialize(); + using var ms = new MemoryStream(); + using var writer = new BinaryWriter(ms); + + // Serialize common options + writer.Write(Options.LagOrder); + writer.Write(Options.IncludeTrend); + writer.Write(Options.SeasonalPeriod); + writer.Write(Options.AutocorrelationCorrection); + writer.Write((int)Options.ModelType); + + // Serialize trained state + writer.Write(IsTrained); + + // Serialize model parameters if trained + if (IsTrained) + { + writer.Write(ModelParameters.Length); + for (int i = 0; i < ModelParameters.Length; i++) + { + writer.Write(Convert.ToDouble(ModelParameters[i])); + } + + // Serialize evaluation metrics + writer.Write(LastEvaluationMetrics.Count); + foreach (var kvp in LastEvaluationMetrics) + { + writer.Write(kvp.Key); + writer.Write(Convert.ToDouble(kvp.Value)); + } + } + + // Serialize auto-scaled guard threshold (persists training-data-aware overflow protection) + writer.Write(_autoGuardThreshold); + + // Let derived classes serialize their specific data + SerializeCore(writer); + + return ms.ToArray(); + } + + /// + /// Deserializes the model from a byte array. + /// + /// The byte array containing the serialized model. + /// Thrown when data is null. + /// Thrown when the serialized data is corrupted or incompatible. + /// + /// + /// This method deserializes the common components of the model (options, trained status, parameters) + /// and then calls the model-specific deserialization method to handle specialized data. + /// + /// + /// For Beginners: + /// Deserialization is the process of loading a previously saved model from a byte array. + /// + /// This method: + /// 1. Creates a memory stream from the provided byte array + /// 2. Reads the common configuration options shared by all models + /// 3. Reads whether the model has been trained + /// 4. Reads the model parameters learned during training + /// 5. Calls the model-specific deserialization method to read specialized data + /// + /// After deserialization, the model is restored to the same state it was in when serialized, + /// allowing you to make predictions without retraining the model. + /// + /// This is particularly useful for: + /// - Deploying models to production environments + /// - Sharing models between different applications + /// - Saving computation time by not having to retrain complex models + /// + /// + public virtual void Deserialize(byte[] data) + { + ModelPersistenceGuard.EnforceBeforeDeserialize(); + if (data == null) + { + throw new ArgumentNullException(nameof(data), "Serialized data cannot be null."); + } + + try + { + using var ms = new MemoryStream(data); + using var reader = new BinaryReader(ms); + + // Deserialize common options + Options.LagOrder = reader.ReadInt32(); + Options.IncludeTrend = reader.ReadBoolean(); + Options.SeasonalPeriod = reader.ReadInt32(); + Options.AutocorrelationCorrection = reader.ReadBoolean(); + Options.ModelType = (TimeSeriesModelType)reader.ReadInt32(); + + // Deserialize trained state + IsTrained = reader.ReadBoolean(); + + // Deserialize model parameters if trained + if (IsTrained) + { + int parameterCount = reader.ReadInt32(); + ModelParameters = new Vector(parameterCount); + for (int i = 0; i < parameterCount; i++) + { + ModelParameters[i] = NumOps.FromDouble(reader.ReadDouble()); + } + + // Deserialize evaluation metrics + int metricsCount = reader.ReadInt32(); + LastEvaluationMetrics.Clear(); + for (int i = 0; i < metricsCount; i++) + { + string key = reader.ReadString(); + T value = NumOps.FromDouble(reader.ReadDouble()); + LastEvaluationMetrics[key] = value; + } + } + + // Deserialize auto-scaled guard threshold (backwards-compatible) + try + { + _autoGuardThreshold = reader.ReadDouble(); + } + catch (EndOfStreamException) + { + _autoGuardThreshold = 1e15; // Pre-patch model + } + + // Let derived classes deserialize their specific data + DeserializeCore(reader); + } + catch (Exception ex) + { + throw new InvalidOperationException("Failed to deserialize model data. The data may be corrupted or incompatible with this model version.", ex); + } + } + + /// + /// Serializes model-specific data to the binary writer. + /// + /// The binary writer to write to. + /// + /// + /// This abstract method must be implemented by each specific model type to save + /// its unique parameters and state. + /// + /// + /// For Beginners: + /// This method is responsible for saving the specific details that make each type of + /// time series model unique. Different models have different internal structures and parameters + /// that need to be saved separately from the common elements. + /// + /// For example: + /// - An ARIMA model would save its AR, I, and MA coefficients + /// - A TBATS model would save its level, trend, and seasonal components + /// - A neural network model would save its weights and biases + /// + /// This separation allows the base class to handle common serialization tasks + /// while each model type handles its specialized data. + /// + /// + protected abstract void SerializeCore(BinaryWriter writer); + + /// + /// Deserializes model-specific data from the binary reader. + /// + /// The binary reader to read from. + /// + /// + /// This abstract method must be implemented by each specific model type to load + /// its unique parameters and state. + /// + /// + /// For Beginners: + /// This method is responsible for loading the specific details that make each type of + /// time series model unique. It reads exactly what was written by SerializeCore, in the + /// same order, reconstructing the specialized parts of the model. + /// + /// It's the counterpart to SerializeCore and should read data in exactly the same + /// order and format that it was written. + /// + /// This separation allows the base class to handle common deserialization tasks + /// while each model type handles its specialized data. + /// + /// + protected abstract void DeserializeCore(BinaryReader reader); + + /// + /// Gets metadata about the time series model. + /// + /// A ModelMetaData object containing information about the model. + /// + /// + /// This method provides comprehensive metadata about the model, including its type, + /// configuration options, training status, evaluation metrics, and information about + /// which features/lags are most important. + /// + /// + /// For Beginners: + /// This method provides important information about the model that can help you understand + /// its characteristics and performance. + /// + /// The metadata includes: + /// - The type of model (e.g., ARIMA, TBATS, Neural Network) + /// - Configuration details (e.g., lag order, seasonality period) + /// - Whether the model has been trained + /// - Performance metrics from the last evaluation + /// - Information about which features (time periods) are most influential + /// + /// This information is useful for documentation, model comparison, and debugging. + /// It's like a complete summary of everything important about the model. + /// + /// + public abstract ModelMetadata GetModelMetadata(); + + /// + /// Gets the trainable parameters of the model as a vector. + /// + /// A vector containing all trainable parameters of the model. + /// Thrown when the model has not been trained. + /// + /// + /// This method returns all the parameters learned during training, combined into a single vector. + /// These parameters determine how the model makes predictions based on input data. + /// + /// + /// For Beginners: + /// This method returns all the numerical values that the model has learned during training. + /// + /// For time series models, these parameters typically include: + /// - Coefficients for each lag (how much each past value influences the prediction) + /// - Trend coefficients (if trend is included) + /// - Seasonal coefficients (if seasonality is included) + /// - Error correction terms (if autocorrelation correction is enabled) + /// + /// These parameters can be: + /// - Analyzed to understand what the model has learned + /// - Saved for later use + /// - Modified to adjust the model's behavior + /// - Transferred to another model with the same structure + /// + /// + /// + /// Time series models do not support random parameter initialization from the optimizer. + /// They must be trained on sequential data to learn meaningful coefficients. + /// + public virtual bool SupportsParameterInitialization => false; + /// + public virtual Vector SanitizeParameters(Vector parameters) => parameters; + + + public virtual Vector GetParameters() + { + if (!IsTrained && (ModelParameters == null || ModelParameters.Length == 0)) + { + throw new InvalidOperationException("Cannot get parameters for an untrained model."); + } + + if (ModelParameters == null || ModelParameters.Length == 0) + { + throw new InvalidOperationException("Model parameters have not been initialized."); + } + + return ModelParameters.Clone(); + } + + /// + /// Creates a new model with the specified parameters. + /// + /// The vector of parameters to use for the new model. + /// A new model instance with the specified parameters. + /// Thrown when parameters is null. + /// Thrown when the parameters vector has incorrect length. + /// + /// + /// This method creates a clone of the current model but replaces its parameters with the + /// provided values. This allows for creating variations of a model without retraining. + /// + /// + /// For Beginners: + /// This method creates a copy of the current model but with different parameter values. + /// + /// This allows you to: + /// - Create a model with manually specified parameters (e.g., from expert knowledge) + /// - Make small adjustments to a trained model without full retraining + /// - Implement ensemble models that combine multiple parameter sets + /// - Perform what-if analysis by changing specific parameters + /// + /// The parameters must be in the same order and have the same meaning as those + /// returned by the GetParameters method. + /// + /// + public virtual IFullModel, Vector> WithParameters(Vector parameters) + { + if (parameters == null) + { + throw new ArgumentNullException(nameof(parameters), "Parameters vector cannot be null."); + } + + // Create a clone of the current model + var newModel = (TimeSeriesModelBase)this.Clone(); + + // Apply the new parameters to the cloned model + newModel.ApplyParameters(parameters); + + // Mark as trained since parameters have been specified + newModel.IsTrained = true; + + return newModel; + } + + /// + /// Applies the provided parameters to the model. + /// + /// The vector of parameters to apply. + /// Thrown when the parameters vector is invalid. + /// + /// + /// This method applies the provided parameter values to the model, updating its internal state + /// to reflect the new parameters. The implementation is model-specific and should be overridden + /// by derived classes as needed. + /// + /// + /// For Beginners: + /// This method updates the model's internal parameters with new values. + /// It's the counterpart to GetParameters and should understand the parameter + /// vector in exactly the same way. + /// + /// For example, if the first 5 elements of the parameters vector represent + /// lag coefficients, this method should apply them as lag coefficients in + /// the model's internal structure. + /// + /// + protected virtual void ApplyParameters(Vector parameters) + { + if (parameters == null) + { + throw new ArgumentNullException(nameof(parameters), "Parameters vector cannot be null."); + } + + // Store the parameters + ModelParameters = parameters.Clone(); + + // Derived classes should override this to apply parameters to their specific structures + } + + /// + /// Gets the indices of features (lags/time periods) actively used by the model. + /// + /// A collection of indices representing the active features. + /// Thrown when the model has not been trained. + /// + /// + /// This method identifies which input features (lags) have significant impact on the model's + /// predictions, based on their corresponding parameter values. + /// + /// + /// For Beginners: + /// This method tells you which past time periods (lags) are most important for predictions. + /// + /// For example, if the result includes indices [1, 7, 12], this means: + /// - The value from 1 period ago strongly influences the prediction + /// - The value from 7 periods ago strongly influences the prediction (could be weekly seasonality) + /// - The value from 12 periods ago strongly influences the prediction (could be yearly for monthly data) + /// + /// These active features are determined by the model's structure and learned parameters. + /// For instance, in an ARIMA model, non-zero AR coefficients indicate active features. + /// + /// Understanding active features helps interpret how the model works and which + /// historical points matter most for forecasting. + /// + /// + public virtual IEnumerable GetActiveFeatureIndices() + { + if (!IsTrained) + { + throw new InvalidOperationException("The model must be trained before getting active feature indices."); + } + + List activeIndices = new List(); + + // Consider common lag patterns based on model configuration + for (int lag = 1; lag <= Options.LagOrder; lag++) + { + if (IsFeatureUsed(lag)) + { + activeIndices.Add(lag); + } + } + + // If seasonal, also include seasonal lags + if (Options.SeasonalPeriod > 0) + { + for (int s = 1; s <= 4; s++) // Consider up to 4 seasonal lags + { + int seasonalLag = s * Options.SeasonalPeriod; + if (seasonalLag <= Options.LagOrder && IsFeatureUsed(seasonalLag)) + { + activeIndices.Add(seasonalLag); + } + } + } + + return activeIndices; + } + + /// + /// Determines if a specific feature (lag) is actively used by the model. + /// + /// The index of the feature to check. + /// True if the feature is actively used; otherwise, false. + /// Thrown when the model has not been trained. + /// Thrown when featureIndex is negative or exceeds the maximum lag order. + /// + /// + /// This method determines whether a specific lag has a significant impact on the model's predictions, + /// based on its corresponding parameter value. The threshold for significance is model-specific. + /// + /// + /// For Beginners: + /// This method checks if a specific past time period (lag) has a significant + /// influence on the model's predictions. + /// + /// For example: + /// - IsFeatureUsed(1) checks if the value from 1 period ago matters + /// - IsFeatureUsed(7) checks if the value from 7 periods ago matters + /// - IsFeatureUsed(12) checks if the value from 12 periods ago matters + /// + /// A feature is typically considered "used" if its coefficient or weight + /// in the model is significantly different from zero. + /// + /// This information helps understand which historical points the model + /// considers important when making predictions. + /// + /// + public virtual bool IsFeatureUsed(int featureIndex) + { + if (!IsTrained) + { + throw new InvalidOperationException("The model must be trained before checking feature usage."); + } + + if (featureIndex < 0) + { + throw new ArgumentOutOfRangeException(nameof(featureIndex), "Feature index cannot be negative."); + } + + if (featureIndex > Options.LagOrder) + { + // For indices beyond the lag order, check if it's a valid seasonal lag + if (Options.SeasonalPeriod > 0 && featureIndex % Options.SeasonalPeriod == 0) + { + return NumOps.GreaterThan(GetFeatureImportance(featureIndex), NumOps.FromDouble(0.01)); + } + + return false; + } + + // For standard lags, check if the feature importance exceeds a threshold + T importance = GetFeatureImportance(featureIndex); + return NumOps.GreaterThan(importance, NumOps.FromDouble(0.01)); + } + + /// + /// Gets the importance of a specific feature (lag). + /// + /// The index of the feature. + /// A value indicating the feature's importance. + /// Thrown when the model has not been trained. + /// Thrown when featureIndex is negative. + /// + /// + /// This method calculates the importance of a specific lag in the model's predictions, + /// based on its parameter value and the model's structure. The implementation is model-specific. + /// + /// + /// For Beginners: + /// This method estimates how important a specific past time period is + /// for making predictions. Higher values indicate more influential features. + /// + /// For example, in many time series models: + /// - Recent lags (like lag 1) often have higher importance + /// - Seasonal lags (like lag 7 for weekly data) often have higher importance + /// - Some lags may have near-zero importance, meaning they don't affect predictions much + /// + /// This information helps understand the model's internal logic and which past + /// time periods it considers most predictive of future values. + /// + /// + protected virtual T GetFeatureImportance(int featureIndex) + { + if (!IsTrained) + { + throw new InvalidOperationException("The model must be trained before getting feature importance."); + } + + if (featureIndex < 0) + { + throw new ArgumentOutOfRangeException(nameof(featureIndex), "Feature index cannot be negative."); + } + + // Default implementation - derived classes should override with model-specific logic + // For time series models, standard importance calculation might consider: + // 1. The magnitude of coefficients for each lag + // 2. The recency of the lag (more recent lags may be more important) + // 3. Seasonal patterns (lags at seasonal intervals may be more important) + + // As a simple default, if the feature index is within the parameter range, use its absolute value + if (featureIndex < ModelParameters.Length) + { + return NumOps.Abs(ModelParameters[featureIndex]); + } + + // Otherwise, define some heuristic defaults + if (featureIndex == 1) + { + // The most recent lag is usually important + return NumOps.FromDouble(0.5); + } + else if (Options.SeasonalPeriod > 0 && featureIndex % Options.SeasonalPeriod == 0) + { + // Seasonal lags are usually important + return NumOps.FromDouble(0.3); + } + else if (featureIndex <= 3) + { + // Recent lags are moderately important + return NumOps.FromDouble(0.2); + } + + // Default to very low importance for other lags + return NumOps.FromDouble(0.01); + } + + /// + /// Sets the parameters for this model. + /// + /// A vector containing the model parameters. + /// + /// If the model is untrained (ModelParameters is empty), this method will + /// resize ModelParameters to accept the incoming parameters. This allows + /// optimizers to initialize untrained models with random parameters. + /// + public virtual void SetParameters(Vector parameters) + { + // If model is untrained (empty parameters), resize to accept the new parameters + // This allows optimizers to initialize untrained models with random parameters + if (ModelParameters.Length == 0 && parameters.Length > 0) + { + ModelParameters = new Vector(parameters.Length); + } + + if (parameters.Length != ModelParameters.Length) + { + throw new ArgumentException($"Expected {ModelParameters.Length} parameters, but got {parameters.Length}", nameof(parameters)); + } + + for (int i = 0; i < ModelParameters.Length; i++) + { + ModelParameters[i] = parameters[i]; + } + } + + /// + /// Sets the active feature indices for this model. + /// + /// The indices of features to activate. + public virtual void SetActiveFeatureIndices(IEnumerable featureIndices) + { + var activeSet = new HashSet(featureIndices); + + for (int i = 0; i < ModelParameters.Length; i++) + { + if (!activeSet.Contains(i)) + { + ModelParameters[i] = NumOps.Zero; + } + } + } + + /// + /// Gets the feature importance scores as a dictionary. + /// + /// A dictionary mapping feature names to their importance scores. + public virtual Dictionary GetFeatureImportance() + { + var result = new Dictionary(); + + for (int i = 0; i < ModelParameters.Length; i++) + { + string featureName = $"Lag_{i + 1}"; + result[featureName] = NumOps.Abs(ModelParameters[i]); + } + + return result; + } + + /// + /// Creates a deep copy of the time series model. + /// + /// A new instance that is a deep copy of this model. + /// + /// + /// This method creates a completely independent copy of the model, with all parameters, + /// options, and internal state duplicated. Modifications to the copy will not affect the + /// original, and vice versa. + /// + /// + /// For Beginners: + /// This method creates a completely independent copy of the current model. + /// + /// A deep copy means that all components of the model are duplicated, + /// including: + /// - Configuration options + /// - Learned parameters + /// - Internal state variables + /// + /// This is useful when you need to: + /// - Create multiple variations of a model for experimentation + /// - Save a model at a specific point during training + /// - Use the same model structure for different datasets + /// + /// Changes to the copy won't affect the original model and vice versa. + /// + /// + public virtual IFullModel, Vector> DeepCopy() + { + // Create a new instance through serialization/deserialization for a true deep copy + byte[] serialized = this.Serialize(); + var newModel = (TimeSeriesModelBase)CreateInstance(); + newModel.Deserialize(serialized); + + return newModel; + } + + /// + /// Creates a clone of the time series model. + /// + /// A new instance that is a clone of this model. + /// + /// + /// This method creates a copy of the model that shares the same options but has independent + /// parameter values. It's a lighter-weight alternative to DeepCopy for cases where a complete + /// independent copy is not needed. + /// + /// + /// For Beginners: + /// This method creates a copy of the current model with the same configuration + /// and parameters. + /// + /// While DeepCopy creates a fully independent duplicate of everything in the model, + /// Clone sometimes creates a more lightweight copy that might share some non-essential + /// components with the original (depending on the specific model implementation). + /// + /// This is useful for: + /// - Creating variations of a model for ensemble methods + /// - Saving a snapshot of the model before making changes + /// - Creating multiple instances for parallel training + /// + /// + public virtual IFullModel, Vector> Clone() + { + // Use DeepCopy (serialize/deserialize) to ensure all internal state is preserved. + // The lighter Clone approach only copies ModelParameters but misses subclass-specific + // state (e.g. _arCoefficients, _trainedSeries) that Predict actually uses. + return DeepCopy(); + } + + /// + /// Creates a new instance of the derived model class. + /// + /// A new instance of the same model type. + /// + /// + /// This abstract factory method must be implemented by derived classes to create a new + /// instance of their specific type. It's used by Clone and DeepCopy to ensure that + /// the correct derived type is instantiated. + /// + /// + /// For Beginners: + /// This method creates a new, empty instance of the specific model type. + /// It's used during cloning and deep copying to ensure that the copy + /// is of the same specific type as the original. + /// + /// For example, if the original model is an ARIMA model, this method + /// would create a new ARIMA model. If it's a TBATS model, it would + /// create a new TBATS model. + /// + /// + protected abstract IFullModel, Vector> CreateInstance(); + + /// + /// Resets the model to its untrained state. + /// + /// + /// + /// This method clears all trained parameters and returns the model to its initial untrained state. + /// + /// + /// For Beginners: + /// This method erases all the patterns the model has learned. + /// + /// After calling this method: + /// - All coefficients and learned parameters are cleared + /// - The model behaves as if it was never trained + /// - You would need to train it again before making predictions + /// + /// This is useful when you want to: + /// - Experiment with different training data on the same model + /// - Retrain a model from scratch with new parameters + /// - Reset a model that might have been trained incorrectly + /// + /// + public virtual void Reset() + { + // Clear model parameters + ModelParameters = new Vector(0); + + // Reset trained flag + IsTrained = false; + + // Clear evaluation metrics + LastEvaluationMetrics.Clear(); + + // Derived classes should override this to reset any additional state + } + + /// + /// Clips a value to be within the specified range. + /// + /// The value to clip. + /// The minimum allowed value. + /// The maximum allowed value. + /// The clipped value. + /// + /// + /// This utility method constrains a value to be within the specified range. + /// If the value is less than the minimum, the minimum is returned. + /// If the value is greater than the maximum, the maximum is returned. + /// Otherwise, the original value is returned. + /// + /// + /// For Beginners: + /// This method ensures a value stays within a specified range (between min and max). + /// It's like setting boundaries that a value cannot cross. + /// + /// For example, if you clip a value with min=0 and max=1: + /// - If the value is -0.5, it returns 0 (the minimum) + /// - If the value is 1.5, it returns 1 (the maximum) + /// - If the value is 0.7, it returns 0.7 (unchanged, as it's within range) + /// + /// This is useful for: + /// - Preventing parameters from taking extreme values + /// - Constraining predictions to reasonable ranges + /// - Implementing optimization algorithms that require bounded parameters + /// + /// + protected T Clip(T value, T min, T max) + { + if (NumOps.LessThan(value, min)) + { + return min; + } + + if (NumOps.GreaterThan(value, max)) + { + return max; + } + + return value; + } + + /// + /// Generates a forecast for multiple steps ahead. + /// + /// The historical time series data. + /// The number of steps to forecast. + /// A vector containing the forecasted values. + /// Thrown when the model has not been trained. + /// Thrown when history is null. + /// Thrown when steps is not positive or history is insufficient. + /// + /// + /// This method generates a multi-step forecast using the history data as the starting point. + /// For each step, it makes a prediction and then updates the history with the predicted value + /// to generate the next prediction. + /// + /// + /// For Beginners: + /// This method predicts multiple future values in sequence. + /// + /// For example, if you have daily data and want to forecast the next 7 days: + /// 1. It first predicts day 1 using your historical data + /// 2. Then it adds that prediction to the history + /// 3. Then it predicts day 2 using the updated history (including the day 1 prediction) + /// 4. And so on, until it has predicted all 7 days + /// + /// This approach lets you make predictions further into the future, + /// but be aware that errors tend to accumulate with each step (predictions + /// become less accurate the further ahead you forecast). + /// + /// + public virtual Vector Forecast(Vector history, int steps) + { + if (!IsTrained) + { + throw new InvalidOperationException("The model must be trained before forecasting."); + } + + if (history == null) + { + throw new ArgumentNullException(nameof(history), "History cannot be null."); + } + + if (steps <= 0) + { + throw new ArgumentException("Number of forecast steps must be positive.", nameof(steps)); + } + + if (history.Length < Options.LagOrder) + { + throw new ArgumentException( + $"History length ({history.Length}) must be at least equal to lag order ({Options.LagOrder}).", + nameof(history)); + } + + // Create a working copy of the history that we can extend + List extendedHistory = new List(history.Length + steps); + for (int i = 0; i < history.Length; i++) + { + extendedHistory.Add(history[i]); + } + + // Generate forecasts one step at a time + Vector forecasts = new Vector(steps); + for (int step = 0; step < steps; step++) + { + // Prepare input features for this forecast step + Vector features = PrepareForecastFeatures(extendedHistory, step); + + // Make prediction + T forecast = PredictSingle(features); + + // Store forecast + forecasts[step] = forecast; + + // Add forecast to extended history for next step + extendedHistory.Add(forecast); + } + + return forecasts; + } + + /// + /// Prepares input features for a forecast step using the extended history. + /// + /// The historical data including any previous forecasts. + /// The current forecast step (0-based). + /// A vector of input features for the forecast. + /// + /// + /// This method extracts the appropriate lags and constructs any additional features + /// needed for the forecast, such as trend indicators or seasonal dummies. + /// + /// + /// For Beginners: + /// This method prepares the input data needed to make a forecast for a specific step. + /// It typically extracts recent values, seasonal patterns, and trend indicators from + /// the history (which may include previous predictions for multi-step forecasts). + /// + /// + protected virtual Vector PrepareForecastFeatures(List extendedHistory, int step) + { + // This is a basic implementation that derived classes should override + // to include model-specific feature preparation + + // For a simple AR model, we would just include the last LagOrder values + int historyLength = extendedHistory.Count; + int featureCount = Options.LagOrder; + + // Add space for trend if included + if (Options.IncludeTrend) + { + featureCount += 1; + } + + // Add space for seasonal dummies if seasonal + if (Options.SeasonalPeriod > 0) + { + featureCount += Options.SeasonalPeriod; + } + + Vector features = new Vector(featureCount); + int featureIndex = 0; + + // Add lag features + for (int lag = 1; lag <= Options.LagOrder; lag++) + { + if (historyLength - lag >= 0) + { + features[featureIndex++] = extendedHistory[historyLength - lag]; + } + else + { + // Not enough history for this lag, use a default value + features[featureIndex++] = NumOps.Zero; + } + } + + // Add trend feature if included + if (Options.IncludeTrend) + { + features[featureIndex++] = NumOps.FromDouble(step + 1); + } + + // Add seasonal dummies if seasonal + if (Options.SeasonalPeriod > 0) + { + int season = (historyLength + step) % Options.SeasonalPeriod; + for (int s = 0; s < Options.SeasonalPeriod; s++) + { + features[featureIndex++] = NumOps.FromDouble(s == season ? 1.0 : 0.0); + } + } + + return features; + } + + public virtual int ParameterCount + { + get { return ModelParameters.Length; } + } + + /// + public virtual int[] GetInputShape() + { + // LagOrder defines the lookback window. Exogenous features are handled + // as additional columns in the input matrix, so effective input width + // equals LagOrder * (1 + exogenousFeatureCount). Since the exogenous + // count varies per dataset, we report the lag dimension here; subclasses + // with known exogenous counts should override. + return Options.LagOrder > 0 ? new[] { Options.LagOrder } : Array.Empty(); + } + + /// + public virtual int[] GetOutputShape() + { + // Forecasts one value per step; multi-horizon models should override + return new[] { 1 }; + } + + /// + public virtual DynamicShapeInfo GetDynamicShapeInfo() + { + return DynamicShapeInfo.None; + } + + + public virtual void SaveModel(string filePath) + { + if (string.IsNullOrWhiteSpace(filePath)) + throw new ArgumentException("File path must not be null or empty.", nameof(filePath)); + + try + { + var data = Serialize(); + var directory = Path.GetDirectoryName(filePath); + if (!string.IsNullOrEmpty(directory) && !Directory.Exists(directory)) + Directory.CreateDirectory(directory); + byte[] envelopedData = ModelFileHeader.WrapWithHeader( + data, this, GetInputShape(), GetOutputShape(), SerializationFormat.Binary, + GetDynamicShapeInfo()); + File.WriteAllBytes(filePath, envelopedData); + } + catch (IOException ex) { throw new InvalidOperationException($"Failed to save model to '{filePath}': {ex.Message}", ex); } + catch (UnauthorizedAccessException ex) { throw new InvalidOperationException($"Access denied when saving model to '{filePath}': {ex.Message}", ex); } + catch (System.Security.SecurityException ex) { throw new InvalidOperationException($"Security error when saving model to '{filePath}': {ex.Message}", ex); } + } + + public virtual void LoadModel(string filePath) + { + if (string.IsNullOrWhiteSpace(filePath)) + throw new ArgumentException("File path must not be null or empty.", nameof(filePath)); + + try + { + var data = File.ReadAllBytes(filePath); + + // Extract payload from AIMF envelope if present; use raw bytes for legacy files + if (ModelFileHeader.HasHeader(data)) + { + data = ModelFileHeader.ExtractPayload(data); + } + + Deserialize(data); + } + catch (FileNotFoundException ex) { throw new FileNotFoundException($"The specified model file does not exist: {filePath}", filePath, ex); } + catch (IOException ex) { throw new InvalidOperationException($"File I/O error while loading model from '{filePath}': {ex.Message}", ex); } + catch (UnauthorizedAccessException ex) { throw new InvalidOperationException($"Access denied when loading model from '{filePath}': {ex.Message}", ex); } + catch (System.Security.SecurityException ex) { throw new InvalidOperationException($"Security error when loading model from '{filePath}': {ex.Message}", ex); } + catch (Exception ex) { throw new InvalidOperationException($"Failed to deserialize model from file '{filePath}'. The file may be corrupted or incompatible: {ex.Message}", ex); } + } + + public virtual ILossFunction DefaultLossFunction => _defaultLossFunction; + + public virtual Vector ComputeGradients(Matrix input, Vector target, ILossFunction? lossFunction = null) + { + var loss = lossFunction ?? DefaultLossFunction; + + // Primary path: layer-level backpropagation for exact gradients. + // Available for NeuralNetworkBase-derived time series models (Autoformer, Informer, etc.). + try + { + var predicted = Predict(input); + + var lossGrad = loss.CalculateDerivative(predicted, target); + var lossGradTensor = Tensor.FromVector(lossGrad); + + BackpropagateLayers(lossGradTensor); + + var gradients = GetLayerParameterGradients(); + + bool hasValidGradients = false; + for (int i = 0; i < Math.Min(gradients.Length, 100); i++) + { + if (!NumOps.Equals(gradients[i], NumOps.Zero)) + { + hasValidGradients = true; + break; + } + } + + if (hasValidGradients) + return gradients; + } + catch (NotSupportedException) + { + // Expected for models without layer-level backprop — fall through to SPSA + } + catch (Exception ex) + { + System.Diagnostics.Trace.TraceWarning( + $"Layer backpropagation failed for {GetType().Name}, falling back to SPSA: {ex.Message}"); + } + + // Fallback: SPSA gradient approximation — estimates ALL N gradients with just 2 forward + // passes per sample (vs 2N for per-parameter finite differences). + // Reference: Spall, J.C., IEEE TAC, 1992. + var parameters = GetParameters(); + var gradients_spsa = new Vector(parameters.Length); + T epsilon = NumOps.FromDouble(1e-3); + T twoEpsilon = NumOps.Multiply(epsilon, NumOps.FromDouble(2.0)); + var rng = Tensors.Helpers.RandomHelper.CreateSeededRandom(42); + int numSamples = 3; + + var delta = new Vector(parameters.Length); + + for (int s = 0; s < numSamples; s++) + { + for (int i = 0; i < parameters.Length; i++) + delta[i] = rng.NextDouble() < 0.5 ? NumOps.FromDouble(-1.0) : NumOps.FromDouble(1.0); + + var eDelta = Engine.Multiply(delta, epsilon); + + var modelPlus = (TimeSeriesModelBase)WithParameters(Engine.Add(parameters, eDelta)); + var lossPlus = loss.CalculateLoss(modelPlus.Predict(input), target); + + var modelMinus = (TimeSeriesModelBase)WithParameters(Engine.Subtract(parameters, eDelta)); + var lossMinus = loss.CalculateLoss(modelMinus.Predict(input), target); + + T lossDiff = NumOps.Subtract(lossPlus, lossMinus); + var scaledDelta = Engine.Multiply(delta, twoEpsilon); + gradients_spsa = Engine.Add(gradients_spsa, Engine.Divide( + Engine.Fill(parameters.Length, lossDiff), scaledDelta)); + } + + gradients_spsa = Engine.Multiply(gradients_spsa, NumOps.FromDouble(1.0 / numSamples)); + + return gradients_spsa; + } + + /// + /// Backpropagates the loss gradient through the model's neural network layers. + /// Override in NeuralNetworkBase-derived time series models to enable exact gradients. + /// + /// Gradient of the loss w.r.t. the model output. + protected virtual void BackpropagateLayers(Tensor lossGradient) + { + throw new NotSupportedException( + $"{GetType().Name} does not implement BackpropagateLayers. " + + "Override this method in NeuralNetworkBase-derived models."); + } + + /// + /// Extracts accumulated parameter gradients from all layers after backpropagation. + /// + protected virtual Vector GetLayerParameterGradients() + { + throw new NotSupportedException( + $"{GetType().Name} does not implement GetLayerParameterGradients. " + + "Override this method to extract layer-level gradients."); + } + + public virtual void ApplyGradients(Vector gradients, T learningRate) + { + if (gradients == null) + throw new ArgumentNullException(nameof(gradients)); + + var parameters = GetParameters(); + + if (gradients.Length != parameters.Length) + throw new ArgumentException($"Gradient vector length ({gradients.Length}) must match parameter count ({parameters.Length})."); + + parameters = Engine.Subtract(parameters, Engine.Multiply(gradients, learningRate)); + + SetParameters(parameters); + } + + /// + /// Saves the time series model's current state to a stream. + /// + /// The stream to write the model state to. + /// + /// + /// This method serializes the time series model's parameters and configuration. + /// It uses the existing Serialize method and writes the data to the provided stream. + /// + /// For Beginners: This is like creating a snapshot of your trained time series model. + /// + /// When you call SaveState: + /// - All learned parameters and trends are written to the stream + /// - Model configuration and internal state are preserved + /// + /// This is particularly useful for: + /// - Checkpointing during long training sessions + /// - Saving the best model for forecasting + /// - Knowledge distillation from time series models + /// - Deploying forecasting models to production + /// + /// You can later use LoadState to restore the model. + /// + /// + /// Thrown when stream is null. + /// Thrown when there's an error writing to the stream. + public virtual void SaveState(Stream stream) + { + if (stream == null) + throw new ArgumentNullException(nameof(stream)); + + if (!stream.CanWrite) + throw new ArgumentException("Stream must be writable.", nameof(stream)); + + try + { + var data = this.Serialize(); + stream.Write(data, 0, data.Length); + stream.Flush(); + } + catch (IOException ex) + { + throw new IOException($"Failed to save time series model state to stream: {ex.Message}", ex); + } + catch (Exception ex) + { + throw new InvalidOperationException($"Unexpected error while saving time series model state: {ex.Message}", ex); + } + } + + /// + /// Loads the time series model's state from a stream. + /// + /// The stream to read the model state from. + /// + /// + /// This method deserializes a time series model that was previously saved with SaveState. + /// It uses the existing Deserialize method after reading data from the stream. + /// + /// For Beginners: This is like loading a saved snapshot of your time series model. + /// + /// When you call LoadState: + /// - All parameters and trends are read from the stream + /// - Model configuration and state are restored + /// + /// After loading, the model can: + /// - Make forecasts using the restored parameters + /// - Continue training from where it left off + /// - Be deployed to production for time series prediction + /// + /// This is essential for: + /// - Resuming interrupted training sessions + /// - Loading the best model for forecasting + /// - Deploying trained models to production + /// - Knowledge distillation workflows + /// + /// + /// Thrown when stream is null. + /// Thrown when there's an error reading from the stream. + /// Thrown when the stream contains invalid or incompatible data. + public virtual void LoadState(Stream stream) + { + if (stream == null) + throw new ArgumentNullException(nameof(stream)); + + if (!stream.CanRead) + throw new ArgumentException("Stream must be readable.", nameof(stream)); + + try + { + using var ms = new MemoryStream(); + stream.CopyTo(ms); + var data = ms.ToArray(); + + if (data.Length == 0) + throw new InvalidOperationException("Stream contains no data."); + + this.Deserialize(data); + } + catch (IOException ex) + { + throw new IOException($"Failed to read time series model state from stream: {ex.Message}", ex); + } + catch (InvalidOperationException) + { + // Re-throw InvalidOperationException from Deserialize + throw; + } + catch (Exception ex) + { + throw new InvalidOperationException( + $"Failed to deserialize time series model state. The stream may contain corrupted or incompatible data: {ex.Message}", ex); + } + } + +} diff --git a/src/TransferLearning/Algorithms/TransferRandomForest.cs b/src/TransferLearning/Algorithms/TransferRandomForest.cs index 000b93f5a4..be77692ec7 100644 --- a/src/TransferLearning/Algorithms/TransferRandomForest.cs +++ b/src/TransferLearning/Algorithms/TransferRandomForest.cs @@ -652,7 +652,4 @@ public override void LoadState(Stream stream) } } - #region IJitCompilable Implementation - - #endregion } diff --git a/src/UncertaintyQuantification/Layers/BayesianDenseLayer.cs b/src/UncertaintyQuantification/Layers/BayesianDenseLayer.cs index a94c4315fe..9512c41cc6 100644 --- a/src/UncertaintyQuantification/Layers/BayesianDenseLayer.cs +++ b/src/UncertaintyQuantification/Layers/BayesianDenseLayer.cs @@ -455,6 +455,4 @@ public override void ClearGradients() _biasMeanGradient.Fill(NumOps.Zero); _biasLogVarGradient.Fill(NumOps.Zero); } - public override ComputationNode ExportComputationGraph(List> inputNodes) - => throw new NotSupportedException($"{GetType().Name} does not currently support JIT compilation."); } diff --git a/src/UncertaintyQuantification/Layers/MCDropoutLayer.cs b/src/UncertaintyQuantification/Layers/MCDropoutLayer.cs index b34a07d3a5..df7b4eba65 100644 --- a/src/UncertaintyQuantification/Layers/MCDropoutLayer.cs +++ b/src/UncertaintyQuantification/Layers/MCDropoutLayer.cs @@ -167,6 +167,4 @@ public override LayerBase Clone() copy.SetTrainingMode(IsTrainingMode); return copy; } - public override ComputationNode ExportComputationGraph(List> inputNodes) - => throw new NotSupportedException($"{GetType().Name} does not currently support JIT compilation."); } diff --git a/testconsole/Examples/KnowledgeDistillationExample.cs b/testconsole/Examples/KnowledgeDistillationExample.cs index 52ed43ae59..a49a02d4db 100644 --- a/testconsole/Examples/KnowledgeDistillationExample.cs +++ b/testconsole/Examples/KnowledgeDistillationExample.cs @@ -372,22 +372,6 @@ public void SetActiveFeatureIndices(IEnumerable featureIndices) { } public Vector ComputeGradients(Matrix input, Vector target, ILossFunction? lossFunction = null) => new Vector(0); public void ApplyGradients(Vector gradients, double learningRate) { } - // IJitCompilable implementation - public bool SupportsJitCompilation => true; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - // Create a simple computation graph for the mock model - var inputShape = new int[] { 1, _inputDim }; - var inputTensor = new Tensor(inputShape); - var inputNode = TensorOperations.Variable(inputTensor, "input"); - inputNodes.Add(inputNode); - - // Simple transformation: mean of inputs - var outputNode = TensorOperations.Mean(inputNode); - return outputNode; - } - public Vector SanitizeParameters(Vector parameters) => parameters; } } diff --git a/testconsole/Examples/MetaLearning/SimpleMetaModel.cs b/testconsole/Examples/MetaLearning/SimpleMetaModel.cs index f7adf8afc2..78050a96bd 100644 --- a/testconsole/Examples/MetaLearning/SimpleMetaModel.cs +++ b/testconsole/Examples/MetaLearning/SimpleMetaModel.cs @@ -161,10 +161,5 @@ public Dictionary GetFeatureImportance() return importance; } - public bool SupportsJitCompilation => false; - - public ComputationNode ExportComputationGraph(List> inputNodes) - => throw new NotSupportedException(); - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/testconsole/Examples/SimpleKnowledgeDistillationExample.cs b/testconsole/Examples/SimpleKnowledgeDistillationExample.cs index 2a0cf51d67..ff40caf14f 100644 --- a/testconsole/Examples/SimpleKnowledgeDistillationExample.cs +++ b/testconsole/Examples/SimpleKnowledgeDistillationExample.cs @@ -171,22 +171,6 @@ public void SetActiveFeatureIndices(IEnumerable featureIndices) { } public Vector ComputeGradients(Matrix input, Vector target, ILossFunction? lossFunction = null) => new Vector(0); public void ApplyGradients(Vector gradients, double learningRate) { } - // IJitCompilable implementation - public bool SupportsJitCompilation => true; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - // Create a simple computation graph for the mock model - var inputShape = new int[] { 1, _inputDim }; - var inputTensor = new Tensor(inputShape); - var inputNode = TensorOperations.Variable(inputTensor, "input"); - inputNodes.Add(inputNode); - - // Simple transformation: mean of inputs - var outputNode = TensorOperations.Mean(inputNode); - return outputNode; - } - public Vector SanitizeParameters(Vector parameters) => parameters; } } diff --git a/tests/AiDotNet.Tests/AdversarialRobustness/CROWNVerificationTests.cs b/tests/AiDotNet.Tests/AdversarialRobustness/CROWNVerificationTests.cs index 1dda863ded..cb479a0123 100644 --- a/tests/AiDotNet.Tests/AdversarialRobustness/CROWNVerificationTests.cs +++ b/tests/AiDotNet.Tests/AdversarialRobustness/CROWNVerificationTests.cs @@ -726,8 +726,6 @@ public void SetParameters(Vector parameters) { } public void SetActiveFeatureIndices(IEnumerable featureIndices) => _activeFeatures = featureIndices.ToList(); public bool IsFeatureUsed(int featureIndex) => _activeFeatures.Contains(featureIndex); public Dictionary GetFeatureImportance() => Enumerable.Range(0, _inputDim).ToDictionary(i => $"Feature{i}", i => 1.0 / _inputDim); - public bool SupportsJitCompilation => false; - public ComputationNode ExportComputationGraph(List> inputNodes) => throw new NotSupportedException(); public Vector SanitizeParameters(Vector parameters) => parameters; } @@ -786,8 +784,6 @@ public void SetParameters(Vector parameters) { } public void SetActiveFeatureIndices(IEnumerable featureIndices) => _activeFeatures = featureIndices.ToList(); public bool IsFeatureUsed(int featureIndex) => _activeFeatures.Contains(featureIndex); public Dictionary GetFeatureImportance() => Enumerable.Range(0, _inputDim).ToDictionary(i => $"Feature{i}", i => 1.0 / _inputDim); - public bool SupportsJitCompilation => false; - public ComputationNode ExportComputationGraph(List> inputNodes) => throw new NotSupportedException(); public Vector SanitizeParameters(Vector parameters) => parameters; } @@ -844,8 +840,6 @@ public void SetParameters(Vector parameters) { } public void SetActiveFeatureIndices(IEnumerable featureIndices) => _activeFeatures = featureIndices.ToList(); public bool IsFeatureUsed(int featureIndex) => _activeFeatures.Contains(featureIndex); public Dictionary GetFeatureImportance() => Enumerable.Range(0, _inputDim).ToDictionary(i => $"Feature{i}", i => 1.0f / _inputDim); - public bool SupportsJitCompilation => false; - public ComputationNode ExportComputationGraph(List> inputNodes) => throw new NotSupportedException(); public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/AdversarialRobustness/IntervalBoundPropagationTests.cs b/tests/AiDotNet.Tests/AdversarialRobustness/IntervalBoundPropagationTests.cs index f688c3d818..1c7a759f7f 100644 --- a/tests/AiDotNet.Tests/AdversarialRobustness/IntervalBoundPropagationTests.cs +++ b/tests/AiDotNet.Tests/AdversarialRobustness/IntervalBoundPropagationTests.cs @@ -633,8 +633,6 @@ public void SetParameters(Vector parameters) { } public void SetActiveFeatureIndices(IEnumerable featureIndices) => _activeFeatures = featureIndices.ToList(); public bool IsFeatureUsed(int featureIndex) => _activeFeatures.Contains(featureIndex); public Dictionary GetFeatureImportance() => Enumerable.Range(0, _inputDim).ToDictionary(i => $"Feature{i}", i => 1.0 / _inputDim); - public bool SupportsJitCompilation => false; - public ComputationNode ExportComputationGraph(List> inputNodes) => throw new NotSupportedException(); public Vector SanitizeParameters(Vector parameters) => parameters; } @@ -692,8 +690,6 @@ public void SetParameters(Vector parameters) { } public void SetActiveFeatureIndices(IEnumerable featureIndices) => _activeFeatures = featureIndices.ToList(); public bool IsFeatureUsed(int featureIndex) => _activeFeatures.Contains(featureIndex); public Dictionary GetFeatureImportance() => Enumerable.Range(0, _inputDim).ToDictionary(i => $"Feature{i}", i => 1.0 / _inputDim); - public bool SupportsJitCompilation => false; - public ComputationNode ExportComputationGraph(List> inputNodes) => throw new NotSupportedException(); public Vector SanitizeParameters(Vector parameters) => parameters; } @@ -750,8 +746,6 @@ public void SetParameters(Vector parameters) { } public void SetActiveFeatureIndices(IEnumerable featureIndices) => _activeFeatures = featureIndices.ToList(); public bool IsFeatureUsed(int featureIndex) => _activeFeatures.Contains(featureIndex); public Dictionary GetFeatureImportance() => Enumerable.Range(0, _inputDim).ToDictionary(i => $"Feature{i}", i => 1.0f / _inputDim); - public bool SupportsJitCompilation => false; - public ComputationNode ExportComputationGraph(List> inputNodes) => throw new NotSupportedException(); public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/Benchmarks/JIT_BENCHMARKS_README.md b/tests/AiDotNet.Tests/Benchmarks/JIT_BENCHMARKS_README.md deleted file mode 100644 index cc1b66bd1b..0000000000 --- a/tests/AiDotNet.Tests/Benchmarks/JIT_BENCHMARKS_README.md +++ /dev/null @@ -1,311 +0,0 @@ -# JIT Compiler Performance Benchmarks - -This file contains comprehensive performance benchmarks for the AiDotNet JIT compiler using BenchmarkDotNet. - -## Benchmarks Overview - -### 1. Simple Operations -- **Graph**: ReLU(Exp(input)) -- **Tensor Size**: 64x64 -- **Operations**: 2 -- **Purpose**: Measure basic compilation and execution overhead - -### 2. Linear Layer -- **Graph**: ReLU(MatMul(input, weights) + bias) -- **Tensor Sizes**: Input: 32x128, Weights: 128x256, Bias: 1x256 -- **Operations**: 3 (fused to 1 with optimization) -- **Purpose**: Measure fusion optimization benefits - -### 3. Deep Network -- **Graph**: 10 sequential linear layers with ReLU -- **Tensor Sizes**: Batch: 16, Features: 128 per layer -- **Operations**: 30 total (10 x [MatMul + Add + ReLU]) -- **Purpose**: Measure performance on realistic networks - -### 4. Compilation Overhead -- **Graph**: Single ReLU operation -- **Purpose**: Measure pure compilation time -- **Note**: Important for understanding first-call latency - -### 5. Cache Performance -- **Graph**: Previously compiled simple graph -- **Purpose**: Measure cache hit performance (should be ~instant) - -## Running the Benchmarks - -### Method 1: Using BenchmarkDotNet Runner - -```bash -cd tests/AiDotNet.Tests -dotnet run -c Release --project AiDotNetTests.csproj --filter "*JitCompiler*" -``` - -### Method 2: Programmatically - -```csharp -using BenchmarkDotNet.Running; -using AiDotNet.Tests.Benchmarks; - -var summary = BenchmarkRunner.Run(); -``` - -### Method 3: From Test Explorer - -Run the `JitCompilerBenchmarkRunner.Main()` method directly. - -## Expected Results - -### Performance Metrics - -Based on typical hardware (Intel i7, 16GB RAM): - -| Benchmark | Mean Time | Allocated | Notes | -|-----------|-----------|-----------|-------| -| Simple ops - JIT | ~0.05ms | < 1KB | Fast element-wise operations | -| Linear layer - JIT | ~0.15ms | < 5KB | Matrix multiplication + fusion | -| Deep network - JIT | ~1.5ms | < 50KB | 10 layers, significant speedup | -| Compilation overhead | ~15ms | ~20KB | One-time cost | -| Cached compilation | ~0.001ms | < 1KB | Near-instant | - -### Expected Speedups - -Compared to interpreted execution: - -- **Simple operations**: 2-3x faster -- **Linear layer**: 3-5x faster (with fusion) -- **Deep network**: 5-10x faster (many optimizations) -- **Cached compilation**: Effectively free (microseconds) - -## Interpreting Results - -### Mean Time -- Lower is better -- Typical variance: ±5-10% -- Outliers are automatically detected and reported - -### Allocated Memory -- Memory allocated per operation -- Lower is better for GC pressure -- JIT should have minimal allocation after compilation - -### Ratio Columns -BenchmarkDotNet will show ratio compared to baseline if you mark one: - -```csharp -[Benchmark(Baseline = true)] -public void InterpretedExecution() { ... } - -[Benchmark] -public void JITExecution() { ... } -``` - -### StdDev / StdErr -- Standard deviation and error -- Lower indicates more consistent performance -- High variance may indicate GC or thermal throttling - -## Performance Tips - -### 1. Compilation is One-Time Cost - -``` -First execution: Compilation (15ms) + Execution (0.15ms) = ~15.15ms -Next executions: Execution only (0.15ms) = 0.15ms -``` - -**Recommendation**: Compile during initialization, execute in hot path. - -### 2. Caching is Extremely Fast - -Cache hit = ~1 microsecond (0.001ms) -- Structure-based caching -- Same graph structure → instant compilation -- Different data → same compiled function - -### 3. Fusion Provides Major Gains - -Example: Linear layer (MatMul + Add + ReLU) -- Without fusion: 3 separate operations -- With fusion: 1 combined operation -- Speedup: 2-3x from fusion alone - -### 4. Deep Networks Benefit Most - -10-layer network: -- Interpreted: ~15ms -- JIT compiled: ~1.5ms -- **Speedup: ~10x** - -More layers = more optimization opportunities! - -## Benchmarking Best Practices - -### 1. Run in Release Mode - -```bash -dotnet run -c Release -``` - -Debug mode includes extra checks and assertions. - -### 2. Close Other Applications - -- Minimize background processes -- Disable antivirus temporarily -- Close browser/IDE if possible - -### 3. Let CPU Stabilize - -- Wait 30 seconds after starting benchmarks -- CPU frequency scaling needs time to stabilize -- First few iterations may be slower - -### 4. Multiple Runs - -BenchmarkDotNet automatically runs: -- 5 warmup iterations (not measured) -- 20 measured iterations -- Statistical analysis on results - -### 5. Check for Thermal Throttling - -If results vary widely: -- CPU may be thermal throttling -- Check CPU temperature -- Ensure good cooling - -## Customizing Benchmarks - -### Add Custom Configuration - -```csharp -[MemoryDiagnoser] -[SimpleJob(launchCount: 1, warmupCount: 5, iterationCount: 20)] -[MinColumn, MaxColumn, MeanColumn, MedianColumn] -public class JitCompilerBenchmarks -{ - // ... benchmarks -} -``` - -### Filter Specific Benchmarks - -```bash -dotnet run -c Release --filter "*Linear*" -``` - -### Export Results - -```csharp -[MarkdownExporter, HtmlExporter, CsvExporter] -public class JitCompilerBenchmarks { } -``` - -Results saved to `BenchmarkDotNet.Artifacts/`. - -## Comparing with Interpreted Execution - -To add interpreted execution benchmarks: - -```csharp -[Benchmark(Baseline = true, Description = "Linear layer - Interpreted")] -public Tensor LinearLayerInterpreted() -{ - // Execute graph using TensorOperations directly - // (Implementation depends on graph execution engine) - return ExecuteGraphDirectly(_linearGraph); -} - -[Benchmark(Description = "Linear layer - JIT Compiled")] -public Tensor[] LinearLayerJIT() -{ - return _linearCompiled!(new[] { _linearInput!, _linearWeights!, _linearBias! }); -} -``` - -BenchmarkDotNet will automatically show relative performance. - -## Troubleshooting - -### "No benchmarks found" - -- Check namespace matches -- Ensure methods are `public` -- Methods must have `[Benchmark]` attribute - -### Out of Memory - -- Reduce tensor sizes -- Reduce number of layers in deep network -- Run fewer iterations - -### Inconsistent Results - -- Close background applications -- Check CPU temperature -- Run with `launchCount: 3` for multiple processes -- Disable CPU frequency scaling - -### Very Slow Compilation - -Normal! First compilation takes ~10-20ms. -- Parsing graph structure -- Building IR -- Running optimizations -- Expression tree compilation -- .NET JIT compilation - -Cache hits should be <0.01ms. - -## Further Analysis - -### Profiling with BenchmarkDotNet - -```csharp -[EtwProfiler] // Windows only -[ConcurrencyVisualizerProfiler] // Requires Concurrency Visualizer -public class JitCompilerBenchmarks { } -``` - -### Memory Profiling - -The `[MemoryDiagnoser]` attribute provides: -- Gen 0/1/2 collections per operation -- Allocated bytes per operation -- Memory traffic analysis - -### CPU Profiling - -Use: -- Visual Studio Profiler -- dotTrace -- PerfView (Windows) -- perf (Linux) - -## Expected Output Example - -``` -BenchmarkDotNet=v0.13.0, OS=Windows 10 -Intel Core i7-9750H CPU 2.60GHz, 1 CPU, 12 logical and 6 physical cores -.NET SDK=8.0.100 - -| Method | Mean | Error | StdDev | Median | Allocated | -|-------------------------------- |---------:|---------:|---------:|---------:|----------:| -| Simple ops - JIT Compiled | 52.3 μs | 1.2 μs | 0.8 μs | 52.1 μs | 752 B | -| Linear layer - JIT Compiled | 145.6 μs | 3.1 μs | 2.1 μs | 145.2 μs | 4.1 KB | -| Deep network - JIT Compiled | 1.48 ms | 0.03 ms | 0.02 ms | 1.47 ms | 45.2 KB | -| Compilation time (simple graph) | 14.2 ms | 0.5 ms | 0.3 ms | 14.1 ms | 18.5 KB | -| Compilation with cache hit | 0.8 μs | 0.1 μs | 0.05 μs | 0.8 μs | 64 B | -``` - -## Conclusion - -The JIT compiler provides significant performance improvements: -- **2-3x** for simple operations -- **3-5x** for fused operations -- **5-10x** for deep networks -- **Near-zero** overhead for cached compilations - -Compilation cost (~15ms) is easily amortized over repeated executions. - -For questions or issues, please file a GitHub issue! diff --git a/tests/AiDotNet.Tests/GlobalUsings.cs b/tests/AiDotNet.Tests/GlobalUsings.cs index 063818334f..154a8bf32a 100644 --- a/tests/AiDotNet.Tests/GlobalUsings.cs +++ b/tests/AiDotNet.Tests/GlobalUsings.cs @@ -5,5 +5,3 @@ // Resolve type ambiguity between AiDotNet.Enums and AiDotNet.Tensors.Helpers global using QuantizationMode = AiDotNet.Enums.QuantizationMode; -global using MemoryLayout = AiDotNet.InferenceOptimization.IR.Common.MemoryLayout; -global using QuantizationParams = AiDotNet.InferenceOptimization.IR.Common.QuantizationParams; diff --git a/tests/AiDotNet.Tests/Helpers/ActiveLearningTestHelper.cs b/tests/AiDotNet.Tests/Helpers/ActiveLearningTestHelper.cs index 4cb98b5d20..d6c8f93028 100644 --- a/tests/AiDotNet.Tests/Helpers/ActiveLearningTestHelper.cs +++ b/tests/AiDotNet.Tests/Helpers/ActiveLearningTestHelper.cs @@ -243,13 +243,5 @@ public void ApplyGradients(Vector gradients, T learningRate) } } - // IJitCompilable - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("Mock model does not support JIT compilation"); - } - - public bool SupportsJitCompilation => false; - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/Helpers/ContinualLearningTestHelper.cs b/tests/AiDotNet.Tests/Helpers/ContinualLearningTestHelper.cs index e35d71faec..3d5e020e3e 100644 --- a/tests/AiDotNet.Tests/Helpers/ContinualLearningTestHelper.cs +++ b/tests/AiDotNet.Tests/Helpers/ContinualLearningTestHelper.cs @@ -413,14 +413,6 @@ public void ApplyGradients(Vector gradients, T learningRate) } } - // IJitCompilable - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("Mock network does not support JIT compilation"); - } - - public bool SupportsJitCompilation => false; - public T GetLastLoss() => _ops.Zero; public Vector SanitizeParameters(Vector parameters) => parameters; @@ -564,14 +556,6 @@ public void ResetState() // No state to reset for mock } - // IJitCompilable - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("Mock layer does not support JIT compilation"); - } - - public bool SupportsJitCompilation => false; - // GPU execution public bool CanExecuteOnGpu => false; diff --git a/tests/AiDotNet.Tests/Helpers/MockFullModel.cs b/tests/AiDotNet.Tests/Helpers/MockFullModel.cs index 9de6463416..be6c25c955 100644 --- a/tests/AiDotNet.Tests/Helpers/MockFullModel.cs +++ b/tests/AiDotNet.Tests/Helpers/MockFullModel.cs @@ -127,13 +127,5 @@ public void ApplyGradients(Vector gradients, double learningRate) } } - // IJitCompilable implementation - public bool SupportsJitCompilation => false; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation not supported in mock model"); - } - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/Helpers/MockNeuralNetwork.cs b/tests/AiDotNet.Tests/Helpers/MockNeuralNetwork.cs index 472cf5c254..b3a52a00e6 100644 --- a/tests/AiDotNet.Tests/Helpers/MockNeuralNetwork.cs +++ b/tests/AiDotNet.Tests/Helpers/MockNeuralNetwork.cs @@ -273,14 +273,6 @@ public void ApplyGradients(Vector gradients, double learningRate) } } - // IJitCompilable implementation - public bool SupportsJitCompilation => false; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation not supported in mock model"); - } - // ILayeredModel implementation IReadOnlyList> ILayeredModel.Layers => Array.Empty>(); diff --git a/tests/AiDotNet.Tests/InferenceOptimization/AttentionKernelValidationTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/AttentionKernelValidationTests.cs deleted file mode 100644 index 1812eacc92..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/AttentionKernelValidationTests.cs +++ /dev/null @@ -1,121 +0,0 @@ -using System; -using AiDotNet.InferenceOptimization.Kernels; -using AiDotNet.LinearAlgebra; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization; - -public class AttentionKernelValidationTests -{ - [Fact(Timeout = 60000)] - public async Task Execute_MatchesNaiveAttention() - { - var kernel = new AttentionKernel(); - - // [batch=1, seq=2, d=4] - var q = CreateTensor(new[] { 1, 2, 4 }, new float[] { 1, 0, 0, 0, 0, 1, 0, 0 }); - var k = CreateTensor(new[] { 1, 2, 4 }, new float[] { 1, 0, 0, 0, 0, 1, 0, 0 }); - var v = CreateTensor(new[] { 1, 2, 4 }, new float[] { 10, 11, 12, 13, 20, 21, 22, 23 }); - - var actual = kernel.Execute(q, k, v); - var expected = NaiveAttention(q, k, v); - - Assert.Equal(expected.Shape.ToArray(), actual.Shape.ToArray()); - for (int i = 0; i < expected.Length; i++) - { - Assert.Equal(expected[i], actual[i], 5); - } - } - - [Fact(Timeout = 60000)] - public async Task Execute_WithMask_RespectsMaskZeros() - { - var kernel = new AttentionKernel(); - - // [batch=1, seq=2, d=2] - var q = CreateTensor(new[] { 1, 2, 2 }, new float[] { 1, 0, 0, 1 }); - var k = CreateTensor(new[] { 1, 2, 2 }, new float[] { 1, 0, 0, 1 }); - var v = CreateTensor(new[] { 1, 2, 2 }, new float[] { 1, 2, 100, 200 }); - - // Mask: allow only attending to token 0 (mask value 1), disallow token 1 (mask value 0) - var mask = CreateTensor(new[] { 1, 2, 2 }, new float[] { 1, 0, 1, 0 }); - - var actual = kernel.Execute(q, k, v, mask); - - // With token 1 masked out, both rows should match v0. - Assert.Equal(1f, actual[0], 5); - Assert.Equal(2f, actual[1], 5); - Assert.Equal(1f, actual[2], 5); - Assert.Equal(2f, actual[3], 5); - } - - private static Tensor NaiveAttention(Tensor q, Tensor k, Tensor v) - { - int seq = q.Shape[1]; - int d = q.Shape[2]; - float scale = 1f / MathF.Sqrt(d); - - var scores = new float[seq * seq]; - for (int i = 0; i < seq; i++) - { - for (int j = 0; j < seq; j++) - { - float dot = 0f; - for (int t = 0; t < d; t++) - { - dot += q[i * d + t] * k[j * d + t]; - } - - scores[i * seq + j] = dot * scale; - } - } - - for (int i = 0; i < seq; i++) - { - float max = float.NegativeInfinity; - for (int j = 0; j < seq; j++) - { - max = Math.Max(max, scores[i * seq + j]); - } - - float sum = 0f; - for (int j = 0; j < seq; j++) - { - scores[i * seq + j] = MathF.Exp(scores[i * seq + j] - max); - sum += scores[i * seq + j]; - } - - for (int j = 0; j < seq; j++) - { - scores[i * seq + j] /= sum; - } - } - - var result = new Tensor(new[] { 1, seq, v.Shape[2] }); - int dV = v.Shape[2]; - for (int i = 0; i < seq; i++) - { - for (int j = 0; j < dV; j++) - { - float sum = 0f; - for (int t = 0; t < seq; t++) - { - sum += scores[i * seq + t] * v[t * dV + j]; - } - - result[i * dV + j] = sum; - } - } - - return result; - } - - private static Tensor CreateTensor(int[] shape, float[] data) - { - var t = new Tensor(shape); - Assert.Equal(t.Length, data.Length); - t.CopyFromArray(data); - return t; - } -} diff --git a/tests/AiDotNet.Tests/InferenceOptimization/CacheOptimizerTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/CacheOptimizerTests.cs deleted file mode 100644 index bdc814e9ad..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/CacheOptimizerTests.cs +++ /dev/null @@ -1,37 +0,0 @@ -using AiDotNet.Tensors.Engines.Optimization; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization; - -public class CacheOptimizerTests -{ - [Fact(Timeout = 60000)] - public async Task TransposeBlocked_Transposes2DMatrix() - { - // 2x3 - float[] src = new float[] { 1f, 2f, 3f, 4f, 5f, 6f }; - float[] dst = new float[src.Length]; - - CacheOptimizer.TransposeBlocked(src, dst, rows: 2, cols: 3); - - // 3x2 (row-major): [ [1,4], [2,5], [3,6] ] - float[] expected = new float[] { 1f, 4f, 2f, 5f, 3f, 6f }; - Assert.Equal(expected, dst); - } - - [Fact(Timeout = 60000)] - public async Task CopyWithPrefetch_CopiesPrefix() - { - float[] src = new float[] { 1f, 2f, 3f, 4f, 5f }; - float[] dst = new float[] { 0f, 0f, 0f, 0f, 0f }; - - CacheOptimizer.CopyWithPrefetch(src, dst, length: 3); - - Assert.Equal(1f, dst[0]); - Assert.Equal(2f, dst[1]); - Assert.Equal(3f, dst[2]); - Assert.Equal(0f, dst[3]); - Assert.Equal(0f, dst[4]); - } -} diff --git a/tests/AiDotNet.Tests/InferenceOptimization/ConvolutionKernelValidationTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/ConvolutionKernelValidationTests.cs deleted file mode 100644 index 5a4db33b49..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/ConvolutionKernelValidationTests.cs +++ /dev/null @@ -1,23 +0,0 @@ -using System; -using AiDotNet.InferenceOptimization.Kernels; -using AiDotNet.LinearAlgebra; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization; - -public class ConvolutionKernelValidationTests -{ - [Fact(Timeout = 60000)] - public async Task Conv2D_Throws_WhenKernelInChannelsMismatch() - { - var kernel = new ConvolutionKernel(); - - var input = new Tensor(new[] { 1, 3, 5, 5 }); - var badKernel = new Tensor(new[] { 2, 2, 3, 3 }); - - var ex = Assert.Throws(() => kernel.Conv2D(input, badKernel)); - Assert.Contains("kernel.Shape[1] == inChannels", ex.Message, StringComparison.OrdinalIgnoreCase); - } -} - diff --git a/tests/AiDotNet.Tests/InferenceOptimization/GemmKernelValidationTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/GemmKernelValidationTests.cs deleted file mode 100644 index 9343925226..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/GemmKernelValidationTests.cs +++ /dev/null @@ -1,101 +0,0 @@ -using AiDotNet.InferenceOptimization.Kernels; -using AiDotNet.LinearAlgebra; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization; - -public class GemmKernelValidationTests -{ - [Fact(Timeout = 60000)] - public async Task Execute_MatchesNaiveGemm() - { - var kernel = new GemmKernel(); - - // A: 2x3 - var a = CreateTensor(new[] { 2, 3 }, new float[] { 1, 2, 3, 4, 5, 6 }); - // B: 3x2 - var b = CreateTensor(new[] { 3, 2 }, new float[] { 7, 8, 9, 10, 11, 12 }); - - var actual = kernel.Execute(a, b); - var expected = NaiveGemm(a, b); - - Assert.Equal(expected.Shape.ToArray(), actual.Shape.ToArray()); - Assert.Equal(expected.ToArray(), actual.ToArray()); - } - - [Fact(Timeout = 60000)] - public async Task GemmTransposeB_MatchesNaive() - { - var kernel = new GemmKernel(); - - // A: 2x3 - var a = CreateTensor(new[] { 2, 3 }, new float[] { 1, 2, 3, 4, 5, 6 }); - // B: 2x3 (represents B^T; result is 2x2) - var b = CreateTensor(new[] { 2, 3 }, new float[] { 7, 8, 9, 10, 11, 12 }); - - var actual = kernel.GemmTransposeB(a, b); - var expected = NaiveGemmTransposeB(a, b); - - Assert.Equal(expected.Shape.ToArray(), actual.Shape.ToArray()); - Assert.Equal(expected.ToArray(), actual.ToArray()); - } - - private static Tensor NaiveGemm(Tensor a, Tensor b) - { - int m = a.Shape[0]; - int k = a.Shape[1]; - int n = b.Shape[1]; - - var c = new Tensor(new[] { m, n }); - - for (int i = 0; i < m; i++) - { - for (int j = 0; j < n; j++) - { - float sum = 0f; - for (int t = 0; t < k; t++) - { - sum += a[i * k + t] * b[t * n + j]; - } - - c[i * n + j] = sum; - } - } - - return c; - } - - private static Tensor NaiveGemmTransposeB(Tensor a, Tensor b) - { - int m = a.Shape[0]; - int k = a.Shape[1]; - int n = b.Shape[0]; - - var c = new Tensor(new[] { m, n }); - - for (int i = 0; i < m; i++) - { - for (int j = 0; j < n; j++) - { - float sum = 0f; - for (int t = 0; t < k; t++) - { - sum += a[i * k + t] * b[j * k + t]; - } - - c[i * n + j] = sum; - } - } - - return c; - } - - private static Tensor CreateTensor(int[] shape, float[] data) - { - var t = new Tensor(shape); - Assert.Equal(t.Length, data.Length); - t.CopyFromArray(data); - return t; - } -} diff --git a/tests/AiDotNet.Tests/InferenceOptimization/IR/HLIRTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/IR/HLIRTests.cs deleted file mode 100644 index 1afbdff306..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/IR/HLIRTests.cs +++ /dev/null @@ -1,474 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.IR.Common; -using AiDotNet.InferenceOptimization.IR.HighLevel; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization.IR; - -/// -/// Tests for High-Level IR classes. -/// -public class HLIRTests -{ - #region HLIRNode Tests - - [Fact(Timeout = 60000)] - public async Task HLIRNode_DefaultValues_AreCorrect() - { - var node = new HLIRNode(); - - Assert.Equal(0, node.Id); - Assert.Equal(string.Empty, node.Name); - Assert.NotNull(node.Inputs); - Assert.Empty(node.Inputs); - Assert.NotNull(node.Outputs); - Assert.Empty(node.Outputs); - Assert.NotNull(node.OutputType); - Assert.True(node.CanEliminate); - Assert.False(node.IsFused); - Assert.False(node.IsMarkedForDeletion); - } - - [Fact(Timeout = 60000)] - public async Task HLIRNode_AddInput_CreatesBidirectionalConnection() - { - var inputNode = new HLIRNode { Id = 1, Name = "input" }; - var outputNode = new HLIRNode { Id = 2, Name = "output" }; - - outputNode.AddInput(inputNode); - - Assert.Contains(inputNode, outputNode.Inputs); - Assert.Contains(outputNode, inputNode.Outputs); - } - - [Fact(Timeout = 60000)] - public async Task HLIRNode_AddInput_DoesNotDuplicate() - { - var inputNode = new HLIRNode { Id = 1 }; - var outputNode = new HLIRNode { Id = 2 }; - - outputNode.AddInput(inputNode); - outputNode.AddInput(inputNode); - - Assert.Single(outputNode.Inputs); - Assert.Single(inputNode.Outputs); - } - - [Fact(Timeout = 60000)] - public async Task HLIRNode_RemoveInput_RemovesBidirectionalConnection() - { - var inputNode = new HLIRNode { Id = 1 }; - var outputNode = new HLIRNode { Id = 2 }; - - outputNode.AddInput(inputNode); - outputNode.RemoveInput(inputNode); - - Assert.DoesNotContain(inputNode, outputNode.Inputs); - Assert.DoesNotContain(outputNode, inputNode.Outputs); - } - - [Fact(Timeout = 60000)] - public async Task HLIRNode_ReplaceInput_UpdatesConnections() - { - var oldInput = new HLIRNode { Id = 1 }; - var newInput = new HLIRNode { Id = 2 }; - var node = new HLIRNode { Id = 3 }; - - node.AddInput(oldInput); - node.ReplaceInput(oldInput, newInput); - - Assert.DoesNotContain(oldInput, node.Inputs); - Assert.Contains(newInput, node.Inputs); - Assert.DoesNotContain(node, oldInput.Outputs); - Assert.Contains(node, newInput.Outputs); - } - - [Fact(Timeout = 60000)] - public async Task HLIRNode_HasConsumers_ReturnsCorrectValue() - { - var node = new HLIRNode(); - Assert.False(node.HasConsumers); - - var consumer = new HLIRNode(); - consumer.AddInput(node); - - Assert.True(node.HasConsumers); - Assert.Equal(1, node.ConsumerCount); - } - - [Fact(Timeout = 60000)] - public async Task HLIRNode_Validate_ReturnsTrueForValidNode() - { - var input = new HLIRNode { Id = 1 }; - var output = new HLIRNode { Id = 2, OutputType = new TensorType() }; - output.AddInput(input); - - Assert.True(output.Validate()); - } - - [Fact(Timeout = 60000)] - public async Task HLIRNode_Validate_ReturnsFalseForInvalidId() - { - var node = new HLIRNode { Id = -1 }; - Assert.False(node.Validate()); - } - - [Fact(Timeout = 60000)] - public async Task HLIRNode_Clone_CreatesIndependentCopy() - { - var original = new HLIRNode - { - Id = 1, - Name = "test", - Operation = OperationType.Add, - CanEliminate = true, - OutputType = new TensorType { Shape = new[] { 2, 3 } } - }; - - var clone = original.Clone(); - - Assert.Equal(-1, clone.Id); // Clone gets new ID - Assert.Contains("_clone", clone.Name); - Assert.Equal(original.Operation, clone.Operation); - Assert.Equal(original.CanEliminate, clone.CanEliminate); - } - - [Fact(Timeout = 60000)] - public async Task HLIRNode_AddProvenance_TracksHistory() - { - var node = new HLIRNode(); - - node.AddProvenance("Created"); - node.AddProvenance("Modified"); - - Assert.Equal(2, node.Provenance.Count); - Assert.Contains("Created", node.Provenance[0]); - Assert.Contains("Modified", node.Provenance[1]); - } - - [Fact(Timeout = 60000)] - public async Task HLIRNode_ToString_ReturnsFormattedString() - { - var node = new HLIRNode - { - Id = 1, - Name = "relu1", - Operation = OperationType.ReLU - }; - - var str = node.ToString(); - Assert.Contains("n1", str); - Assert.Contains("relu1", str); - Assert.Contains("ReLU", str); - } - - #endregion - - #region HLIRGraph Tests - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_AddNode_AssignsId() - { - var graph = new HLIRGraph(); - var node = new HLIRNode { Id = -1, Name = "test" }; - - graph.AddNode(node); - - Assert.True(node.Id >= 0); - Assert.Equal(1, graph.NodeCount); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_AddNode_ThrowsOnDuplicateId() - { - var graph = new HLIRGraph(); - var node1 = new HLIRNode { Id = 5 }; - var node2 = new HLIRNode { Id = 5 }; - - graph.AddNode(node1); - - Assert.Throws(() => graph.AddNode(node2)); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_CreateNode_CreatesAndAddsNode() - { - var graph = new HLIRGraph(); - var input = graph.CreateNode(OperationType.Input, "input"); - var relu = graph.CreateNode(OperationType.ReLU, "relu", input); - - Assert.Equal(2, graph.NodeCount); - Assert.Contains(input, relu.Inputs); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_RemoveNode_RemovesNodeAndConnections() - { - var graph = new HLIRGraph(); - var input = graph.CreateNode(OperationType.Input, "input"); - var middle = graph.CreateNode(OperationType.ReLU, "middle", input); - var output = graph.CreateNode(OperationType.Output, "output", middle); - - graph.RemoveNode(middle); - - Assert.Equal(2, graph.NodeCount); - Assert.DoesNotContain(middle, output.Inputs); - Assert.DoesNotContain(output, input.Outputs); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_FindNode_ReturnsCorrectNode() - { - var graph = new HLIRGraph(); - var node = graph.CreateNode(OperationType.Input, "test"); - - Assert.Same(node, graph.FindNode(node.Id)); - Assert.Null(graph.FindNode(999)); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_FindNodesByName_ReturnsMatchingNodes() - { - var graph = new HLIRGraph(); - graph.CreateNode(OperationType.ReLU, "relu1"); - graph.CreateNode(OperationType.ReLU, "relu2"); - graph.CreateNode(OperationType.Add, "add1"); - - var reluNodes = graph.FindNodesByName("relu").ToList(); - - Assert.Equal(2, reluNodes.Count); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_FindNodesByOperation_ReturnsMatchingNodes() - { - var graph = new HLIRGraph(); - graph.CreateNode(OperationType.ReLU, "relu1"); - graph.CreateNode(OperationType.ReLU, "relu2"); - graph.CreateNode(OperationType.Add, "add1"); - - var reluNodes = graph.FindNodesByOperation(OperationType.ReLU).ToList(); - - Assert.Equal(2, reluNodes.Count); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_GetTopologicalOrder_ReturnsCorrectOrder() - { - var graph = new HLIRGraph(); - var input = graph.CreateNode(OperationType.Input, "input"); - var relu = graph.CreateNode(OperationType.ReLU, "relu", input); - var output = graph.CreateNode(OperationType.Output, "output", relu); - - graph.InputNodes.Add(input); - graph.OutputNodes.Add(output); - - var order = graph.GetTopologicalOrder(); - - Assert.Equal(3, order.Count); - Assert.True(order.IndexOf(input) < order.IndexOf(relu)); - Assert.True(order.IndexOf(relu) < order.IndexOf(output)); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_GetTopologicalOrder_DetectsCycle() - { - var graph = new HLIRGraph(); - var node1 = graph.CreateNode(OperationType.Add, "node1"); - var node2 = graph.CreateNode(OperationType.Add, "node2", node1); - - // Create cycle - node1.AddInput(node2); - - Assert.Throws(() => graph.GetTopologicalOrder()); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_Validate_ReturnsTrueForValidGraph() - { - var graph = new HLIRGraph(); - var input = graph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 2, 3 } }; - var output = graph.CreateNode(OperationType.Output, "output", input); - output.OutputType = new TensorType { Shape = new[] { 2, 3 } }; - - graph.InputNodes.Add(input); - graph.OutputNodes.Add(output); - - var result = graph.Validate(); - - Assert.True(result.IsValid); - Assert.Empty(result.Errors); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_Clone_CreatesIndependentCopy() - { - var graph = new HLIRGraph { Name = "original" }; - var input = graph.CreateNode(OperationType.Input, "input"); - var output = graph.CreateNode(OperationType.Output, "output", input); - graph.InputNodes.Add(input); - graph.OutputNodes.Add(output); - - var clone = graph.Clone(); - - Assert.Equal(graph.NodeCount, clone.NodeCount); - Assert.Contains("_clone", clone.Name); - - // Verify cloned nodes are different objects - var originalNodes = graph.Nodes.ToList(); - var clonedNodes = clone.Nodes.ToList(); - Assert.NotSame(originalNodes[0], clonedNodes[0]); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_GetStatistics_ReturnsCorrectStats() - { - var graph = new HLIRGraph(); - var input = graph.CreateNode(OperationType.Input, "input"); - var relu1 = graph.CreateNode(OperationType.ReLU, "relu1", input); - var relu2 = graph.CreateNode(OperationType.ReLU, "relu2", relu1); - var output = graph.CreateNode(OperationType.Output, "output", relu2); - - graph.InputNodes.Add(input); - graph.OutputNodes.Add(output); - - var stats = graph.GetStatistics(); - - Assert.Equal(4, stats.TotalNodes); - Assert.Equal(1, stats.InputNodes); - Assert.Equal(1, stats.OutputNodes); - Assert.Equal(2, stats.NodesByOperation[OperationType.ReLU]); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_FindPatterns_DetectsSequentialPattern() - { - var graph = new HLIRGraph(); - var input = graph.CreateNode(OperationType.Input, "input"); - var conv = graph.CreateNode(OperationType.Conv2D, "conv", input); - var bn = graph.CreateNode(OperationType.BatchNorm, "bn", conv); - var relu = graph.CreateNode(OperationType.ReLU, "relu", bn); - - var patterns = graph.FindPatterns(OperationType.Conv2D, OperationType.BatchNorm, OperationType.ReLU); - - Assert.Single(patterns); - Assert.Equal(3, patterns[0].Count); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_ReplaceNode_UpdatesAllConnections() - { - var graph = new HLIRGraph(); - var input = graph.CreateNode(OperationType.Input, "input"); - var oldNode = graph.CreateNode(OperationType.ReLU, "old", input); - var output = graph.CreateNode(OperationType.Output, "output", oldNode); - - var newNode = new HLIRNode - { - Id = -1, - Name = "new", - Operation = OperationType.GELU, - OutputType = new TensorType() - }; - - graph.ReplaceNode(oldNode, newNode); - - Assert.Contains(input, newNode.Inputs); - Assert.Contains(newNode, output.Inputs); - Assert.DoesNotContain(oldNode, output.Inputs); - } - - [Fact(Timeout = 60000)] - public async Task HLIRGraph_CompactNodeIds_ReassignsIdsSequentially() - { - var graph = new HLIRGraph(); - var node1 = new HLIRNode { Id = 100 }; - var node2 = new HLIRNode { Id = 200 }; - var node3 = new HLIRNode { Id = 300 }; - - graph.AddNode(node1); - graph.AddNode(node2); - node2.AddInput(node1); - graph.AddNode(node3); - node3.AddInput(node2); - - graph.InputNodes.Add(node1); - graph.OutputNodes.Add(node3); - - graph.CompactNodeIds(); - - var ids = graph.Nodes.Select(n => n.Id).ToList(); - Assert.Contains(0, ids); - Assert.Contains(1, ids); - Assert.Contains(2, ids); - } - - #endregion - - #region OperationCost Tests - - [Fact(Timeout = 60000)] - public async Task OperationCost_ArithmeticIntensity_CalculatesCorrectly() - { - var cost = new OperationCost - { - FLOPs = 1000, - MemoryRead = 50, - MemoryWrite = 50 - }; - - Assert.Equal(10.0, cost.ArithmeticIntensity); - } - - [Fact(Timeout = 60000)] - public async Task OperationCost_IsMemoryBound_DetectsCorrectly() - { - var memBound = new OperationCost { FLOPs = 100, MemoryRead = 100, MemoryWrite = 100 }; - var computeBound = new OperationCost { FLOPs = 10000, MemoryRead = 100, MemoryWrite = 100 }; - - Assert.True(memBound.IsMemoryBound); - Assert.False(computeBound.IsMemoryBound); - } - - #endregion - - #region OptimizationHints Tests - - [Fact(Timeout = 60000)] - public async Task OptimizationHints_DefaultValues_AreCorrect() - { - var hints = new OptimizationHints(); - - Assert.Equal(DeviceType.Auto, hints.PreferredDevice); - Assert.False(hints.PrioritizeMemory); - Assert.False(hints.PrioritizeLatency); - Assert.True(hints.IsFusionCandidate); - Assert.True(hints.EnableVectorization); - Assert.True(hints.EnableParallelization); - } - - [Fact(Timeout = 60000)] - public async Task OptimizationHints_Clone_CreatesIndependentCopy() - { - var original = new OptimizationHints - { - PreferredDevice = DeviceType.GPU, - TileSizes = new[] { 32, 32 }, - EnableVectorization = false - }; - - var clone = original.Clone(); - - Assert.Equal(original.PreferredDevice, clone.PreferredDevice); - Assert.Equal(original.TileSizes, clone.TileSizes); - Assert.Equal(original.EnableVectorization, clone.EnableVectorization); - - // Modify clone and verify original unchanged - clone.TileSizes![0] = 64; - Assert.NotEqual(original.TileSizes[0], clone.TileSizes[0]); - } - - #endregion -} diff --git a/tests/AiDotNet.Tests/InferenceOptimization/IR/HLIRToLLIRLoweringTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/IR/HLIRToLLIRLoweringTests.cs deleted file mode 100644 index ed5b692975..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/IR/HLIRToLLIRLoweringTests.cs +++ /dev/null @@ -1,397 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.IR.Common; -using AiDotNet.InferenceOptimization.IR.HighLevel; -using AiDotNet.InferenceOptimization.IR.Lowering; -using AiDotNet.InferenceOptimization.IR.LowLevel; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization.IR; - -/// -/// Tests for HLIR to LLIR lowering pipeline. -/// -public class HLIRToLLIRLoweringTests -{ - #region Basic Lowering Tests - - [Fact(Timeout = 60000)] - public async Task Lower_EmptyGraph_ReturnsEmptyLLIRGraph() - { - var hlirGraph = new HLIRGraph(); - var lowering = new HLIRToLLIRLowering(); - - var llirGraph = lowering.Lower(hlirGraph); - - Assert.NotNull(llirGraph); - Assert.Empty(llirGraph.Operations); - } - - [Fact(Timeout = 60000)] - public async Task Lower_SingleInputNode_CreatesInputBuffer() - { - var hlirGraph = new HLIRGraph(); - var inputNode = hlirGraph.CreateNode(OperationType.Input, "input"); - inputNode.OutputType = new TensorType { Shape = new[] { 2, 3 }, DataType = IRDataType.Float32 }; - hlirGraph.InputNodes.Add(inputNode); - hlirGraph.OutputNodes.Add(inputNode); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - Assert.Single(llirGraph.InputIds); - // BufferShapes is keyed by LLIR buffer IDs, not HLIR node IDs - Assert.Contains(llirGraph.InputIds[0], llirGraph.BufferShapes.Keys); - } - - [Fact(Timeout = 60000)] - public async Task Lower_ReLUOperation_CreatesElementwiseOp() - { - var hlirGraph = new HLIRGraph(); - var input = hlirGraph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 2, 3 }, DataType = IRDataType.Float32 }; - - var relu = hlirGraph.CreateNode(OperationType.ReLU, "relu", input); - relu.OutputType = new TensorType { Shape = new[] { 2, 3 }, DataType = IRDataType.Float32 }; - - hlirGraph.InputNodes.Add(input); - hlirGraph.OutputNodes.Add(relu); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - Assert.Contains(llirGraph.Operations, op => op is ElementwiseOp); - var elementwiseOp = llirGraph.Operations.OfType().First(); - Assert.Equal(ElementwiseOpType.ReLU, elementwiseOp.ElementwiseType); - } - - [Fact(Timeout = 60000)] - public async Task Lower_AddOperation_CreatesElementwiseOp() - { - var hlirGraph = new HLIRGraph(); - var input1 = hlirGraph.CreateNode(OperationType.Input, "input1"); - input1.OutputType = new TensorType { Shape = new[] { 2, 3 }, DataType = IRDataType.Float32 }; - - var input2 = hlirGraph.CreateNode(OperationType.Input, "input2"); - input2.OutputType = new TensorType { Shape = new[] { 2, 3 }, DataType = IRDataType.Float32 }; - - var add = hlirGraph.CreateNode(OperationType.Add, "add", input1, input2); - add.OutputType = new TensorType { Shape = new[] { 2, 3 }, DataType = IRDataType.Float32 }; - - hlirGraph.InputNodes.Add(input1); - hlirGraph.InputNodes.Add(input2); - hlirGraph.OutputNodes.Add(add); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - var addOp = llirGraph.Operations.OfType().FirstOrDefault(op => op.ElementwiseType == ElementwiseOpType.Add); - Assert.NotNull(addOp); - Assert.Equal(2, addOp.InputIds.Length); - } - - [Fact(Timeout = 60000)] - public async Task Lower_MatMulOperation_CreatesMatMulOp() - { - var hlirGraph = new HLIRGraph(); - var input1 = hlirGraph.CreateNode(OperationType.Input, "input1"); - input1.OutputType = new TensorType { Shape = new[] { 128, 256 }, DataType = IRDataType.Float32 }; - - var input2 = hlirGraph.CreateNode(OperationType.Input, "input2"); - input2.OutputType = new TensorType { Shape = new[] { 256, 512 }, DataType = IRDataType.Float32 }; - - var matmul = hlirGraph.CreateNode(OperationType.MatMul, "matmul", input1, input2); - matmul.OutputType = new TensorType { Shape = new[] { 128, 512 }, DataType = IRDataType.Float32 }; - - hlirGraph.InputNodes.Add(input1); - hlirGraph.InputNodes.Add(input2); - hlirGraph.OutputNodes.Add(matmul); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - var matmulOp = llirGraph.Operations.OfType().FirstOrDefault(); - Assert.NotNull(matmulOp); - Assert.Equal(128, matmulOp.M); - Assert.Equal(512, matmulOp.N); - Assert.Equal(256, matmulOp.K); - } - - [Fact(Timeout = 60000)] - public async Task Lower_Conv2DOperation_CreatesConv2DOp() - { - var hlirGraph = new HLIRGraph(); - var input = hlirGraph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 1, 64, 56, 56 }, DataType = IRDataType.Float32 }; - - // Create kernel/weight input node - var kernel = hlirGraph.CreateNode(OperationType.Constant, "kernel"); - kernel.OutputType = new TensorType { Shape = new[] { 128, 64, 3, 3 }, DataType = IRDataType.Float32 }; // OIHW format - - var conv = hlirGraph.CreateNode(OperationType.Conv2D, "conv", input, kernel); - conv.OutputType = new TensorType { Shape = new[] { 1, 128, 56, 56 }, DataType = IRDataType.Float32 }; - // Add input types for proper lowering - conv.InputTypes.Add(new TensorType { Shape = new[] { 1, 64, 56, 56 }, DataType = IRDataType.Float32 }); - conv.InputTypes.Add(new TensorType { Shape = new[] { 128, 64, 3, 3 }, DataType = IRDataType.Float32 }); - conv.Attributes["kernel_size"] = new[] { 3, 3 }; - conv.Attributes["stride"] = new[] { 1, 1 }; - conv.Attributes["padding"] = new[] { 1, 1 }; - conv.Attributes["input_channels"] = 64; - conv.Attributes["output_channels"] = 128; - - hlirGraph.InputNodes.Add(input); - hlirGraph.OutputNodes.Add(conv); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - var convOp = llirGraph.Operations.OfType().FirstOrDefault(); - Assert.NotNull(convOp); - Assert.Equal(1, convOp.BatchSize); - } - - [Fact(Timeout = 60000)] - public async Task Lower_ReshapeOperation_CreatesMemoryOp() - { - var hlirGraph = new HLIRGraph(); - var input = hlirGraph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 2, 3, 4 }, DataType = IRDataType.Float32 }; - - var reshape = hlirGraph.CreateNode(OperationType.Reshape, "reshape", input); - reshape.OutputType = new TensorType { Shape = new[] { 6, 4 }, DataType = IRDataType.Float32 }; - - hlirGraph.InputNodes.Add(input); - hlirGraph.OutputNodes.Add(reshape); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - var memOp = llirGraph.Operations.OfType().FirstOrDefault(); - Assert.NotNull(memOp); - Assert.Equal(MemoryOpType.Reshape, memOp.MemoryOpType); - } - - [Fact(Timeout = 60000)] - public async Task Lower_TransposeOperation_CreatesMemoryOp() - { - var hlirGraph = new HLIRGraph(); - var input = hlirGraph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 2, 3 }, DataType = IRDataType.Float32 }; - - var transpose = hlirGraph.CreateNode(OperationType.Transpose, "transpose", input); - transpose.OutputType = new TensorType { Shape = new[] { 3, 2 }, DataType = IRDataType.Float32 }; - - hlirGraph.InputNodes.Add(input); - hlirGraph.OutputNodes.Add(transpose); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - var memOp = llirGraph.Operations.OfType().FirstOrDefault(); - Assert.NotNull(memOp); - Assert.Equal(MemoryOpType.Transpose, memOp.MemoryOpType); - } - - [Fact(Timeout = 60000)] - public async Task Lower_ConstantNode_CreatesConstantOp() - { - var hlirGraph = new HLIRGraph(); - var constant = hlirGraph.CreateNode(OperationType.Constant, "constant"); - constant.OutputType = new TensorType { Shape = new[] { 2, 3 }, DataType = IRDataType.Float32 }; - - hlirGraph.OutputNodes.Add(constant); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - var constOp = llirGraph.Operations.OfType().FirstOrDefault(); - Assert.NotNull(constOp); - } - - #endregion - - #region Activation Function Tests - - [Theory] - [InlineData(OperationType.ReLU, ElementwiseOpType.ReLU)] - [InlineData(OperationType.Sigmoid, ElementwiseOpType.Sigmoid)] - [InlineData(OperationType.Tanh, ElementwiseOpType.Tanh)] - [InlineData(OperationType.GELU, ElementwiseOpType.GELU)] - public void Lower_ActivationFunction_CreatesCorrectElementwiseOp(OperationType hlirOp, ElementwiseOpType expectedLlirOp) - { - var hlirGraph = new HLIRGraph(); - var input = hlirGraph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 2, 3 }, DataType = IRDataType.Float32 }; - - var activation = hlirGraph.CreateNode(hlirOp, "activation", input); - activation.OutputType = new TensorType { Shape = new[] { 2, 3 }, DataType = IRDataType.Float32 }; - - hlirGraph.InputNodes.Add(input); - hlirGraph.OutputNodes.Add(activation); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - var elementwiseOp = llirGraph.Operations.OfType().FirstOrDefault(); - Assert.NotNull(elementwiseOp); - Assert.Equal(expectedLlirOp, elementwiseOp.ElementwiseType); - } - - #endregion - - #region Fused Operation Tests - - [Fact(Timeout = 60000)] - public async Task Lower_FusedNode_CreatesFusedOp() - { - var hlirGraph = new HLIRGraph(); - var input = hlirGraph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 1, 64, 56, 56 }, DataType = IRDataType.Float32 }; - - // Create original nodes that will be marked as fused - var conv = new HLIRNode - { - Id = 10, - Name = "conv", - Operation = OperationType.Conv2D, - OutputType = new TensorType { Shape = new[] { 1, 128, 56, 56 }, DataType = IRDataType.Float32 }, - InputTypes = new List - { - new TensorType { Shape = new[] { 1, 64, 56, 56 }, DataType = IRDataType.Float32 }, - new TensorType { Shape = new[] { 128, 64, 3, 3 }, DataType = IRDataType.Float32 } - } - }; - var relu = new HLIRNode - { - Id = 11, - Name = "relu", - Operation = OperationType.ReLU, - OutputType = new TensorType { Shape = new[] { 1, 128, 56, 56 }, DataType = IRDataType.Float32 } - }; - - // Create fused node - var fused = hlirGraph.CreateNode(OperationType.Add, "fused", input); - fused.OutputType = new TensorType { Shape = new[] { 1, 128, 56, 56 }, DataType = IRDataType.Float32 }; - fused.IsFused = true; - fused.FusedFrom = new List> { conv, relu }; - - hlirGraph.InputNodes.Add(input); - hlirGraph.OutputNodes.Add(fused); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - var fusedOp = llirGraph.Operations.OfType().FirstOrDefault(); - Assert.NotNull(fusedOp); - Assert.Contains("Conv2D", fusedOp.FusionPattern); - Assert.Contains("ReLU", fusedOp.FusionPattern); - } - - #endregion - - #region Complex Graph Tests - - [Fact(Timeout = 60000)] - public async Task Lower_LinearSequence_PreservesOrder() - { - var hlirGraph = new HLIRGraph(); - - var input = hlirGraph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 10 }, DataType = IRDataType.Float32 }; - - var relu = hlirGraph.CreateNode(OperationType.ReLU, "relu", input); - relu.OutputType = new TensorType { Shape = new[] { 10 }, DataType = IRDataType.Float32 }; - - var sigmoid = hlirGraph.CreateNode(OperationType.Sigmoid, "sigmoid", relu); - sigmoid.OutputType = new TensorType { Shape = new[] { 10 }, DataType = IRDataType.Float32 }; - - hlirGraph.InputNodes.Add(input); - hlirGraph.OutputNodes.Add(sigmoid); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - // Should have at least 2 elementwise operations (ReLU and Sigmoid) - var elementwiseOps = llirGraph.Operations.OfType().ToList(); - Assert.True(elementwiseOps.Count >= 2); - } - - [Fact(Timeout = 60000)] - public async Task Lower_BranchingGraph_HandlesMultipleOutputs() - { - var hlirGraph = new HLIRGraph(); - - var input = hlirGraph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 10 }, DataType = IRDataType.Float32 }; - - var relu = hlirGraph.CreateNode(OperationType.ReLU, "relu", input); - relu.OutputType = new TensorType { Shape = new[] { 10 }, DataType = IRDataType.Float32 }; - - var sigmoid = hlirGraph.CreateNode(OperationType.Sigmoid, "sigmoid", input); - sigmoid.OutputType = new TensorType { Shape = new[] { 10 }, DataType = IRDataType.Float32 }; - - hlirGraph.InputNodes.Add(input); - hlirGraph.OutputNodes.Add(relu); - hlirGraph.OutputNodes.Add(sigmoid); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - Assert.Equal(2, llirGraph.OutputIds.Count); - } - - #endregion - - #region Data Type Preservation Tests - - [Theory] - [InlineData(IRDataType.Float32)] - [InlineData(IRDataType.Float64)] - [InlineData(IRDataType.Float16)] - public void Lower_PreservesDataType(IRDataType dataType) - { - var hlirGraph = new HLIRGraph(); - var input = hlirGraph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 10 }, DataType = dataType }; - - var relu = hlirGraph.CreateNode(OperationType.ReLU, "relu", input); - relu.OutputType = new TensorType { Shape = new[] { 10 }, DataType = dataType }; - - hlirGraph.InputNodes.Add(input); - hlirGraph.OutputNodes.Add(relu); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - var elementwiseOp = llirGraph.Operations.OfType().FirstOrDefault(); - Assert.NotNull(elementwiseOp); - Assert.Equal(dataType, elementwiseOp.OutputDataType); - } - - #endregion - - #region Provenance Tracking Tests - - [Fact(Timeout = 60000)] - public async Task Lower_PreservesSourceNodeId() - { - var hlirGraph = new HLIRGraph(); - var input = hlirGraph.CreateNode(OperationType.Input, "input"); - input.OutputType = new TensorType { Shape = new[] { 10 }, DataType = IRDataType.Float32 }; - - var relu = hlirGraph.CreateNode(OperationType.ReLU, "relu", input); - relu.OutputType = new TensorType { Shape = new[] { 10 }, DataType = IRDataType.Float32 }; - - hlirGraph.InputNodes.Add(input); - hlirGraph.OutputNodes.Add(relu); - - var lowering = new HLIRToLLIRLowering(); - var llirGraph = lowering.Lower(hlirGraph); - - var elementwiseOp = llirGraph.Operations.OfType().FirstOrDefault(); - Assert.NotNull(elementwiseOp); - Assert.Equal(relu.Id, elementwiseOp.SourceHLIRNodeId); - } - - #endregion -} diff --git a/tests/AiDotNet.Tests/InferenceOptimization/IR/IRTypesTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/IR/IRTypesTests.cs deleted file mode 100644 index 37c3c83972..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/IR/IRTypesTests.cs +++ /dev/null @@ -1,255 +0,0 @@ -using AiDotNet.InferenceOptimization.IR.Common; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization.IR; - -/// -/// Tests for IR type system classes. -/// -public class IRTypesTests -{ - #region IRDataType Tests - - [Theory] - [InlineData(IRDataType.Float32, true)] - [InlineData(IRDataType.Float64, true)] - [InlineData(IRDataType.Float16, true)] - [InlineData(IRDataType.BFloat16, true)] - [InlineData(IRDataType.Int32, false)] - [InlineData(IRDataType.QInt8, false)] - public void IsFloatingPoint_ReturnsCorrectResult(IRDataType type, bool expected) - { - Assert.Equal(expected, type.IsFloatingPoint()); - } - - [Theory] - [InlineData(IRDataType.Int8, true)] - [InlineData(IRDataType.Int32, true)] - [InlineData(IRDataType.UInt64, true)] - [InlineData(IRDataType.Float32, false)] - [InlineData(IRDataType.QInt8, false)] - public void IsInteger_ReturnsCorrectResult(IRDataType type, bool expected) - { - Assert.Equal(expected, type.IsInteger()); - } - - [Theory] - [InlineData(IRDataType.QInt8, true)] - [InlineData(IRDataType.QUInt8, true)] - [InlineData(IRDataType.QInt4, true)] - [InlineData(IRDataType.QInt2, true)] - [InlineData(IRDataType.Int8, false)] - [InlineData(IRDataType.Float32, false)] - public void IsQuantized_ReturnsCorrectResult(IRDataType type, bool expected) - { - Assert.Equal(expected, type.IsQuantized()); - } - - [Theory] - [InlineData(typeof(float), IRDataType.Float32)] - [InlineData(typeof(double), IRDataType.Float64)] - [InlineData(typeof(int), IRDataType.Int32)] - [InlineData(typeof(byte), IRDataType.UInt8)] - [InlineData(typeof(bool), IRDataType.Bool)] - public void FromSystemType_ConvertsCorrectly(Type systemType, IRDataType expected) - { - Assert.Equal(expected, IRDataTypeExtensions.FromSystemType(systemType)); - } - - [Theory] - [InlineData(IRDataType.Float32, typeof(float))] - [InlineData(IRDataType.Float64, typeof(double))] - [InlineData(IRDataType.Int32, typeof(int))] - [InlineData(IRDataType.UInt8, typeof(byte))] - [InlineData(IRDataType.Bool, typeof(bool))] - public void ToSystemType_ConvertsCorrectly(IRDataType irType, Type expected) - { - Assert.Equal(expected, irType.ToSystemType()); - } - - #endregion - - #region TensorType Tests - - [Fact(Timeout = 60000)] - public async Task TensorType_DefaultValues_AreCorrect() - { - var tensorType = new TensorType(); - - Assert.Equal(IRDataType.Float32, tensorType.DataType); - Assert.Empty(tensorType.Shape.ToArray()); - Assert.Equal(MemoryLayout.RowMajor, tensorType.Layout); - Assert.Equal(DeviceType.Auto, tensorType.Device); - Assert.Null(tensorType.Quantization); - } - - [Theory] - [InlineData(new int[] { 2, 3, 4 }, 24)] - [InlineData(new int[] { 10 }, 10)] - [InlineData(new int[] { }, 1)] // scalar - public void TensorType_NumElements_CalculatesCorrectly(int[] shape, long expected) - { - var tensorType = new TensorType { Shape = shape }; - Assert.Equal(expected, tensorType.NumElements); - } - - [Fact(Timeout = 60000)] - public async Task TensorType_NumElements_ReturnsMinusOneForDynamicShape() - { - var tensorType = new TensorType { Shape = new[] { 2, -1, 4 } }; - Assert.Equal(-1, tensorType.NumElements); - } - - [Fact(Timeout = 60000)] - public async Task TensorType_HasDynamicShape_DetectsDynamicDimensions() - { - var staticType = new TensorType { Shape = new[] { 2, 3, 4 } }; - var dynamicType = new TensorType { Shape = new[] { 2, -1, 4 } }; - - Assert.False(staticType.HasDynamicShape); - Assert.True(dynamicType.HasDynamicShape); - } - - [Theory] - [InlineData(IRDataType.Float32, 4)] - [InlineData(IRDataType.Float64, 8)] - [InlineData(IRDataType.Float16, 2)] - [InlineData(IRDataType.Int8, 1)] - [InlineData(IRDataType.Complex128, 16)] - public void TensorType_ElementSize_ReturnsCorrectSize(IRDataType dataType, int expected) - { - var tensorType = new TensorType { DataType = dataType }; - Assert.Equal(expected, tensorType.ElementSize); - } - - [Fact(Timeout = 60000)] - public async Task TensorType_TotalBytes_CalculatesCorrectly() - { - var tensorType = new TensorType - { - DataType = IRDataType.Float32, - Shape = new[] { 2, 3, 4 } - }; - - Assert.Equal(24 * 4, tensorType.TotalBytes); // 24 elements * 4 bytes - } - - [Theory] - [InlineData(new[] { 3, 4 }, new[] { 3, 4 }, true)] - [InlineData(new[] { 1, 4 }, new[] { 3, 4 }, true)] - [InlineData(new[] { 3, 1 }, new[] { 3, 4 }, true)] - [InlineData(new[] { 4 }, new[] { 3, 4 }, true)] - [InlineData(new[] { 3, 4 }, new[] { 2, 4 }, false)] - public void TensorType_IsBroadcastCompatible_ReturnsCorrectResult(int[] shape1, int[] shape2, bool expected) - { - var type1 = new TensorType { Shape = shape1 }; - var type2 = new TensorType { Shape = shape2 }; - - Assert.Equal(expected, type1.IsBroadcastCompatible(type2)); - } - - [Fact(Timeout = 60000)] - public async Task TensorType_Clone_CreatesIndependentCopy() - { - var original = new TensorType - { - DataType = IRDataType.Float32, - Shape = new[] { 2, 3 }, - Layout = MemoryLayout.NCHW, - Device = DeviceType.GPU - }; - - var clone = original.Clone(); - - Assert.Equal(original.DataType, clone.DataType); - Assert.Equal(original.Shape.ToArray(), clone.Shape.ToArray()); - Assert.Equal(original.Layout, clone.Layout); - Assert.Equal(original.Device, clone.Device); - - // Modify clone and verify original is unchanged - clone.Shape[0] = 999; - Assert.NotEqual(original.Shape[0], clone.Shape[0]); - } - - [Fact(Timeout = 60000)] - public async Task TensorType_ToString_ReturnsFormattedString() - { - var tensorType = new TensorType - { - DataType = IRDataType.Float32, - Shape = new[] { 2, 3, 4 }, - Device = DeviceType.GPU - }; - - var str = tensorType.ToString(); - Assert.Contains("Float32", str); - Assert.Contains("2", str); - Assert.Contains("GPU", str); - } - - #endregion - - #region QuantizationParams Tests - - [Fact(Timeout = 60000)] - public async Task QuantizationParams_DefaultValues_AreCorrect() - { - var qParams = new QuantizationParams(); - - Assert.Equal(1.0, qParams.Scale); - Assert.Equal(0, qParams.ZeroPoint); - Assert.False(qParams.PerChannel); - Assert.Equal(-1, qParams.QuantizationAxis); - } - - [Fact(Timeout = 60000)] - public async Task QuantizationParams_PerChannel_CanBeConfigured() - { - var qParams = new QuantizationParams - { - PerChannel = true, - QuantizationAxis = 0, - PerChannelScales = new[] { 0.1, 0.2, 0.3 }, - PerChannelZeroPoints = new[] { 0, 1, 2 } - }; - - Assert.True(qParams.PerChannel); - Assert.Equal(0, qParams.QuantizationAxis); - Assert.Equal(3, qParams.PerChannelScales!.Length); - Assert.Equal(3, qParams.PerChannelZeroPoints!.Length); - } - - #endregion - - #region MemoryLayout Tests - - [Fact(Timeout = 60000)] - public async Task MemoryLayout_AllValuesAreDefined() - { - Assert.True(Enum.IsDefined(typeof(MemoryLayout), MemoryLayout.RowMajor)); - Assert.True(Enum.IsDefined(typeof(MemoryLayout), MemoryLayout.ColumnMajor)); - Assert.True(Enum.IsDefined(typeof(MemoryLayout), MemoryLayout.NCHW)); - Assert.True(Enum.IsDefined(typeof(MemoryLayout), MemoryLayout.NHWC)); - Assert.True(Enum.IsDefined(typeof(MemoryLayout), MemoryLayout.Tiled4x4)); - Assert.True(Enum.IsDefined(typeof(MemoryLayout), MemoryLayout.Blocked)); - } - - #endregion - - #region DeviceType Tests - - [Fact(Timeout = 60000)] - public async Task DeviceType_AllValuesAreDefined() - { - Assert.True(Enum.IsDefined(typeof(DeviceType), DeviceType.CPU)); - Assert.True(Enum.IsDefined(typeof(DeviceType), DeviceType.GPU)); - Assert.True(Enum.IsDefined(typeof(DeviceType), DeviceType.TPU)); - Assert.True(Enum.IsDefined(typeof(DeviceType), DeviceType.NPU)); - Assert.True(Enum.IsDefined(typeof(DeviceType), DeviceType.FPGA)); - Assert.True(Enum.IsDefined(typeof(DeviceType), DeviceType.Auto)); - Assert.True(Enum.IsDefined(typeof(DeviceType), DeviceType.Any)); - } - - #endregion -} diff --git a/tests/AiDotNet.Tests/InferenceOptimization/IR/LLIRTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/IR/LLIRTests.cs deleted file mode 100644 index 2f45070112..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/IR/LLIRTests.cs +++ /dev/null @@ -1,601 +0,0 @@ -using AiDotNet.InferenceOptimization.IR.Common; -using AiDotNet.InferenceOptimization.IR.LowLevel; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization.IR; - -/// -/// Tests for Low-Level IR classes. -/// -public class LLIRTests -{ - #region LLIROp Tests - - [Fact(Timeout = 60000)] - public async Task MatMulOp_EstimateCost_CalculatesCorrectFLOPs() - { - var matmul = new MatMulOp - { - M = 128, - N = 256, - K = 512, - OutputShape = new[] { 128, 256 } - }; - - var cost = matmul.EstimateCost(); - - // 2 * M * N * K for matmul - Assert.Equal(2L * 128 * 256 * 512, cost.FLOPs); - } - - [Fact(Timeout = 60000)] - public async Task MatMulOp_Validate_ReturnsTrueForValidOp() - { - var matmul = new MatMulOp - { - OutputId = 1, - InputIds = new[] { 0 }, - OutputShape = new[] { 128, 256 }, - M = 128, - N = 256, - K = 512 - }; - - Assert.True(matmul.Validate()); - } - - [Fact(Timeout = 60000)] - public async Task ElementwiseOp_EstimateCost_CalculatesCorrectly() - { - var elementwise = new ElementwiseOp - { - OutputId = 1, - InputIds = new[] { 0 }, - OutputShape = new[] { 2, 3, 4 }, - ElementwiseType = ElementwiseOpType.ReLU - }; - - var cost = elementwise.EstimateCost(); - - Assert.Equal(24, cost.FLOPs); // 2*3*4 = 24 elements - } - - [Fact(Timeout = 60000)] - public async Task ElementwiseOp_FusedMultiplyAdd_DoublesOperationCount() - { - var fma = new ElementwiseOp - { - OutputId = 1, - InputIds = new[] { 0, 1 }, - OutputShape = new[] { 10 }, - ElementwiseType = ElementwiseOpType.FusedMultiplyAdd - }; - - var cost = fma.EstimateCost(); - - Assert.Equal(20, cost.FLOPs); // 10 * 2 for FMA - } - - [Fact(Timeout = 60000)] - public async Task Conv2DOp_EstimateCost_CalculatesCorrectFLOPs() - { - var conv = new Conv2DOp - { - BatchSize = 1, - InputChannels = 64, - OutputChannels = 128, - InputHeight = 56, - InputWidth = 56, - KernelHeight = 3, - KernelWidth = 3, - StrideH = 1, - StrideW = 1, - PadH = 1, - PadW = 1, - OutputShape = new[] { 1, 128, 56, 56 } - }; - - var cost = conv.EstimateCost(); - - // 2 * BatchSize * OutputChannels * OutH * OutW * (InputChannels/Groups) * KH * KW - var expectedFlops = 2L * 1 * 128 * 56 * 56 * 64 * 3 * 3; - Assert.Equal(expectedFlops, cost.FLOPs); - } - - [Fact(Timeout = 60000)] - public async Task ReduceOp_EstimateCost_CalculatesCorrectly() - { - var reduce = new ReduceOp - { - OutputId = 1, - InputIds = new[] { 0 }, - OutputShape = new[] { 4 }, - ReduceType = ReduceType.Sum, - Axes = new[] { 1, 2 } - }; - - var cost = reduce.EstimateCost(); - - Assert.True(cost.FLOPs > 0); - Assert.True(cost.MemoryRead > 0); - } - - [Fact(Timeout = 60000)] - public async Task MemoryOp_EstimateCost_HasZeroFLOPs() - { - var memOp = new MemoryOp - { - OutputId = 1, - InputIds = new[] { 0 }, - OutputShape = new[] { 2, 3, 4 }, - MemoryOpType = MemoryOpType.Reshape - }; - - var cost = memOp.EstimateCost(); - - Assert.Equal(0, cost.FLOPs); - Assert.True(cost.MemoryRead > 0); - Assert.True(cost.MemoryWrite > 0); - } - - [Fact(Timeout = 60000)] - public async Task FusedOp_EstimateCost_CombinesOperations() - { - var op1 = new ElementwiseOp { OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.ReLU }; - var op2 = new ElementwiseOp { OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.Add }; - - var fused = new FusedOp - { - OutputId = 1, - InputIds = new[] { 0 }, - OutputShape = new[] { 10 }, - FusionPattern = "ReLU_Add", - FusedOps = new List { op1, op2 } - }; - - var cost = fused.EstimateCost(); - - Assert.Equal(20, cost.FLOPs); // 10 + 10 - } - - [Fact(Timeout = 60000)] - public async Task ConstantOp_EstimateCost_CalculatesCorrectly() - { - var constant = new ConstantOp - { - OutputId = 0, - OutputShape = new[] { 10, 20 }, - IsParameter = true, - ParameterName = "weights" - }; - - var cost = constant.EstimateCost(); - - Assert.Equal(0, cost.FLOPs); - Assert.True(cost.MemoryRead > 0); - } - - [Fact(Timeout = 60000)] - public async Task LLIROp_ToString_ReturnsFormattedString() - { - var matmul = new MatMulOp - { - OutputId = 2, - InputIds = new[] { 0, 1 }, - OutputShape = new[] { 128, 256 }, - OutputDataType = IRDataType.Float32, - Device = DeviceType.GPU - }; - - var str = matmul.ToString(); - - Assert.Contains("b2", str); - Assert.Contains("MatMul", str); - Assert.Contains("b0", str); - Assert.Contains("GPU", str); - } - - #endregion - - #region ScheduleInfo Tests - - [Fact(Timeout = 60000)] - public async Task ScheduleInfo_DefaultValues_AreCorrect() - { - var schedule = new ScheduleInfo(); - - Assert.Empty(schedule.TileSizes); - Assert.Empty(schedule.LoopOrder); - Assert.Empty(schedule.ParallelAxes); - Assert.Equal(-1, schedule.VectorAxis); - Assert.Equal(1, schedule.VectorWidth); - Assert.Equal(1, schedule.UnrollFactor); - } - - [Fact(Timeout = 60000)] - public async Task ScheduleInfo_Clone_CreatesIndependentCopy() - { - var original = new ScheduleInfo - { - TileSizes = new[] { 32, 32 }, - VectorWidth = 8, - ThreadBlockDims = new[] { 256, 1, 1 } - }; - - var clone = original.Clone(); - - Assert.Equal(original.TileSizes, clone.TileSizes); - Assert.Equal(original.VectorWidth, clone.VectorWidth); - Assert.Equal(original.ThreadBlockDims, clone.ThreadBlockDims); - - // Modify clone - clone.TileSizes[0] = 64; - Assert.NotEqual(original.TileSizes[0], clone.TileSizes[0]); - } - - #endregion - - #region BufferInfo Tests - - [Fact(Timeout = 60000)] - public async Task BufferInfo_DefaultValues_AreCorrect() - { - var buffer = new BufferInfo(); - - Assert.Equal(0, buffer.SizeBytes); - Assert.Equal(64, buffer.Alignment); - Assert.Equal(-1, buffer.MemoryPoolId); - Assert.Equal(MemoryLevel.DRAM, buffer.MemoryLevel); - Assert.False(buffer.CanInPlace); - Assert.False(buffer.IsPersistent); - } - - [Fact(Timeout = 60000)] - public async Task BufferInfo_InPlaceConfiguration_CanBeSet() - { - var buffer = new BufferInfo - { - CanInPlace = true, - InPlaceInputId = 5, - SizeBytes = 1024 - }; - - Assert.True(buffer.CanInPlace); - Assert.Equal(5, buffer.InPlaceInputId); - } - - #endregion - - #region LLIRGraph Tests - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_AddOperation_AssignsOutputId() - { - var graph = new LLIRGraph(); - var op = new ElementwiseOp - { - OutputId = -1, - OutputShape = new[] { 10 }, - ElementwiseType = ElementwiseOpType.ReLU - }; - - graph.AddOperation(op); - - Assert.True(op.OutputId >= 0); - Assert.Single(graph.Operations); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_AllocateBufferId_ReturnsUniqueIds() - { - var graph = new LLIRGraph(); - - var id1 = graph.AllocateBufferId(); - var id2 = graph.AllocateBufferId(); - var id3 = graph.AllocateBufferId(); - - Assert.NotEqual(id1, id2); - Assert.NotEqual(id2, id3); - Assert.NotEqual(id1, id3); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_GetOperationByOutputId_ReturnsCorrectOp() - { - var graph = new LLIRGraph(); - var op1 = new ElementwiseOp { OutputId = 5, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.ReLU }; - var op2 = new ElementwiseOp { OutputId = 10, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.Add }; - - graph.AddOperation(op1); - graph.AddOperation(op2); - - Assert.Same(op1, graph.GetOperationByOutputId(5)); - Assert.Same(op2, graph.GetOperationByOutputId(10)); - Assert.Null(graph.GetOperationByOutputId(999)); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_GetConsumers_ReturnsCorrectOperations() - { - var graph = new LLIRGraph(); - graph.InputIds.Add(0); - - var op1 = new ElementwiseOp { OutputId = 1, InputIds = new[] { 0 }, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.ReLU }; - var op2 = new ElementwiseOp { OutputId = 2, InputIds = new[] { 0 }, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.Add }; - var op3 = new ElementwiseOp { OutputId = 3, InputIds = new[] { 1 }, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.Sigmoid }; - - graph.AddOperation(op1); - graph.AddOperation(op2); - graph.AddOperation(op3); - - var consumers = graph.GetConsumers(0).ToList(); - - Assert.Equal(2, consumers.Count); - Assert.Contains(op1, consumers); - Assert.Contains(op2, consumers); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_Validate_ReturnsTrueForValidGraph() - { - var graph = new LLIRGraph(); - graph.InputIds.Add(0); - graph.BufferShapes[0] = new[] { 10 }; - graph.BufferTypes[0] = IRDataType.Float32; - - var op = new ElementwiseOp - { - OutputId = 1, - InputIds = new[] { 0 }, - OutputShape = new[] { 10 }, - ElementwiseType = ElementwiseOpType.ReLU - }; - graph.AddOperation(op); - graph.OutputIds.Add(1); - - var result = graph.Validate(); - - Assert.True(result.IsValid); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_Validate_DetectsUndefinedInput() - { - var graph = new LLIRGraph(); - - var op = new ElementwiseOp - { - OutputId = 1, - InputIds = new[] { 999 }, // Undefined input - OutputShape = new[] { 10 }, - ElementwiseType = ElementwiseOpType.ReLU - }; - graph.AddOperation(op); - - var result = graph.Validate(); - - Assert.False(result.IsValid); - Assert.Contains(result.Errors, e => e.Contains("undefined buffer")); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_ComputeMetrics_CalculatesCorrectly() - { - var graph = new LLIRGraph(); - graph.InputIds.Add(0); - graph.BufferShapes[0] = new[] { 128, 256 }; - - var matmul = new MatMulOp - { - OutputId = 1, - InputIds = new[] { 0 }, - OutputShape = new[] { 128, 256 }, - M = 128, - N = 256, - K = 512 - }; - graph.AddOperation(matmul); - - var metrics = graph.ComputeMetrics(); - - Assert.Equal(1, metrics.OperationCount); - Assert.True(metrics.TotalFLOPs > 0); - Assert.True(metrics.PeakMemoryBytes > 0); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_OptimizeMemory_CreatesMemoryPlan() - { - var graph = new LLIRGraph(); - graph.InputIds.Add(0); - graph.BufferShapes[0] = new[] { 10 }; - graph.BufferTypes[0] = IRDataType.Float32; - - var op1 = new ElementwiseOp { InputIds = new[] { 0 }, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.ReLU }; - var op2 = new ElementwiseOp { InputIds = new[] { op1.OutputId }, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.Add }; - - graph.AddOperation(op1); - op2.InputIds = new[] { op1.OutputId }; - graph.AddOperation(op2); - - graph.OptimizeMemory(); - - Assert.NotNull(graph.MemoryPlan); - Assert.True(graph.MemoryPlan.PoolCount >= 0); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_AutoSchedule_SetsSchedulingInfo() - { - var graph = new LLIRGraph(); - graph.DeviceConfig = new DeviceConfiguration { CPUVectorWidth = 8, CPUCores = 4 }; - - var op = new ElementwiseOp - { - OutputId = 0, - OutputShape = new[] { 64, 64 }, - ElementwiseType = ElementwiseOpType.ReLU, - Device = DeviceType.CPU - }; - graph.AddOperation(op); - - graph.AutoSchedule(); - - // AutoSchedule should set vector width for CPU ops - Assert.True(op.Schedule.VectorWidth >= 1); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_ComputeCriticalPath_CalculatesCorrectly() - { - var graph = new LLIRGraph(); - graph.InputIds.Add(0); - graph.BufferShapes[0] = new[] { 10 }; - - var op1 = new ElementwiseOp { OutputId = 1, InputIds = new[] { 0 }, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.ReLU }; - var op2 = new ElementwiseOp { OutputId = 2, InputIds = new[] { 1 }, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.Add }; - - graph.AddOperation(op1); - graph.AddOperation(op2); - graph.OutputIds.Add(2); - - var criticalPath = graph.ComputeCriticalPath(); - - Assert.True(criticalPath > 0); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_Clone_CreatesIndependentCopy() - { - var graph = new LLIRGraph { Name = "original" }; - graph.InputIds.Add(0); - graph.BufferShapes[0] = new[] { 10 }; - - var op = new ElementwiseOp { OutputId = 1, InputIds = new[] { 0 }, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.ReLU }; - graph.AddOperation(op); - - var clone = graph.Clone(); - - Assert.Contains("_clone", clone.Name); - Assert.Equal(graph.Operations.Count, clone.Operations.Count); - Assert.Equal(graph.InputIds, clone.InputIds); - } - - [Fact(Timeout = 60000)] - public async Task LLIRGraph_ComputeStructureHash_ReturnsSameHashForSameStructure() - { - var graph1 = new LLIRGraph(); - graph1.InputIds.Add(0); - graph1.BufferShapes[0] = new[] { 10 }; - var op1 = new ElementwiseOp { OutputId = 1, InputIds = new[] { 0 }, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.ReLU }; - graph1.AddOperation(op1); - graph1.OutputIds.Add(1); - - var graph2 = new LLIRGraph(); - graph2.InputIds.Add(0); - graph2.BufferShapes[0] = new[] { 10 }; - var op2 = new ElementwiseOp { OutputId = 1, InputIds = new[] { 0 }, OutputShape = new[] { 10 }, ElementwiseType = ElementwiseOpType.ReLU }; - graph2.AddOperation(op2); - graph2.OutputIds.Add(1); - - Assert.Equal(graph1.ComputeStructureHash(), graph2.ComputeStructureHash()); - } - - #endregion - - #region DeviceConfiguration Tests - - [Fact(Timeout = 60000)] - public async Task DeviceConfiguration_DefaultValues_AreReasonable() - { - var config = new DeviceConfiguration(); - - Assert.True(config.CPUCores > 0); - Assert.True(config.CPUVectorWidth > 0); - Assert.True(config.L1CacheBytes > 0); - Assert.True(config.L2CacheBytes > 0); - Assert.True(config.L3CacheBytes > 0); - Assert.True(config.CPUMemoryBandwidth > 0); - Assert.True(config.CPUPeakGFLOPS > 0); - } - - #endregion - - #region OperationMetrics Tests - - [Fact(Timeout = 60000)] - public async Task OperationMetrics_ArithmeticIntensity_CalculatesCorrectly() - { - var metrics = new OperationMetrics - { - FLOPs = 1000, - IntOps = 0, - MemoryRead = 50, - MemoryWrite = 50 - }; - - Assert.Equal(10.0, metrics.ArithmeticIntensity); - } - - [Fact(Timeout = 60000)] - public async Task OperationMetrics_RooflineGFLOPS_CalculatesCorrectly() - { - var metrics = new OperationMetrics - { - FLOPs = 1000, - MemoryRead = 50, - MemoryWrite = 50 - }; - - double peakGFLOPS = 100; - double memBandwidth = 50; - - var roofline = metrics.RooflineGFLOPS(peakGFLOPS, memBandwidth); - - // min(100, 10 * 50) = min(100, 500) = 100 - Assert.Equal(100, roofline); - } - - #endregion - - #region MemoryPlan Tests - - [Fact(Timeout = 60000)] - public async Task MemoryPlan_Validate_ReturnsTrueForValidPlan() - { - var plan = new MemoryPlan - { - PoolCount = 2, - PoolSizes = new long[] { 1024, 2048 }, - BufferAssignments = new Dictionary - { - { 0, (0, 0) }, - { 1, (1, 0) } - } - }; - - var result = plan.Validate(); - - Assert.True(result.IsValid); - } - - [Fact(Timeout = 60000)] - public async Task MemoryPlan_Validate_DetectsInvalidPoolId() - { - var plan = new MemoryPlan - { - PoolCount = 1, - BufferAssignments = new Dictionary - { - { 0, (5, 0) } // Invalid pool ID - } - }; - - var result = plan.Validate(); - - Assert.False(result.IsValid); - } - - #endregion -} diff --git a/tests/AiDotNet.Tests/InferenceOptimization/OptimizationGraphTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/OptimizationGraphTests.cs deleted file mode 100644 index 5502b6cbb2..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/OptimizationGraphTests.cs +++ /dev/null @@ -1,176 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization; - -public class OptimizationGraphTests -{ - [Fact(Timeout = 60000)] - public async Task AddNode_ShouldAddNodeToGraph() - { - // Arrange - var graph = new OptimizationGraph(); - var node = new OptimizationNode - { - OperationType = OperationType.ReLU, - Name = "relu1" - }; - - // Act - graph.AddNode(node); - - // Assert - Assert.Contains(node, graph.Nodes); - Assert.Single(graph.Nodes); - } - - [Fact(Timeout = 60000)] - public async Task RemoveNode_ShouldRemoveNodeFromGraph() - { - // Arrange - var graph = new OptimizationGraph(); - var node = new OptimizationNode - { - OperationType = OperationType.ReLU, - Name = "relu1" - }; - graph.AddNode(node); - - // Act - graph.RemoveNode(node); - - // Assert - Assert.DoesNotContain(node, graph.Nodes); - Assert.Empty(graph.Nodes); - } - - [Fact(Timeout = 60000)] - public async Task FindNodeById_ShouldReturnCorrectNode() - { - // Arrange - var graph = new OptimizationGraph(); - var node = new OptimizationNode - { - Id = "test-id", - OperationType = OperationType.ReLU, - Name = "relu1" - }; - graph.AddNode(node); - - // Act - var found = graph.FindNodeById("test-id"); - - // Assert - Assert.NotNull(found); - Assert.Equal(node, found); - } - - [Fact(Timeout = 60000)] - public async Task GetTopologicalOrder_ShouldReturnValidOrder() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var conv = new OptimizationNode { OperationType = OperationType.Convolution, Name = "conv" }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - conv.AddInput(input); - relu.AddInput(conv); - output.AddInput(relu); - - graph.AddNode(input); - graph.AddNode(conv); - graph.AddNode(relu); - graph.AddNode(output); - - // Act - var order = graph.GetTopologicalOrder(); - - // Assert - Assert.Equal(4, order.Count); - Assert.True(order.IndexOf(input) < order.IndexOf(conv)); - Assert.True(order.IndexOf(conv) < order.IndexOf(relu)); - Assert.True(order.IndexOf(relu) < order.IndexOf(output)); - } - - [Fact(Timeout = 60000)] - public async Task Validate_ShouldReturnTrueForValidGraph() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - relu.AddInput(input); - output.AddInput(relu); - - graph.AddNode(input); - graph.AddNode(relu); - graph.AddNode(output); - - // Act - var isValid = graph.Validate(); - - // Assert - Assert.True(isValid); - } - - [Fact(Timeout = 60000)] - public async Task Clone_ShouldCreateDeepCopy() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - - relu.AddInput(input); - - graph.AddNode(input); - graph.AddNode(relu); - - // Act - var clonedResult = graph.Clone(); - var cloned = clonedResult as OptimizationGraph; - - // Assert - Assert.NotNull(cloned); - if (cloned != null) - { - Assert.Equal(graph.Nodes.Count, cloned.Nodes.Count); - Assert.NotSame(graph.Nodes[0], cloned.Nodes[0]); - } - } - - [Fact(Timeout = 60000)] - public async Task GetStatistics_ShouldReturnCorrectCounts() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var relu1 = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu1" }; - var relu2 = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu2" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - graph.AddNode(input); - graph.AddNode(relu1); - graph.AddNode(relu2); - graph.AddNode(output); - - // Act - var stats = graph.GetStatistics(); - - // Assert - Assert.Equal(4, stats.TotalNodes); - Assert.Equal(1, stats.InputNodes); - Assert.Equal(1, stats.OutputNodes); - Assert.Equal(2, stats.OperationTypeCounts[OperationType.ReLU]); - } -} diff --git a/tests/AiDotNet.Tests/InferenceOptimization/OptimizationPassTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/OptimizationPassTests.cs deleted file mode 100644 index 80173052b9..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/OptimizationPassTests.cs +++ /dev/null @@ -1,1742 +0,0 @@ -#nullable disable -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; -using AiDotNet.InferenceOptimization.Passes; -using AiDotNet.LinearAlgebra; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization; - -public class OptimizationPassTests -{ - [Fact(Timeout = 60000)] - public async Task ConvBatchNormFusionPass_ShouldFuseConvAndBatchNorm() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var conv = new OptimizationNode { OperationType = OperationType.Convolution, Name = "conv" }; - var bn = new OptimizationNode { OperationType = OperationType.BatchNormalization, Name = "bn" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - conv.AddInput(input); - bn.AddInput(conv); - output.AddInput(bn); - - graph.AddNode(input); - graph.AddNode(conv); - graph.AddNode(bn); - graph.AddNode(output); - - var pass = new ConvBatchNormFusionPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - Assert.Contains(graph.Nodes, n => n.OperationType == OperationType.FusedConvBatchNorm); - Assert.DoesNotContain(graph.Nodes, n => n.Name == "conv"); - Assert.DoesNotContain(graph.Nodes, n => n.Name == "bn"); - } - - [Fact(Timeout = 60000)] - public async Task ConvBatchNormReLUFusionPass_ShouldFuseThreeOperations() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var conv = new OptimizationNode { OperationType = OperationType.Convolution, Name = "conv" }; - var bn = new OptimizationNode { OperationType = OperationType.BatchNormalization, Name = "bn" }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - conv.AddInput(input); - bn.AddInput(conv); - relu.AddInput(bn); - output.AddInput(relu); - - graph.AddNode(input); - graph.AddNode(conv); - graph.AddNode(bn); - graph.AddNode(relu); - graph.AddNode(output); - - var pass = new ConvBatchNormReLUFusionPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - Assert.Contains(graph.Nodes, n => n.OperationType == OperationType.FusedConvBatchNormReLU); - } - - [Fact(Timeout = 60000)] - public async Task DeadCodeEliminationPass_ShouldRemoveUnusedNodes() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var relu1 = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu1" }; - var relu2 = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu2" }; // Dead code - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - relu1.AddInput(input); - relu2.AddInput(input); // Not connected to output - output.AddInput(relu1); - - graph.AddNode(input); - graph.AddNode(relu1); - graph.AddNode(relu2); - graph.AddNode(output); - - var pass = new DeadCodeEliminationPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - Assert.DoesNotContain(graph.Nodes, n => n.Name == "relu2"); - Assert.Contains(graph.Nodes, n => n.Name == "relu1"); - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_ShouldFoldConstants() - { - // Arrange - var graph = new OptimizationGraph(); - - var const1 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = null // Would be actual tensor - }; - - var const2 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const2", - ConstantValue = null // Would be actual tensor - }; - - var add = new OptimizationNode - { - OperationType = OperationType.Add, - Name = "add" - }; - - add.AddInput(const1); - add.AddInput(const2); - - graph.AddNode(const1); - graph.AddNode(const2); - graph.AddNode(add); - - var pass = new ConstantFoldingPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.True(canApply); - // Note: Actual folding would require tensor implementation - } - - [Fact(Timeout = 60000)] - public async Task InPlaceOptimizationPass_ShouldMarkEligibleOperations() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - relu.AddInput(input); - output.AddInput(relu); - - graph.AddNode(input); - graph.AddNode(relu); - graph.AddNode(output); - - var pass = new InPlaceOptimizationPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - Assert.True(relu.CanOperateInPlace); - } - - [Fact(Timeout = 60000)] - public async Task AlgebraicSimplificationPass_ShouldSimplifyIdentities() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - - var zero = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "zero", - Metadata = new Dictionary { ["IsZero"] = true } - }; - - var add = new OptimizationNode { OperationType = OperationType.Add, Name = "add" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - add.AddInput(input); - add.AddInput(zero); - output.AddInput(add); - - graph.AddNode(input); - graph.AddNode(zero); - graph.AddNode(add); - graph.AddNode(output); - - var pass = new AlgebraicSimplificationPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - // x + 0 should be simplified to x - Assert.DoesNotContain(graph.Nodes, n => n.Name == "add"); - } - - [Fact(Timeout = 60000)] - public async Task GraphOptimizer_ShouldApplyMultiplePasses() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var conv = new OptimizationNode { OperationType = OperationType.Convolution, Name = "conv" }; - var bn = new OptimizationNode { OperationType = OperationType.BatchNormalization, Name = "bn" }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - conv.AddInput(input); - bn.AddInput(conv); - relu.AddInput(bn); - output.AddInput(relu); - - graph.AddNode(input); - graph.AddNode(conv); - graph.AddNode(bn); - graph.AddNode(relu); - graph.AddNode(output); - - var options = OptimizationOptions.FromLevel(OptimizationLevel.Standard); - var optimizer = new GraphOptimizer(options); - - // Capture original count before optimization (handles in-place mutation) - var originalCount = graph.Nodes.Count; - - // Act - var optimizedGraph = optimizer.Optimize(graph); - - // Assert - Assert.NotNull(optimizedGraph); - // Should have fewer nodes due to fusion - Assert.True(optimizedGraph.Nodes.Count < originalCount); - } - - [Fact(Timeout = 60000)] - public async Task OptimizationOptions_ShouldConfigureCorrectly() - { - // Arrange & Act - var basicOptions = OptimizationOptions.FromLevel(OptimizationLevel.Basic); - var standardOptions = OptimizationOptions.FromLevel(OptimizationLevel.Standard); - var aggressiveOptions = OptimizationOptions.FromLevel(OptimizationLevel.Aggressive); - - // Assert - Assert.False(basicOptions.EnableOperatorFusion); - Assert.True(standardOptions.EnableOperatorFusion); - Assert.True(aggressiveOptions.EnableMemoryReuse); - } - - #region ConstantFoldingPass Tests - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_FoldAdd_ShouldComputeCorrectResult() - { - // Arrange - var graph = new OptimizationGraph(); - - var const1 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 2, 2 }, new Vector(new[] { 1.0, 2.0, 3.0, 4.0 })), - OutputShape = new[] { 2, 2 } - }; - - var const2 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const2", - ConstantValue = new Tensor(new[] { 2, 2 }, new Vector(new[] { 5.0, 6.0, 7.0, 8.0 })), - OutputShape = new[] { 2, 2 } - }; - - var add = new OptimizationNode - { - OperationType = OperationType.Add, - Name = "add", - OutputShape = new[] { 2, 2 } - }; - - var output = new OptimizationNode - { - OperationType = OperationType.Output, - Name = "output" - }; - - add.AddInput(const1); - add.AddInput(const2); - output.AddInput(add); - - graph.AddNode(const1); - graph.AddNode(const2); - graph.AddNode(add); - graph.AddNode(output); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - Assert.DoesNotContain(graph.Nodes, n => n.Name == "add"); - var foldedNode = graph.Nodes.FirstOrDefault(n => n.Name == "add_folded"); - Assert.NotNull(foldedNode); - Assert.Equal(OperationType.Constant, foldedNode.OperationType); - Assert.NotNull(foldedNode.ConstantValue); - Assert.Equal(6.0, foldedNode.ConstantValue[0]); // 1 + 5 - Assert.Equal(8.0, foldedNode.ConstantValue[1]); // 2 + 6 - Assert.Equal(10.0, foldedNode.ConstantValue[2]); // 3 + 7 - Assert.Equal(12.0, foldedNode.ConstantValue[3]); // 4 + 8 - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_FoldSubtract_ShouldComputeCorrectResult() - { - // Arrange - var graph = new OptimizationGraph(); - - var const1 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 3 }, new Vector(new[] { 10.0, 20.0, 30.0 })), - OutputShape = new[] { 3 } - }; - - var const2 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const2", - ConstantValue = new Tensor(new[] { 3 }, new Vector(new[] { 1.0, 2.0, 3.0 })), - OutputShape = new[] { 3 } - }; - - var subtract = new OptimizationNode - { - OperationType = OperationType.Subtract, - Name = "subtract", - OutputShape = new[] { 3 } - }; - - subtract.AddInput(const1); - subtract.AddInput(const2); - - graph.AddNode(const1); - graph.AddNode(const2); - graph.AddNode(subtract); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - var foldedNode = graph.Nodes.FirstOrDefault(n => n.Name == "subtract_folded"); - Assert.NotNull(foldedNode); - Assert.Equal(9.0, foldedNode.ConstantValue[0]); // 10 - 1 - Assert.Equal(18.0, foldedNode.ConstantValue[1]); // 20 - 2 - Assert.Equal(27.0, foldedNode.ConstantValue[2]); // 30 - 3 - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_FoldMultiply_ShouldComputeCorrectResult() - { - // Arrange - var graph = new OptimizationGraph(); - - var const1 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 2.0, 3.0 })), - OutputShape = new[] { 2 } - }; - - var const2 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const2", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 4.0, 5.0 })), - OutputShape = new[] { 2 } - }; - - var multiply = new OptimizationNode - { - OperationType = OperationType.Multiply, - Name = "multiply", - OutputShape = new[] { 2 } - }; - - multiply.AddInput(const1); - multiply.AddInput(const2); - - graph.AddNode(const1); - graph.AddNode(const2); - graph.AddNode(multiply); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - var foldedNode = graph.Nodes.FirstOrDefault(n => n.Name == "multiply_folded"); - Assert.NotNull(foldedNode); - Assert.Equal(8.0, foldedNode.ConstantValue[0]); // 2 * 4 - Assert.Equal(15.0, foldedNode.ConstantValue[1]); // 3 * 5 - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_FoldDivide_ShouldComputeCorrectResult() - { - // Arrange - var graph = new OptimizationGraph(); - - var const1 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 10.0, 20.0 })), - OutputShape = new[] { 2 } - }; - - var const2 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const2", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 2.0, 4.0 })), - OutputShape = new[] { 2 } - }; - - var divide = new OptimizationNode - { - OperationType = OperationType.Divide, - Name = "divide", - OutputShape = new[] { 2 } - }; - - divide.AddInput(const1); - divide.AddInput(const2); - - graph.AddNode(const1); - graph.AddNode(const2); - graph.AddNode(divide); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - var foldedNode = graph.Nodes.FirstOrDefault(n => n.Name == "divide_folded"); - Assert.NotNull(foldedNode); - Assert.Equal(5.0, foldedNode.ConstantValue[0]); // 10 / 2 - Assert.Equal(5.0, foldedNode.ConstantValue[1]); // 20 / 4 - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_FoldPowerWithScalarExponent_ShouldComputeCorrectResult() - { - // Arrange - var graph = new OptimizationGraph(); - - var constNode = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 3 }, new Vector(new[] { 2.0, 3.0, 4.0 })), - OutputShape = new[] { 3 } - }; - - var power = new OptimizationNode - { - OperationType = OperationType.Power, - Name = "power", - OutputShape = new[] { 3 }, - Metadata = new Dictionary { ["exponent"] = 2.0 } - }; - - power.AddInput(constNode); - - graph.AddNode(constNode); - graph.AddNode(power); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - var foldedNode = graph.Nodes.FirstOrDefault(n => n.Name == "power_folded"); - Assert.NotNull(foldedNode); - Assert.Equal(4.0, foldedNode.ConstantValue[0]); // 2^2 - Assert.Equal(9.0, foldedNode.ConstantValue[1]); // 3^2 - Assert.Equal(16.0, foldedNode.ConstantValue[2]); // 4^2 - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_FoldPowerWithTensorExponent_ShouldComputeCorrectResult() - { - // Arrange - var graph = new OptimizationGraph(); - - var baseNode = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "base", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 2.0, 3.0 })), - OutputShape = new[] { 2 } - }; - - var exponentNode = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "exponent", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 3.0, 2.0 })), - OutputShape = new[] { 2 } - }; - - var power = new OptimizationNode - { - OperationType = OperationType.Power, - Name = "power", - OutputShape = new[] { 2 } - }; - - power.AddInput(baseNode); - power.AddInput(exponentNode); - - graph.AddNode(baseNode); - graph.AddNode(exponentNode); - graph.AddNode(power); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - var foldedNode = graph.Nodes.FirstOrDefault(n => n.Name == "power_folded"); - Assert.NotNull(foldedNode); - Assert.Equal(8.0, foldedNode.ConstantValue[0]); // 2^3 - Assert.Equal(9.0, foldedNode.ConstantValue[1]); // 3^2 - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_FoldSqrt_ShouldComputeCorrectResult() - { - // Arrange - var graph = new OptimizationGraph(); - - var constNode = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 3 }, new Vector(new[] { 4.0, 9.0, 16.0 })), - OutputShape = new[] { 3 } - }; - - var sqrt = new OptimizationNode - { - OperationType = OperationType.Sqrt, - Name = "sqrt", - OutputShape = new[] { 3 } - }; - - sqrt.AddInput(constNode); - - graph.AddNode(constNode); - graph.AddNode(sqrt); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - var foldedNode = graph.Nodes.FirstOrDefault(n => n.Name == "sqrt_folded"); - Assert.NotNull(foldedNode); - Assert.Equal(2.0, foldedNode.ConstantValue[0], 5); // sqrt(4) - Assert.Equal(3.0, foldedNode.ConstantValue[1], 5); // sqrt(9) - Assert.Equal(4.0, foldedNode.ConstantValue[2], 5); // sqrt(16) - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_FoldExp_ShouldComputeCorrectResult() - { - // Arrange - var graph = new OptimizationGraph(); - - var constNode = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 0.0, 1.0 })), - OutputShape = new[] { 2 } - }; - - var exp = new OptimizationNode - { - OperationType = OperationType.Exp, - Name = "exp", - OutputShape = new[] { 2 } - }; - - exp.AddInput(constNode); - - graph.AddNode(constNode); - graph.AddNode(exp); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - var foldedNode = graph.Nodes.FirstOrDefault(n => n.Name == "exp_folded"); - Assert.NotNull(foldedNode); - Assert.Equal(1.0, foldedNode.ConstantValue[0], 5); // exp(0) - Assert.Equal(Math.E, foldedNode.ConstantValue[1], 5); // exp(1) - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_FoldLog_ShouldComputeCorrectResult() - { - // Arrange - var graph = new OptimizationGraph(); - - var constNode = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 1.0, Math.E })), - OutputShape = new[] { 2 } - }; - - var log = new OptimizationNode - { - OperationType = OperationType.Log, - Name = "log", - OutputShape = new[] { 2 } - }; - - log.AddInput(constNode); - - graph.AddNode(constNode); - graph.AddNode(log); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - var foldedNode = graph.Nodes.FirstOrDefault(n => n.Name == "log_folded"); - Assert.NotNull(foldedNode); - Assert.Equal(0.0, foldedNode.ConstantValue[0], 5); // log(1) - Assert.Equal(1.0, foldedNode.ConstantValue[1], 5); // log(e) - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_FoldMatMul_ShouldComputeCorrectResult() - { - // Arrange - var graph = new OptimizationGraph(); - - // 2x3 matrix - var const1 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 2, 3 }, new Vector(new[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 })), - OutputShape = new[] { 2, 3 } - }; - - // 3x2 matrix - var const2 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const2", - ConstantValue = new Tensor(new[] { 3, 2 }, new Vector(new[] { 7.0, 8.0, 9.0, 10.0, 11.0, 12.0 })), - OutputShape = new[] { 3, 2 } - }; - - var matmul = new OptimizationNode - { - OperationType = OperationType.MatMul, - Name = "matmul", - OutputShape = new[] { 2, 2 } - }; - - matmul.AddInput(const1); - matmul.AddInput(const2); - - graph.AddNode(const1); - graph.AddNode(const2); - graph.AddNode(matmul); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - var foldedNode = graph.Nodes.FirstOrDefault(n => n.Name == "matmul_folded"); - Assert.NotNull(foldedNode); - Assert.NotNull(foldedNode.ConstantValue); - // Result should be 2x2 matrix - Assert.Equal(2, foldedNode.ConstantValue.Shape[0]); - Assert.Equal(2, foldedNode.ConstantValue.Shape[1]); - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_ShapeMismatch_ShouldReturnNull() - { - // Arrange - var graph = new OptimizationGraph(); - - var const1 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 1.0, 2.0 })), - OutputShape = new[] { 2 } - }; - - var const2 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const2", - ConstantValue = new Tensor(new[] { 3 }, new Vector(new[] { 3.0, 4.0, 5.0 })), - OutputShape = new[] { 3 } - }; - - var add = new OptimizationNode - { - OperationType = OperationType.Add, - Name = "add", - OutputShape = new[] { 2 } - }; - - add.AddInput(const1); - add.AddInput(const2); - - graph.AddNode(const1); - graph.AddNode(const2); - graph.AddNode(add); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.False(modified); // Should not fold due to shape mismatch - Assert.Contains(graph.Nodes, n => n.Name == "add"); // Add node should still exist - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_ApplyIteratesUntilNoChanges() - { - // Arrange - var graph = new OptimizationGraph(); - - // Create a chain: const1 + const2 = intermediate, intermediate + const3 = result - var const1 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const1", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 1.0, 2.0 })), - OutputShape = new[] { 2 } - }; - - var const2 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const2", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 3.0, 4.0 })), - OutputShape = new[] { 2 } - }; - - var const3 = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const3", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 5.0, 6.0 })), - OutputShape = new[] { 2 } - }; - - var add1 = new OptimizationNode - { - OperationType = OperationType.Add, - Name = "add1", - OutputShape = new[] { 2 } - }; - - var add2 = new OptimizationNode - { - OperationType = OperationType.Add, - Name = "add2", - OutputShape = new[] { 2 } - }; - - add1.AddInput(const1); - add1.AddInput(const2); - add2.AddInput(add1); - add2.AddInput(const3); - - graph.AddNode(const1); - graph.AddNode(const2); - graph.AddNode(const3); - graph.AddNode(add1); - graph.AddNode(add2); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - // Both add operations should be folded through iteration - Assert.DoesNotContain(graph.Nodes, n => n.Name == "add1"); - Assert.DoesNotContain(graph.Nodes, n => n.Name == "add2"); - // Should have a final constant result - var constantNodes = graph.Nodes.Where(n => n.OperationType == OperationType.Constant && n.Name.Contains("folded")).ToList(); - Assert.NotEmpty(constantNodes); - } - - [Fact(Timeout = 60000)] - public async Task ConstantFoldingPass_NonConstantInputs_ShouldNotFold() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode - { - OperationType = OperationType.Input, - Name = "input", - OutputShape = new[] { 2 } - }; - - var constNode = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "const", - ConstantValue = new Tensor(new[] { 2 }, new Vector(new[] { 1.0, 2.0 })), - OutputShape = new[] { 2 } - }; - - var add = new OptimizationNode - { - OperationType = OperationType.Add, - Name = "add", - OutputShape = new[] { 2 } - }; - - add.AddInput(input); - add.AddInput(constNode); - - graph.AddNode(input); - graph.AddNode(constNode); - graph.AddNode(add); - - var pass = new ConstantFoldingPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.False(modified); // Should not fold because one input is not constant - Assert.Contains(graph.Nodes, n => n.Name == "add"); - } - - #endregion - - #region ElementwiseFusionPass Tests - - [Fact(Timeout = 60000)] - public async Task ElementwiseFusionPass_FuseTwoElementwiseOps_ShouldCreateFusedNode() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var add = new OptimizationNode { OperationType = OperationType.Add, Name = "add" }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - add.AddInput(input); - relu.AddInput(add); - output.AddInput(relu); - - graph.AddNode(input); - graph.AddNode(add); - graph.AddNode(relu); - graph.AddNode(output); - - var pass = new ElementwiseFusionPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - Assert.DoesNotContain(graph.Nodes, n => n.Name == "add"); - Assert.DoesNotContain(graph.Nodes, n => n.Name == "relu"); - var fusedNode = graph.Nodes.FirstOrDefault(n => n.IsFused && n.Name.Contains("fused")); - Assert.NotNull(fusedNode); - Assert.Equal(OperationType.Custom, fusedNode.OperationType); - } - - [Fact(Timeout = 60000)] - public async Task ElementwiseFusionPass_ChainOfThreeOps_ShouldFuseAll() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var add = new OptimizationNode { OperationType = OperationType.Add, Name = "add" }; - var multiply = new OptimizationNode { OperationType = OperationType.Multiply, Name = "multiply" }; - var sigmoid = new OptimizationNode { OperationType = OperationType.Sigmoid, Name = "sigmoid" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - add.AddInput(input); - multiply.AddInput(add); - sigmoid.AddInput(multiply); - output.AddInput(sigmoid); - - graph.AddNode(input); - graph.AddNode(add); - graph.AddNode(multiply); - graph.AddNode(sigmoid); - graph.AddNode(output); - - var pass = new ElementwiseFusionPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - Assert.DoesNotContain(graph.Nodes, n => n.Name == "add"); - Assert.DoesNotContain(graph.Nodes, n => n.Name == "multiply"); - Assert.DoesNotContain(graph.Nodes, n => n.Name == "sigmoid"); - var fusedNode = graph.Nodes.FirstOrDefault(n => n.IsFused); - Assert.NotNull(fusedNode); - Assert.True(fusedNode.Metadata.ContainsKey("OperationSequence")); - } - - [Fact(Timeout = 60000)] - public async Task ElementwiseFusionPass_IsChainHead_DetectsCorrectly() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var add = new OptimizationNode { OperationType = OperationType.Add, Name = "add" }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - - add.AddInput(input); - relu.AddInput(add); - - graph.AddNode(input); - graph.AddNode(add); - graph.AddNode(relu); - - var pass = new ElementwiseFusionPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.True(modified); - // Add should be identified as chain head since it follows non-elementwise input - var fusedNode = graph.Nodes.FirstOrDefault(n => n.IsFused); - Assert.NotNull(fusedNode); - } - - [Fact(Timeout = 60000)] - public async Task ElementwiseFusionPass_SingleNode_ShouldNotFuse() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - relu.AddInput(input); - output.AddInput(relu); - - graph.AddNode(input); - graph.AddNode(relu); - graph.AddNode(output); - - var pass = new ElementwiseFusionPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.False(modified); // Single node chain should not be fused - Assert.Contains(graph.Nodes, n => n.Name == "relu"); - } - - [Fact(Timeout = 60000)] - public async Task ElementwiseFusionPass_BranchingOutput_ShouldNotFuse() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var add = new OptimizationNode { OperationType = OperationType.Add, Name = "add" }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - var output1 = new OptimizationNode { OperationType = OperationType.Output, Name = "output1" }; - var output2 = new OptimizationNode { OperationType = OperationType.Output, Name = "output2" }; - - add.AddInput(input); - relu.AddInput(add); - output1.AddInput(add); // Branching - add has two consumers - output2.AddInput(relu); - - graph.AddNode(input); - graph.AddNode(add); - graph.AddNode(relu); - graph.AddNode(output1); - graph.AddNode(output2); - - var pass = new ElementwiseFusionPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.False(modified); // Should not fuse due to branching - Assert.Contains(graph.Nodes, n => n.Name == "add"); - Assert.Contains(graph.Nodes, n => n.Name == "relu"); - } - - [Fact(Timeout = 60000)] - public async Task ElementwiseFusionPass_AlreadyFused_ShouldSkip() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - var fusedOp = new OptimizationNode - { - OperationType = OperationType.Add, - Name = "fused", - IsFused = true - }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "output" }; - - fusedOp.AddInput(input); - output.AddInput(fusedOp); - - graph.AddNode(input); - graph.AddNode(fusedOp); - graph.AddNode(output); - - var pass = new ElementwiseFusionPass(); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.False(modified); // Already fused nodes should be skipped - } - - [Fact(Timeout = 60000)] - public async Task ElementwiseFusionPass_CanApply_ReturnsTrueWhenElementwiseOpsPresent() - { - // Arrange - var graph = new OptimizationGraph(); - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - graph.AddNode(relu); - - var pass = new ElementwiseFusionPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.True(canApply); - } - - [Fact(Timeout = 60000)] - public async Task ElementwiseFusionPass_CanApply_ReturnsFalseWhenNoElementwiseOps() - { - // Arrange - var graph = new OptimizationGraph(); - var conv = new OptimizationNode { OperationType = OperationType.Convolution, Name = "conv" }; - graph.AddNode(conv); - - var pass = new ElementwiseFusionPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.False(canApply); - } - - #endregion - - #region LayoutOptimizationPass Tests - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_Constructor_InvalidLayout_ThrowsException() - { - // Arrange & Act & Assert - Assert.Throws(() => new LayoutOptimizationPass("INVALID")); - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_Constructor_ValidLayout_Succeeds() - { - // Arrange & Act - var pass1 = new LayoutOptimizationPass("NCHW"); - var pass2 = new LayoutOptimizationPass("NHWC"); - - // Assert - Assert.NotNull(pass1); - Assert.NotNull(pass2); - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_NCHWToNHWC_InsertsTranspose() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode - { - OperationType = OperationType.Input, - Name = "input", - OutputShape = new[] { 1, 3, 224, 224 } // NCHW - }; - - var conv = new OptimizationNode - { - OperationType = OperationType.Convolution, - Name = "conv", - OutputShape = new[] { 1, 64, 224, 224 } - }; - - var output = new OptimizationNode - { - OperationType = OperationType.Output, - Name = "output" - }; - - conv.AddInput(input); - output.AddInput(conv); - - graph.AddNode(input); - graph.AddNode(conv); - graph.AddNode(output); - - var pass = new LayoutOptimizationPass("NCHW"); - - // Act - var modified = pass.Apply(graph); - - // Assert - In this case, no transpose needed as everything is NCHW-preferring - // The pass analyzes but doesn't insert transpose when layouts already match - Assert.False(modified); - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_GetPreferredLayout_ReturnsNCHWForConvOps() - { - // Arrange - var graph = new OptimizationGraph(); - var conv = new OptimizationNode - { - OperationType = OperationType.Convolution, - Name = "conv", - OutputShape = new[] { 1, 64, 224, 224 } - }; - graph.AddNode(conv); - - var pass = new LayoutOptimizationPass("NCHW"); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.True(canApply); // Can apply when convolution ops are present - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_GetPreferredLayout_ReturnsAgnosticForOtherOps() - { - // Arrange - var graph = new OptimizationGraph(); - var relu = new OptimizationNode - { - OperationType = OperationType.ReLU, - Name = "relu", - OutputShape = new[] { 1, 64, 224, 224 } - }; - graph.AddNode(relu); - - var pass = new LayoutOptimizationPass("NCHW"); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.False(canApply); // Cannot apply without conv ops - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_IdentityPermutation_SameLayout() - { - // Arrange - var graph = new OptimizationGraph(); - - var conv1 = new OptimizationNode - { - OperationType = OperationType.Convolution, - Name = "conv1", - OutputShape = new[] { 1, 64, 224, 224 } - }; - - var conv2 = new OptimizationNode - { - OperationType = OperationType.Convolution, - Name = "conv2", - OutputShape = new[] { 1, 128, 224, 224 } - }; - - conv2.AddInput(conv1); - - graph.AddNode(conv1); - graph.AddNode(conv2); - - var pass = new LayoutOptimizationPass("NCHW"); - - // Act - var modified = pass.Apply(graph); - - // Assert - Assert.False(modified); // No transpose needed when both prefer same layout - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_RequiresLayoutConversion_DetectsMismatch() - { - // Arrange - var graph = new OptimizationGraph(); - - // This test verifies the detection logic for layout mismatches - var input = new OptimizationNode - { - OperationType = OperationType.Input, - Name = "input", - OutputShape = new[] { 1, 224, 224, 3 } // NHWC format - }; - - var conv = new OptimizationNode - { - OperationType = OperationType.Convolution, - Name = "conv", - OutputShape = new[] { 1, 64, 224, 224 } // Expects NCHW - }; - - conv.AddInput(input); - - graph.AddNode(input); - graph.AddNode(conv); - - var pass = new LayoutOptimizationPass("NCHW"); - - // Act - var modified = pass.Apply(graph); - - // Assert - // Input is agnostic, conv prefers NCHW, so no conversion needed in this simple case - Assert.False(modified); - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_NoMismatch_NoTransposeInserted() - { - // Arrange - Input (AGNOSTIC) -> Conv (NCHW): No mismatch since Input is AGNOSTIC - var graph = new OptimizationGraph(); - - var input = new OptimizationNode - { - OperationType = OperationType.Input, - Name = "input", - OutputShape = new[] { 1, 3, 224, 224 } - }; - - var conv = new OptimizationNode - { - OperationType = OperationType.Convolution, - Name = "conv", - OutputShape = new[] { 1, 64, 224, 224 } - }; - - conv.AddInput(input); - - graph.AddNode(input); - graph.AddNode(conv); - - var pass = new LayoutOptimizationPass("NCHW"); - - // Act - var modified = pass.Apply(graph); - - // Assert - No transpose should be inserted since Input is layout-agnostic - Assert.False(modified); - var transposeNodes = graph.Nodes.Where(n => n.OperationType == OperationType.Transpose).ToList(); - Assert.Empty(transposeNodes); - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_TransposeMetadata_HasCorrectFields() - { - // This test verifies that when a transpose is created (via InsertLayoutConversion), - // it has all the expected metadata fields. We test this by examining the internal - // structure that would be created. - - // Arrange - Create a scenario that forces transpose insertion by having - // two connected conv ops and manually verifying the transpose node structure - var graph = new OptimizationGraph(); - - var conv1 = new OptimizationNode - { - OperationType = OperationType.Convolution, - Name = "conv1", - OutputShape = new[] { 1, 64, 224, 224 } - }; - - var bn = new OptimizationNode - { - OperationType = OperationType.BatchNormalization, - Name = "bn", - OutputShape = new[] { 1, 64, 224, 224 } - }; - - var conv2 = new OptimizationNode - { - OperationType = OperationType.Convolution, - Name = "conv2", - OutputShape = new[] { 1, 128, 224, 224 } - }; - - bn.AddInput(conv1); - conv2.AddInput(bn); - - graph.AddNode(conv1); - graph.AddNode(bn); - graph.AddNode(conv2); - - var pass = new LayoutOptimizationPass("NCHW"); - - // Act - var modified = pass.Apply(graph); - - // Assert - All NCHW-preferring ops so no mismatch, but verify structure - // When transposes ARE inserted, they have LayoutConversion, SourceLayout, - // TargetLayout, and Permutation metadata - var transposeNodes = graph.Nodes.Where(n => n.OperationType == OperationType.Transpose).ToList(); - foreach (var transpose in transposeNodes) - { - Assert.True(transpose.Metadata.ContainsKey("LayoutConversion")); - Assert.True((bool)transpose.Metadata["LayoutConversion"]); - Assert.True(transpose.Metadata.ContainsKey("SourceLayout")); - Assert.True(transpose.Metadata.ContainsKey("TargetLayout")); - Assert.True(transpose.Metadata.ContainsKey("Permutation")); - var perm = (int[])transpose.Metadata["Permutation"]; - Assert.Equal(4, perm.Length); - } - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_ComputeTransposedShape_HandlesNon4DTensors() - { - // Arrange - Test that non-4D tensors are returned unchanged - var graph = new OptimizationGraph(); - - // Use a 3D tensor (sequence data) - should not trigger layout conversion - var input = new OptimizationNode - { - OperationType = OperationType.Input, - Name = "input", - OutputShape = new[] { 1, 100, 512 } // 3D: [batch, sequence, features] - }; - - var conv = new OptimizationNode - { - OperationType = OperationType.Convolution, - Name = "conv", - OutputShape = new[] { 1, 100, 256 } - }; - - conv.AddInput(input); - - graph.AddNode(input); - graph.AddNode(conv); - - var pass = new LayoutOptimizationPass("NCHW"); - - // Act - var modified = pass.Apply(graph); - - // Assert - Non-4D tensors should not cause layout conversion issues - Assert.False(modified); - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_ComputeTransposedShape_Handles5DTensors() - { - // Arrange - Test that 5D tensors are handled correctly (returned unchanged) - var graph = new OptimizationGraph(); - - // 5D tensor (video data) - should not trigger layout conversion - var input = new OptimizationNode - { - OperationType = OperationType.Input, - Name = "input", - OutputShape = new[] { 1, 3, 16, 224, 224 } // 5D: [batch, channels, time, height, width] - }; - - var conv = new OptimizationNode - { - OperationType = OperationType.Convolution, - Name = "conv", - OutputShape = new[] { 1, 64, 16, 224, 224 } - }; - - conv.AddInput(input); - - graph.AddNode(input); - graph.AddNode(conv); - - var pass = new LayoutOptimizationPass("NCHW"); - - // Act - Should not throw even with 5D tensors - var modified = pass.Apply(graph); - - // Assert - 5D tensors are skipped for layout conversion - Assert.False(modified); - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_CanApply_ReturnsTrueWhenConvPresent() - { - // Arrange - var graph = new OptimizationGraph(); - var conv = new OptimizationNode { OperationType = OperationType.Convolution, Name = "conv" }; - graph.AddNode(conv); - - var pass = new LayoutOptimizationPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.True(canApply); - } - - [Fact(Timeout = 60000)] - public async Task LayoutOptimizationPass_CanApply_ReturnsFalseWithoutConv() - { - // Arrange - var graph = new OptimizationGraph(); - var relu = new OptimizationNode { OperationType = OperationType.ReLU, Name = "relu" }; - graph.AddNode(relu); - - var pass = new LayoutOptimizationPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.False(canApply); - } - - #endregion - - #region MatMulBiasActivationFusionPass Tests - - [Fact(Timeout = 60000)] - public async Task MatMulBiasActivationFusionPass_DetectsFusionPattern() - { - // Arrange - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input, Name = "input" }; - - var matmul = new OptimizationNode - { - OperationType = OperationType.MatMul, - Name = "matmul", - OutputShape = new[] { 1, 128 } - }; - - var bias = new OptimizationNode - { - OperationType = OperationType.Constant, - Name = "bias", - ConstantValue = new Tensor(new[] { 128 }, new Vector(new double[128])) - }; - - var add = new OptimizationNode - { - OperationType = OperationType.Add, - Name = "add", - OutputShape = new[] { 1, 128 } - }; - - var relu = new OptimizationNode - { - OperationType = OperationType.ReLU, - Name = "relu", - OutputShape = new[] { 1, 128 } - }; - - matmul.AddInput(input); - add.AddInput(matmul); - add.AddInput(bias); - relu.AddInput(add); - - graph.AddNode(input); - graph.AddNode(matmul); - graph.AddNode(bias); - graph.AddNode(add); - graph.AddNode(relu); - - var pass = new MatMulBiasActivationFusionPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.True(canApply); - // Note: We test detection only due to known bug in implementation - // The implementation needs .ToList() on activation.Outputs iterations - } - - [Fact(Timeout = 60000)] - public async Task MatMulBiasActivationFusionPass_SupportsGELUActivation() - { - // Arrange - var graph = new OptimizationGraph(); - - var matmul = new OptimizationNode - { - OperationType = OperationType.MatMul, - Name = "matmul" - }; - - var add = new OptimizationNode - { - OperationType = OperationType.Add, - Name = "add" - }; - - var gelu = new OptimizationNode - { - OperationType = OperationType.GELU, - Name = "gelu" - }; - - graph.AddNode(matmul); - graph.AddNode(add); - graph.AddNode(gelu); - - var pass = new MatMulBiasActivationFusionPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.True(canApply); // Can apply when MatMul is present - } - - [Fact(Timeout = 60000)] - public async Task MatMulBiasActivationFusionPass_RequiresConstantBias() - { - // Arrange - Test that non-constant bias prevents fusion - var graph = new OptimizationGraph(); - - var matmul = new OptimizationNode - { - OperationType = OperationType.MatMul, - Name = "matmul", - OutputShape = new[] { 1, 128 } - }; - - var bias = new OptimizationNode - { - OperationType = OperationType.Input, // Not constant! - Name = "bias", - OutputShape = new[] { 128 } - }; - - graph.AddNode(matmul); - graph.AddNode(bias); - - var pass = new MatMulBiasActivationFusionPass(); - - // Act & Assert - Assert.True(pass.CanApply(graph)); // Can apply to graph with MatMul - // Note: Actual fusion would check for constant bias and skip non-constant - } - - [Fact(Timeout = 60000)] - public async Task MatMulBiasActivationFusionPass_SupportsFusedMatMulBias() - { - // Arrange - var graph = new OptimizationGraph(); - - var fusedMatMulBias = new OptimizationNode - { - OperationType = OperationType.FusedMatMulBias, - Name = "fused_matmul_bias" - }; - - graph.AddNode(fusedMatMulBias); - - var pass = new MatMulBiasActivationFusionPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.True(canApply); // Can apply when FusedMatMulBias is present - } - - [Fact(Timeout = 60000)] - public async Task MatMulBiasActivationFusionPass_SupportsDenseOperation() - { - // Arrange - var graph = new OptimizationGraph(); - - var dense = new OptimizationNode - { - OperationType = OperationType.Dense, - Name = "dense" - }; - - graph.AddNode(dense); - - var pass = new MatMulBiasActivationFusionPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.True(canApply); // Can apply when Dense is present - } - - [Fact(Timeout = 60000)] - public async Task MatMulBiasActivationFusionPass_CanApply_ReturnsTrueWhenMatMulPresent() - { - // Arrange - var graph = new OptimizationGraph(); - var matmul = new OptimizationNode { OperationType = OperationType.MatMul, Name = "matmul" }; - graph.AddNode(matmul); - - var pass = new MatMulBiasActivationFusionPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.True(canApply); - } - - [Fact(Timeout = 60000)] - public async Task MatMulBiasActivationFusionPass_CanApply_ReturnsTrueWhenFusedMatMulBiasPresent() - { - // Arrange - var graph = new OptimizationGraph(); - var fused = new OptimizationNode { OperationType = OperationType.FusedMatMulBias, Name = "fused" }; - graph.AddNode(fused); - - var pass = new MatMulBiasActivationFusionPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.True(canApply); - } - - [Fact(Timeout = 60000)] - public async Task MatMulBiasActivationFusionPass_CanApply_ReturnsFalseWithoutMatMul() - { - // Arrange - var graph = new OptimizationGraph(); - var conv = new OptimizationNode { OperationType = OperationType.Convolution, Name = "conv" }; - graph.AddNode(conv); - - var pass = new MatMulBiasActivationFusionPass(); - - // Act - var canApply = pass.CanApply(graph); - - // Assert - Assert.False(canApply); - } - - #endregion -} diff --git a/tests/AiDotNet.Tests/InferenceOptimization/SimdKernelsTests.cs b/tests/AiDotNet.Tests/InferenceOptimization/SimdKernelsTests.cs deleted file mode 100644 index ce1953964b..0000000000 --- a/tests/AiDotNet.Tests/InferenceOptimization/SimdKernelsTests.cs +++ /dev/null @@ -1,170 +0,0 @@ -using System; -using AiDotNet.Tensors.Engines.Simd; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.InferenceOptimization; - -public class SimdKernelsTests -{ - [Fact(Timeout = 60000)] - public async Task VectorAdd_MatchesScalar() - { - await Task.Run(() => - { - var a = CreateInput(32, 1); - var b = CreateInput(32, 17); - var result = new float[a.Length]; - var expected = new float[a.Length]; - - for (int i = 0; i < a.Length; i++) - { - expected[i] = a[i] + b[i]; - } - - SimdKernels.VectorAdd(a, b, result); - - AssertEqual(expected, result); - }); - } - - [Fact(Timeout = 60000)] - public async Task VectorMultiply_MatchesScalar() - { - await Task.Run(() => - { - var a = CreateInput(32, 3); - var b = CreateInput(32, 9); - var result = new float[a.Length]; - var expected = new float[a.Length]; - - for (int i = 0; i < a.Length; i++) - { - expected[i] = a[i] * b[i]; - } - - SimdKernels.VectorMultiply(a, b, result); - - AssertEqual(expected, result); - }); - } - - [Fact(Timeout = 60000)] - public async Task DotProduct_MatchesScalar() - { - await Task.Run(() => - { - var a = CreateInput(37, 5); - var b = CreateInput(37, 11); - - float expected = 0f; - for (int i = 0; i < a.Length; i++) - { - expected += a[i] * b[i]; - } - - float actual = SimdKernels.DotProduct(a, b); - Assert.Equal(expected, actual, 5); - }); - } - - [Fact(Timeout = 60000)] - public async Task ScalarMultiplyAdd_MatchesScalar() - { - await Task.Run(() => - { - var a = CreateInput(31, 7); - var b = CreateInput(31, 13); - var result = new float[a.Length]; - var expected = new float[a.Length]; - - float scalar = 0.25f; - for (int i = 0; i < a.Length; i++) - { - expected[i] = a[i] + scalar * b[i]; - } - - SimdKernels.ScalarMultiplyAdd(a, b, scalar, result); - - AssertEqual(expected, result); - }); - } - - [Fact(Timeout = 60000)] - public async Task ReLU_MatchesScalar() - { - await Task.Run(() => - { - var input = CreateSignedInput(33); - var output = new float[input.Length]; - var expected = new float[input.Length]; - - for (int i = 0; i < input.Length; i++) - { - expected[i] = Math.Max(0f, input[i]); - } - - SimdKernels.ReLU(input, output); - - AssertEqual(expected, output); - }); - } - - [Fact(Timeout = 60000)] - public async Task Sum_MatchesScalar() - { - await Task.Run(() => - { - var input = CreateInput(100, 23); - float expected = 0f; - for (int i = 0; i < input.Length; i++) - { - expected += input[i]; - } - - float actual = SimdKernels.Sum(input); - Assert.Equal(expected, actual, 5); - }); - } - - private static float[] CreateInput(int length, int seed) - { - var data = new float[length]; - for (int i = 0; i < length; i++) - { - data[i] = DeterministicValue(i + seed); - } - - return data; - } - - private static float[] CreateSignedInput(int length) - { - var data = new float[length]; - for (int i = 0; i < length; i++) - { - float v = DeterministicValue(i); - data[i] = (i % 2 == 0) ? v : -v; - } - - return data; - } - - private static float DeterministicValue(int i) - { - unchecked - { - uint x = (uint)(i * 1664525 + 1013904223); - return (x & 0x00FFFFFF) / 16777216f; - } - } - - private static void AssertEqual(float[] expected, float[] actual) - { - Assert.Equal(expected.Length, actual.Length); - for (int i = 0; i < expected.Length; i++) - { - Assert.Equal(expected[i], actual[i], 5); - } - } -} diff --git a/tests/AiDotNet.Tests/IntegrationTests/ActiveLearning/ActiveLearningDeepMathIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/ActiveLearning/ActiveLearningDeepMathIntegrationTests.cs index 36aa23d2c8..fc5cdf8b06 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/ActiveLearning/ActiveLearningDeepMathIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/ActiveLearning/ActiveLearningDeepMathIntegrationTests.cs @@ -773,9 +773,6 @@ public void SetParameters(Vector parameters) { } public Vector ComputeGradients(Tensor input, Tensor target, ILossFunction? lossFunction = null) => new(ParameterCount); public void ApplyGradients(Vector gradients, double learningRate) { } - public ComputationNode ExportComputationGraph(List> inputNodes) - => throw new NotSupportedException(); - public bool SupportsJitCompilation => false; public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/IntegrationTests/ActiveLearning/ActiveLearningIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/ActiveLearning/ActiveLearningIntegrationTests.cs index d71c58e7b9..8f28de13f2 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/ActiveLearning/ActiveLearningIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/ActiveLearning/ActiveLearningIntegrationTests.cs @@ -1022,11 +1022,6 @@ public void SetParameters(Vector parameters) { } public Vector ComputeGradients(Tensor input, Tensor target, ILossFunction? lossFunction = null) => new(0); public void ApplyGradients(Vector gradients, double learningRate) { } - // IJitCompilable - public ComputationNode ExportComputationGraph(List> inputNodes) - => throw new NotSupportedException("Mock model does not support JIT compilation"); - public bool SupportsJitCompilation => false; - // IFullModel specific public ILossFunction DefaultLossFunction => new MeanSquaredErrorLoss(); diff --git a/tests/AiDotNet.Tests/IntegrationTests/AdversarialRobustness/AdversarialRobustnessDeepMathIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/AdversarialRobustness/AdversarialRobustnessDeepMathIntegrationTests.cs index 473ba00454..6c47c34fa5 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/AdversarialRobustness/AdversarialRobustnessDeepMathIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/AdversarialRobustness/AdversarialRobustnessDeepMathIntegrationTests.cs @@ -57,7 +57,6 @@ public ARMockModel(int inputSize, int numClasses, double[] weights = null) public ILossFunction DefaultLossFunction => null; public int ParameterCount => _weights.Length; public bool SupportsParameterInitialization => ParameterCount > 0; - public bool SupportsJitCompilation => false; public Vector Predict(Vector input) { @@ -126,10 +125,6 @@ public void ApplyGradients(Vector gradients, double learningRate) for (int i = 0; i < _weights.Length; i++) _weights[i] -= learningRate * gradients[i]; } - public AiDotNet.Autodiff.ComputationNode ExportComputationGraph( - List> inputNodes) - => throw new NotImplementedException(); - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/IntegrationTests/AdversarialRobustness/AdversarialRobustnessIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/AdversarialRobustness/AdversarialRobustnessIntegrationTests.cs index e1afc82922..1d221b04d7 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/AdversarialRobustness/AdversarialRobustnessIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/AdversarialRobustness/AdversarialRobustnessIntegrationTests.cs @@ -1849,7 +1849,6 @@ public MockClassificationModel(int inputSize, int numClasses, int seed) public ILossFunction? DefaultLossFunction => null; public int ParameterCount => _weights.Length; public bool SupportsParameterInitialization => ParameterCount > 0; - public bool SupportsJitCompilation => false; public Vector Predict(Vector input) { @@ -1949,11 +1948,6 @@ public void ApplyGradients(Vector gradients, double learningRate) } } - public AiDotNet.Autodiff.ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotImplementedException(); - } - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/IntegrationTests/ContinualLearning/ContinualLearningDeepMathIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/ContinualLearning/ContinualLearningDeepMathIntegrationTests.cs index 1c965d5da5..3925cf161e 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/ContinualLearning/ContinualLearningDeepMathIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/ContinualLearning/ContinualLearningDeepMathIntegrationTests.cs @@ -905,9 +905,6 @@ public IFullModel, Tensor> DeepCopy() => public Vector ComputeGradients(Tensor input, Tensor target, ILossFunction? lossFunction = null) => new(_parameters.Length); public void ApplyGradients(Vector gradients, double learningRate) { } - public ComputationNode ExportComputationGraph(List> inputNodes) - => throw new NotSupportedException(); - public bool SupportsJitCompilation => false; public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/IntegrationTests/DistributedTraining/DistributedTrainingIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/DistributedTraining/DistributedTrainingIntegrationTests.cs index 295324a006..f03d9ee722 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/DistributedTraining/DistributedTrainingIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/DistributedTraining/DistributedTrainingIntegrationTests.cs @@ -2253,21 +2253,6 @@ public Dictionary GetFeatureImportance() .ToDictionary(i => $"feature_{i}", i => 1.0 / _parameterCount); } - public ComputationNode ExportComputationGraph(List> inputNodes) - { - var node = new ComputationNode( - new Tensor(new[] { _parameterCount }), - false, - null, - null, - "mock_graph" - ); - inputNodes.Add(node); - return node; - } - - public bool SupportsJitCompilation => false; - public IFullModel, Vector> Clone() { var cloned = new MockDistributedModel(_parameterCount); diff --git a/tests/AiDotNet.Tests/IntegrationTests/DistributedTraining/PipelineParallelismIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/DistributedTraining/PipelineParallelismIntegrationTests.cs index 486dd6c387..e692ac280d 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/DistributedTraining/PipelineParallelismIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/DistributedTraining/PipelineParallelismIntegrationTests.cs @@ -1278,7 +1278,6 @@ public PipelineTestModel(int parameterCount) public int ParameterCount => _parameterCount; public bool SupportsParameterInitialization => ParameterCount > 0; public ILossFunction DefaultLossFunction => new MeanSquaredErrorLoss(); - public bool SupportsJitCompilation => false; public Vector Predict(Vector input) { @@ -1398,15 +1397,6 @@ public void SetActiveFeatureIndices(IEnumerable indices) { } public Dictionary GetFeatureImportance() => Enumerable.Range(0, _parameterCount).ToDictionary(i => $"f_{i}", i => 1.0 / _parameterCount); - public ComputationNode ExportComputationGraph(List> inputNodes) - { - var node = new ComputationNode( - new Tensor(new[] { _parameterCount }), - false, null, null, "test_graph"); - inputNodes.Add(node); - return node; - } - public Vector SanitizeParameters(Vector parameters) => parameters; } @@ -1433,7 +1423,6 @@ public DecomposablePipelineTestModel(int parameterCount) public int ParameterCount => _parameterCount; public bool SupportsParameterInitialization => ParameterCount > 0; public ILossFunction DefaultLossFunction => new MeanSquaredErrorLoss(); - public bool SupportsJitCompilation => false; public Vector Predict(Vector input) { @@ -1568,15 +1557,6 @@ public void SetActiveFeatureIndices(IEnumerable indices) { } public Dictionary GetFeatureImportance() => Enumerable.Range(0, _parameterCount).ToDictionary(i => $"f_{i}", i => 1.0 / _parameterCount); - public ComputationNode ExportComputationGraph(List> inputNodes) - { - var node = new ComputationNode( - new Tensor(new[] { _parameterCount }), - false, null, null, "decomposable_test_graph"); - inputNodes.Add(node); - return node; - } - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/IntegrationTests/FederatedLearning/FederatedLearningDeepMathIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/FederatedLearning/FederatedLearningDeepMathIntegrationTests.cs index c1808f1538..5738dec1b3 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/FederatedLearning/FederatedLearningDeepMathIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/FederatedLearning/FederatedLearningDeepMathIntegrationTests.cs @@ -874,11 +874,6 @@ public Vector ComputeGradients(double[] input, double[] target, ILossFun => new(_parameters.Length); public void ApplyGradients(Vector gradients, double learningRate) { } - // IJitCompilable - public ComputationNode ExportComputationGraph(List> inputNodes) - => throw new NotSupportedException(); - public bool SupportsJitCompilation => false; - // IModelSerializer public byte[] Serialize() => Array.Empty(); public void Deserialize(byte[] data) { } diff --git a/tests/AiDotNet.Tests/IntegrationTests/FineTuning/FineTuningIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/FineTuning/FineTuningIntegrationTests.cs index 15bd2c0fa0..189b62db39 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/FineTuning/FineTuningIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/FineTuning/FineTuningIntegrationTests.cs @@ -160,16 +160,6 @@ public void ApplyGradients(Vector gradients, double learningRate) } } - // IJitCompilable - public bool IsJitCompiled => false; - public bool SupportsJitCompilation => false; - public void CompileForJit() { } - public void ResetJitCompilation() { } - public AiDotNet.Autodiff.ComputationNode ExportComputationGraph(List> nodes) - { - return new AiDotNet.Autodiff.ComputationNode(new Tensor(new[] { 1 }), false, null, null, "mock"); - } - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/IntegrationTests/InferenceOptimization/InferenceOptimizationDeepMathIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/InferenceOptimization/InferenceOptimizationDeepMathIntegrationTests.cs deleted file mode 100644 index 43dc60e439..0000000000 --- a/tests/AiDotNet.Tests/IntegrationTests/InferenceOptimization/InferenceOptimizationDeepMathIntegrationTests.cs +++ /dev/null @@ -1,692 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.IntegrationTests.InferenceOptimization; - -/// -/// Deep integration tests for OptimizationGraph and OptimizationNode: -/// graph construction, topological ordering, cycle detection, node operations, -/// statistics, cloning, validation, and graph topology invariants. -/// -public class InferenceOptimizationDeepMathIntegrationTests -{ - // ============================ - // Helper Methods - // ============================ - - private static OptimizationNode CreateNode(string name, OperationType opType = OperationType.MatMul) - { - return new OptimizationNode - { - Name = name, - OperationType = opType, - OutputShape = new[] { 1, 10 } - }; - } - - /// - /// Creates a simple chain: Input -> MatMul -> ReLU -> Output - /// - private static OptimizationGraph CreateChainGraph() - { - var graph = new OptimizationGraph(); - - var input = CreateNode("input", OperationType.Input); - var matmul = CreateNode("matmul", OperationType.MatMul); - var relu = CreateNode("relu", OperationType.ReLU); - var output = CreateNode("output", OperationType.Output); - - matmul.AddInput(input); - relu.AddInput(matmul); - output.AddInput(relu); - - graph.AddNode(input); - graph.AddNode(matmul); - graph.AddNode(relu); - graph.AddNode(output); - - return graph; - } - - /// - /// Creates a diamond graph: Input -> {A, B} -> Output - /// - private static OptimizationGraph CreateDiamondGraph() - { - var graph = new OptimizationGraph(); - - var input = CreateNode("input", OperationType.Input); - var branchA = CreateNode("branchA", OperationType.MatMul); - var branchB = CreateNode("branchB", OperationType.Convolution); - var merge = CreateNode("merge", OperationType.Add); - var output = CreateNode("output", OperationType.Output); - - branchA.AddInput(input); - branchB.AddInput(input); - merge.AddInput(branchA); - merge.AddInput(branchB); - output.AddInput(merge); - - graph.AddNode(input); - graph.AddNode(branchA); - graph.AddNode(branchB); - graph.AddNode(merge); - graph.AddNode(output); - - return graph; - } - - // ============================ - // Graph Construction Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task EmptyGraph_HasNoNodes() - { - var graph = new OptimizationGraph(); - - Assert.Empty(graph.Nodes); - Assert.Empty(graph.InputNodes); - Assert.Empty(graph.OutputNodes); - } - - [Fact(Timeout = 120000)] - public async Task AddNode_IncreasesNodeCount() - { - var graph = new OptimizationGraph(); - var node = CreateNode("test"); - - graph.AddNode(node); - - Assert.Single(graph.Nodes); - } - - [Fact(Timeout = 120000)] - public async Task AddNode_InputType_TrackedAsInput() - { - var graph = new OptimizationGraph(); - var input = CreateNode("input", OperationType.Input); - - graph.AddNode(input); - - Assert.Single(graph.InputNodes); - Assert.Contains(input, graph.InputNodes); - } - - [Fact(Timeout = 120000)] - public async Task AddNode_OutputType_TrackedAsOutput() - { - var graph = new OptimizationGraph(); - var output = CreateNode("output", OperationType.Output); - - graph.AddNode(output); - - Assert.Single(graph.OutputNodes); - Assert.Contains(output, graph.OutputNodes); - } - - [Fact(Timeout = 120000)] - public async Task AddNode_DuplicateId_NotAdded() - { - var graph = new OptimizationGraph(); - var node = CreateNode("test"); - - graph.AddNode(node); - graph.AddNode(node); // Same node (same Id) - - Assert.Single(graph.Nodes); - } - - [Fact(Timeout = 120000)] - public async Task AddNode_NullNode_Throws() - { - var graph = new OptimizationGraph(); - Assert.Throws(() => graph.AddNode(null!)); - } - - // ============================ - // Chain Graph Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task ChainGraph_NodeCount() - { - var graph = CreateChainGraph(); - - Assert.Equal(4, graph.Nodes.Count); - Assert.Single(graph.InputNodes); - Assert.Single(graph.OutputNodes); - } - - [Fact(Timeout = 120000)] - public async Task ChainGraph_TopologicalOrder_InputFirst() - { - var graph = CreateChainGraph(); - var order = graph.GetTopologicalOrder(); - - Assert.Equal(4, order.Count); - // Input should come before matmul, matmul before relu, relu before output - var inputIdx = order.FindIndex(n => n.Name == "input"); - var matmulIdx = order.FindIndex(n => n.Name == "matmul"); - var reluIdx = order.FindIndex(n => n.Name == "relu"); - var outputIdx = order.FindIndex(n => n.Name == "output"); - - Assert.True(inputIdx < matmulIdx); - Assert.True(matmulIdx < reluIdx); - Assert.True(reluIdx < outputIdx); - } - - [Fact(Timeout = 120000)] - public async Task ChainGraph_Validates() - { - var graph = CreateChainGraph(); - Assert.True(graph.Validate()); - } - - // ============================ - // Diamond Graph Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task DiamondGraph_NodeCount() - { - var graph = CreateDiamondGraph(); - - Assert.Equal(5, graph.Nodes.Count); - Assert.Single(graph.InputNodes); - Assert.Single(graph.OutputNodes); - } - - [Fact(Timeout = 120000)] - public async Task DiamondGraph_TopologicalOrder_Valid() - { - var graph = CreateDiamondGraph(); - var order = graph.GetTopologicalOrder(); - - Assert.Equal(5, order.Count); - - var inputIdx = order.FindIndex(n => n.Name == "input"); - var branchAIdx = order.FindIndex(n => n.Name == "branchA"); - var branchBIdx = order.FindIndex(n => n.Name == "branchB"); - var mergeIdx = order.FindIndex(n => n.Name == "merge"); - var outputIdx = order.FindIndex(n => n.Name == "output"); - - // Input must come before both branches - Assert.True(inputIdx < branchAIdx); - Assert.True(inputIdx < branchBIdx); - - // Both branches must come before merge - Assert.True(branchAIdx < mergeIdx); - Assert.True(branchBIdx < mergeIdx); - - // Merge must come before output - Assert.True(mergeIdx < outputIdx); - } - - [Fact(Timeout = 120000)] - public async Task DiamondGraph_Validates() - { - var graph = CreateDiamondGraph(); - Assert.True(graph.Validate()); - } - - // ============================ - // Node Connection Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task AddInput_CreatesEdgeBothWays() - { - var a = CreateNode("a"); - var b = CreateNode("b"); - - b.AddInput(a); - - Assert.Contains(a, b.Inputs); - Assert.Contains(b, a.Outputs); - } - - [Fact(Timeout = 120000)] - public async Task AddInput_DuplicateIgnored() - { - var a = CreateNode("a"); - var b = CreateNode("b"); - - b.AddInput(a); - b.AddInput(a); // duplicate - - Assert.Single(b.Inputs); - Assert.Single(a.Outputs); - } - - [Fact(Timeout = 120000)] - public async Task RemoveInput_CleansUpBothWays() - { - var a = CreateNode("a"); - var b = CreateNode("b"); - - b.AddInput(a); - b.RemoveInput(a); - - Assert.Empty(b.Inputs); - Assert.Empty(a.Outputs); - } - - [Fact(Timeout = 120000)] - public async Task ReplaceInput_SwapsEdge() - { - var a = CreateNode("a"); - var b = CreateNode("b"); - var c = CreateNode("c"); - - c.AddInput(a); - c.ReplaceInput(a, b); - - Assert.Contains(b, c.Inputs); - Assert.DoesNotContain(a, c.Inputs); - Assert.Contains(c, b.Outputs); - Assert.DoesNotContain(c, a.Outputs); - } - - [Fact(Timeout = 120000)] - public async Task AddInput_NullNode_Throws() - { - var node = CreateNode("test"); - Assert.Throws(() => node.AddInput(null!)); - } - - // ============================ - // Node Query Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task HasConsumers_WithOutputs_True() - { - var a = CreateNode("a"); - var b = CreateNode("b"); - - b.AddInput(a); - - Assert.True(a.HasConsumers()); - } - - [Fact(Timeout = 120000)] - public async Task HasConsumers_WithoutOutputs_False() - { - var a = CreateNode("a"); - Assert.False(a.HasConsumers()); - } - - [Fact(Timeout = 120000)] - public async Task ConsumerCount_HandVerified() - { - var a = CreateNode("a"); - var b = CreateNode("b"); - var c = CreateNode("c"); - - b.AddInput(a); - c.AddInput(a); - - Assert.Equal(2, a.ConsumerCount()); - } - - // ============================ - // FindNode Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task FindNodeById_ExistingNode_Found() - { - var graph = new OptimizationGraph(); - var node = CreateNode("test"); - - graph.AddNode(node); - - var found = graph.FindNodeById(node.Id); - Assert.NotNull(found); - Assert.Equal("test", found.Name); - } - - [Fact(Timeout = 120000)] - public async Task FindNodeById_NonExisting_ReturnsNull() - { - var graph = new OptimizationGraph(); - var found = graph.FindNodeById("nonexistent"); - Assert.Null(found); - } - - [Fact(Timeout = 120000)] - public async Task FindNodeById_NullId_Throws() - { - var graph = new OptimizationGraph(); - Assert.Throws(() => graph.FindNodeById(null!)); - } - - [Fact(Timeout = 120000)] - public async Task FindNodesByName_ExistingName_Found() - { - var graph = new OptimizationGraph(); - var node1 = CreateNode("conv1", OperationType.Convolution); - var node2 = CreateNode("conv1", OperationType.Convolution); - - graph.AddNode(node1); - graph.AddNode(node2); - - var found = graph.FindNodesByName("conv1"); - Assert.Equal(2, found.Count); - } - - [Fact(Timeout = 120000)] - public async Task FindNodesByName_NonExisting_Empty() - { - var graph = new OptimizationGraph(); - var found = graph.FindNodesByName("nonexistent"); - Assert.Empty(found); - } - - // ============================ - // Remove Node Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task RemoveNode_DecreasesCount() - { - var graph = new OptimizationGraph(); - var node = CreateNode("test"); - - graph.AddNode(node); - Assert.Single(graph.Nodes); - - graph.RemoveNode(node); - Assert.Empty(graph.Nodes); - } - - [Fact(Timeout = 120000)] - public async Task RemoveNode_CleansUpConnections() - { - var graph = new OptimizationGraph(); - var a = CreateNode("a", OperationType.Input); - var b = CreateNode("b", OperationType.MatMul); - var c = CreateNode("c", OperationType.Output); - - b.AddInput(a); - c.AddInput(b); - graph.AddNode(a); - graph.AddNode(b); - graph.AddNode(c); - - // Remove middle node - graph.RemoveNode(b); - - Assert.DoesNotContain(b, a.Outputs); - Assert.DoesNotContain(b, c.Inputs); - } - - // ============================ - // Graph Statistics Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task Statistics_ChainGraph_HandVerified() - { - var graph = CreateChainGraph(); - var stats = graph.GetStatistics(); - - Assert.Equal(4, stats.TotalNodes); - Assert.Equal(1, stats.InputNodes); - Assert.Equal(1, stats.OutputNodes); - Assert.Equal(0, stats.FusedNodes); - } - - [Fact(Timeout = 120000)] - public async Task Statistics_DiamondGraph_OperationCounts() - { - var graph = CreateDiamondGraph(); - var stats = graph.GetStatistics(); - - Assert.Equal(5, stats.TotalNodes); - Assert.True(stats.OperationTypeCounts.ContainsKey(OperationType.Input)); - Assert.True(stats.OperationTypeCounts.ContainsKey(OperationType.MatMul)); - Assert.True(stats.OperationTypeCounts.ContainsKey(OperationType.Convolution)); - Assert.True(stats.OperationTypeCounts.ContainsKey(OperationType.Add)); - Assert.True(stats.OperationTypeCounts.ContainsKey(OperationType.Output)); - - Assert.Equal(1, stats.OperationTypeCounts[OperationType.Input]); - Assert.Equal(1, stats.OperationTypeCounts[OperationType.MatMul]); - Assert.Equal(1, stats.OperationTypeCounts[OperationType.Convolution]); - Assert.Equal(1, stats.OperationTypeCounts[OperationType.Add]); - Assert.Equal(1, stats.OperationTypeCounts[OperationType.Output]); - } - - [Fact(Timeout = 120000)] - public async Task Statistics_WithFusedNode() - { - var graph = new OptimizationGraph(); - var node = CreateNode("fused", OperationType.MatMul); - node.IsFused = true; - - graph.AddNode(node); - - var stats = graph.GetStatistics(); - Assert.Equal(1, stats.FusedNodes); - } - - // ============================ - // Node Clone Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task Clone_CopiesProperties() - { - var original = CreateNode("conv1", OperationType.Convolution); - original.OutputShape = new[] { 32, 28, 28 }; - original.Parameters["weight"] = "test"; - original.CanEliminate = false; - - var clone = original.Clone(); - - Assert.NotEqual(original.Id, clone.Id); // New Id - Assert.Equal("conv1_clone", clone.Name); - Assert.Equal(OperationType.Convolution, clone.OperationType); - Assert.Equal(new[] { 32, 28, 28 }, clone.OutputShape); - Assert.False(clone.CanEliminate); - } - - [Fact(Timeout = 120000)] - public async Task Clone_NoSharedConnections() - { - var a = CreateNode("a"); - var b = CreateNode("b"); - b.AddInput(a); - - var clone = b.Clone(); - - // Clone should have no connections - Assert.Empty(clone.Inputs); - Assert.Empty(clone.Outputs); - } - - // ============================ - // Graph Clone Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task GraphClone_PreservesStructure() - { - var original = CreateChainGraph(); - var clone = (OptimizationGraph)original.Clone(); - - Assert.Equal(original.Nodes.Count, clone.Nodes.Count); - Assert.Equal(original.InputNodes.Count, clone.InputNodes.Count); - Assert.Equal(original.OutputNodes.Count, clone.OutputNodes.Count); - } - - [Fact(Timeout = 120000)] - public async Task GraphClone_TopologicalOrderPreserved() - { - var original = CreateChainGraph(); - var clone = (OptimizationGraph)original.Clone(); - - var cloneOrder = clone.GetTopologicalOrder(); - Assert.Equal(4, cloneOrder.Count); - } - - [Fact(Timeout = 120000)] - public async Task GraphClone_IsIndependent() - { - var original = CreateChainGraph(); - var clone = (OptimizationGraph)original.Clone(); - - // Modifying original should not affect clone - var newNode = CreateNode("extra"); - original.AddNode(newNode); - - Assert.Equal(5, original.Nodes.Count); - Assert.Equal(4, clone.Nodes.Count); - } - - // ============================ - // Validation Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task Validate_EmptyGraph_Valid() - { - var graph = new OptimizationGraph(); - Assert.True(graph.Validate()); - } - - [Fact(Timeout = 120000)] - public async Task Validate_DisconnectedNode_Invalid() - { - var graph = new OptimizationGraph(); - var input = CreateNode("input", OperationType.Input); - var disconnected = CreateNode("disconnected", OperationType.MatMul); - - graph.AddNode(input); - graph.AddNode(disconnected); - - // disconnected is not reachable from input - Assert.False(graph.Validate()); - } - - [Fact(Timeout = 120000)] - public async Task Validate_ConstantNodes_AllowedDisconnected() - { - var graph = new OptimizationGraph(); - var input = CreateNode("input", OperationType.Input); - var constant = CreateNode("const", OperationType.Constant); - - graph.AddNode(input); - graph.AddNode(constant); - - // Constants are allowed to be disconnected - Assert.True(graph.Validate()); - } - - // ============================ - // Node Default Properties Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task NewNode_DefaultProperties() - { - var node = new OptimizationNode(); - - Assert.NotNull(node.Id); - Assert.NotEmpty(node.Id); - Assert.Equal(string.Empty, node.Name); - Assert.Empty(node.Inputs); - Assert.Empty(node.Outputs); - Assert.Empty(node.OutputShape); - Assert.Empty(node.Parameters); - Assert.Empty(node.Metadata); - Assert.True(node.CanEliminate); - Assert.False(node.CanOperateInPlace); - Assert.False(node.IsMarkedForDeletion); - Assert.False(node.IsFused); - Assert.Null(node.ConstantValue); - Assert.Null(node.OriginalLayer); - Assert.Null(node.FusedFrom); - } - - // ============================ - // Graph Topology Invariants - // ============================ - - [Fact(Timeout = 120000)] - public async Task TopologicalOrder_EveryNodePrecedesItsConsumers() - { - var graph = CreateDiamondGraph(); - var order = graph.GetTopologicalOrder(); - - var positionMap = new Dictionary(); - for (int i = 0; i < order.Count; i++) - positionMap[order[i].Id] = i; - - // For every edge A -> B: position(A) < position(B) - foreach (var node in graph.Nodes) - { - foreach (var output in node.Outputs) - { - Assert.True(positionMap[node.Id] < positionMap[output.Id], - $"Node {node.Name} should come before {output.Name} in topological order"); - } - } - } - - [Fact(Timeout = 120000)] - public async Task TopologicalOrder_AllNodesIncluded() - { - var graph = CreateChainGraph(); - var order = graph.GetTopologicalOrder(); - - Assert.Equal(graph.Nodes.Count, order.Count); - - var orderIds = new HashSet(order.Select(n => n.Id)); - foreach (var node in graph.Nodes) - Assert.Contains(node.Id, orderIds); - } - - [Fact(Timeout = 120000)] - public async Task Edges_AreBidirectional() - { - // For every node, if B is in A.Outputs, then A should be in B.Inputs - var graph = CreateDiamondGraph(); - - foreach (var node in graph.Nodes) - { - foreach (var output in node.Outputs) - Assert.Contains(node, output.Inputs); - - foreach (var input in node.Inputs) - Assert.Contains(node, input.Outputs); - } - } - - // ============================ - // ToString Tests - // ============================ - - [Fact(Timeout = 120000)] - public async Task Node_ToString_ContainsInfo() - { - var node = CreateNode("conv1", OperationType.Convolution); - node.OutputShape = new[] { 32, 28, 28 }; - - var str = node.ToString(); - Assert.Contains("conv1", str); - Assert.Contains("Convolution", str); - } - - [Fact(Timeout = 120000)] - public async Task Graph_ToString_ContainsInfo() - { - var graph = CreateChainGraph(); - var str = graph.ToString(); - - Assert.Contains("4", str); // 4 nodes - } -} diff --git a/tests/AiDotNet.Tests/IntegrationTests/InferenceOptimization/InferenceOptimizationIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/InferenceOptimization/InferenceOptimizationIntegrationTests.cs deleted file mode 100644 index 91169b7d86..0000000000 --- a/tests/AiDotNet.Tests/IntegrationTests/InferenceOptimization/InferenceOptimizationIntegrationTests.cs +++ /dev/null @@ -1,1463 +0,0 @@ -using AiDotNet.Enums; -using AiDotNet.InferenceOptimization.Core; -using AiDotNet.InferenceOptimization.IR.Common; -using AiDotNet.InferenceOptimization.Passes; -using Xunit; -using System.Threading.Tasks; - -namespace AiDotNet.Tests.IntegrationTests.InferenceOptimization; - -/// -/// Integration tests for the InferenceOptimization module. -/// Tests cover optimization graphs, nodes, passes, IR types, and various optimization strategies. -/// -public class InferenceOptimizationIntegrationTests -{ - #region OptimizationNode Tests - - [Fact(Timeout = 120000)] - public async Task OptimizationNode_Constructor_SetsDefaults() - { - var node = new OptimizationNode(); - - Assert.NotNull(node.Id); - Assert.Empty(node.Name); - Assert.Empty(node.Inputs); - Assert.Empty(node.Outputs); - Assert.Empty(node.OutputShape); - Assert.Empty(node.Parameters); - Assert.Empty(node.Metadata); - Assert.True(node.CanEliminate); - Assert.False(node.CanOperateInPlace); - Assert.False(node.IsMarkedForDeletion); - Assert.False(node.IsFused); - Assert.Null(node.ConstantValue); - Assert.Null(node.FusedFrom); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationNode_AddInput_EstablishesConnection() - { - var node1 = new OptimizationNode { Name = "input" }; - var node2 = new OptimizationNode { Name = "output" }; - - node2.AddInput(node1); - - Assert.Single(node2.Inputs); - Assert.Contains(node1, node2.Inputs); - Assert.Single(node1.Outputs); - Assert.Contains(node2, node1.Outputs); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationNode_AddInput_DoesNotDuplicate() - { - var node1 = new OptimizationNode(); - var node2 = new OptimizationNode(); - - node2.AddInput(node1); - node2.AddInput(node1); // Add same node twice - - Assert.Single(node2.Inputs); - Assert.Single(node1.Outputs); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationNode_RemoveInput_RemovesConnection() - { - var node1 = new OptimizationNode(); - var node2 = new OptimizationNode(); - - node2.AddInput(node1); - node2.RemoveInput(node1); - - Assert.Empty(node2.Inputs); - Assert.Empty(node1.Outputs); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationNode_ReplaceInput_UpdatesConnection() - { - var oldInput = new OptimizationNode { Name = "old" }; - var newInput = new OptimizationNode { Name = "new" }; - var node = new OptimizationNode { Name = "consumer" }; - - node.AddInput(oldInput); - node.ReplaceInput(oldInput, newInput); - - Assert.Single(node.Inputs); - Assert.Contains(newInput, node.Inputs); - Assert.DoesNotContain(oldInput, node.Inputs); - Assert.Empty(oldInput.Outputs); - Assert.Single(newInput.Outputs); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationNode_HasConsumers_ReturnsTrueWhenOutputsExist() - { - var node1 = new OptimizationNode(); - var node2 = new OptimizationNode(); - - Assert.False(node1.HasConsumers()); - - node2.AddInput(node1); - - Assert.True(node1.HasConsumers()); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationNode_ConsumerCount_ReturnsCorrectCount() - { - var producer = new OptimizationNode(); - var consumer1 = new OptimizationNode(); - var consumer2 = new OptimizationNode(); - - Assert.Equal(0, producer.ConsumerCount()); - - consumer1.AddInput(producer); - Assert.Equal(1, producer.ConsumerCount()); - - consumer2.AddInput(producer); - Assert.Equal(2, producer.ConsumerCount()); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationNode_Clone_CreatesDeepCopy() - { - var original = new OptimizationNode - { - Name = "test_node", - OperationType = OperationType.Add, - OutputShape = new[] { 1, 2, 3 }, - CanEliminate = false, - IsFused = true - }; - original.Parameters["weight"] = 0.5; - original.Metadata["stride"] = 1; - - var clone = original.Clone(); - - Assert.NotEqual(original.Id, clone.Id); - Assert.Equal("test_node_clone", clone.Name); - Assert.Equal(OperationType.Add, clone.OperationType); - Assert.Equal(original.OutputShape, clone.OutputShape); - Assert.Equal(original.CanEliminate, clone.CanEliminate); - Assert.Equal(original.IsFused, clone.IsFused); - Assert.Equal(original.Parameters["weight"], clone.Parameters["weight"]); - Assert.Equal(original.Metadata["stride"], clone.Metadata["stride"]); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationNode_ToString_ReturnsFormattedString() - { - var node = new OptimizationNode - { - Name = "conv1", - OperationType = OperationType.Convolution2D, - OutputShape = new[] { 1, 64, 32, 32 } - }; - - var str = node.ToString(); - - Assert.Contains("conv1", str); - Assert.Contains("Convolution2D", str); - Assert.Contains("[1, 64, 32, 32]", str); - } - - #endregion - - #region OptimizationGraph Tests - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_Constructor_InitializesEmptyCollections() - { - var graph = new OptimizationGraph(); - - Assert.Empty(graph.Nodes); - Assert.Empty(graph.InputNodes); - Assert.Empty(graph.OutputNodes); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_AddNode_AddsToCollection() - { - var graph = new OptimizationGraph(); - var node = new OptimizationNode { Name = "node1" }; - - graph.AddNode(node); - - Assert.Single(graph.Nodes); - Assert.Contains(node, graph.Nodes); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_AddNode_TracksInputNodes() - { - var graph = new OptimizationGraph(); - var inputNode = new OptimizationNode - { - Name = "input", - OperationType = OperationType.Input - }; - - graph.AddNode(inputNode); - - Assert.Single(graph.InputNodes); - Assert.Contains(inputNode, graph.InputNodes); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_AddNode_TracksOutputNodes() - { - var graph = new OptimizationGraph(); - var outputNode = new OptimizationNode - { - Name = "output", - OperationType = OperationType.Output - }; - - graph.AddNode(outputNode); - - Assert.Single(graph.OutputNodes); - Assert.Contains(outputNode, graph.OutputNodes); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_AddNode_ThrowsOnNull() - { - var graph = new OptimizationGraph(); - - Assert.Throws(() => graph.AddNode(null!)); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_AddNode_DoesNotAddDuplicate() - { - var graph = new OptimizationGraph(); - var node = new OptimizationNode(); - - graph.AddNode(node); - graph.AddNode(node); // Add same node twice - - Assert.Single(graph.Nodes); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_RemoveNode_RemovesFromCollection() - { - var graph = new OptimizationGraph(); - var node = new OptimizationNode(); - - graph.AddNode(node); - graph.RemoveNode(node); - - Assert.Empty(graph.Nodes); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_RemoveNode_RemovesConnections() - { - var graph = new OptimizationGraph(); - var node1 = new OptimizationNode(); - var node2 = new OptimizationNode(); - var node3 = new OptimizationNode(); - - graph.AddNode(node1); - graph.AddNode(node2); - graph.AddNode(node3); - - node2.AddInput(node1); - node3.AddInput(node2); - - graph.RemoveNode(node2); - - Assert.Empty(node1.Outputs); - Assert.Empty(node3.Inputs); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_FindNodeById_ReturnsCorrectNode() - { - var graph = new OptimizationGraph(); - var node = new OptimizationNode { Name = "target" }; - - graph.AddNode(node); - - var found = graph.FindNodeById(node.Id); - - Assert.NotNull(found); - Assert.Equal(node, found); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_FindNodeById_ReturnsNullForMissing() - { - var graph = new OptimizationGraph(); - - var found = graph.FindNodeById("nonexistent"); - - Assert.Null(found); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_FindNodesByName_ReturnsMatchingNodes() - { - var graph = new OptimizationGraph(); - var node1 = new OptimizationNode { Name = "conv" }; - var node2 = new OptimizationNode { Name = "conv" }; - var node3 = new OptimizationNode { Name = "relu" }; - - graph.AddNode(node1); - graph.AddNode(node2); - graph.AddNode(node3); - - var found = graph.FindNodesByName("conv"); - - Assert.Equal(2, found.Count); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_GetTopologicalOrder_ReturnsCorrectOrder() - { - var graph = new OptimizationGraph(); - var input = new OptimizationNode { Name = "input", OperationType = OperationType.Input }; - var middle = new OptimizationNode { Name = "middle", OperationType = OperationType.ReLU }; - var output = new OptimizationNode { Name = "output", OperationType = OperationType.Output }; - - graph.AddNode(input); - graph.AddNode(middle); - graph.AddNode(output); - - middle.AddInput(input); - output.AddInput(middle); - - var order = graph.GetTopologicalOrder(); - - Assert.Equal(3, order.Count); - Assert.True(order.IndexOf(input) < order.IndexOf(middle)); - Assert.True(order.IndexOf(middle) < order.IndexOf(output)); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_GetTopologicalOrder_ThrowsOnCycle() - { - var graph = new OptimizationGraph(); - var node1 = new OptimizationNode { Name = "node1" }; - var node2 = new OptimizationNode { Name = "node2" }; - - graph.AddNode(node1); - graph.AddNode(node2); - - // Create a cycle - node1.Inputs.Add(node2); - node2.Inputs.Add(node1); - node1.Outputs.Add(node2); - node2.Outputs.Add(node1); - - Assert.Throws(() => graph.GetTopologicalOrder()); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_Validate_ReturnsTrueForValidGraph() - { - var graph = new OptimizationGraph(); - var input = new OptimizationNode { Name = "input", OperationType = OperationType.Input }; - var output = new OptimizationNode { Name = "output", OperationType = OperationType.Output }; - - graph.AddNode(input); - graph.AddNode(output); - - output.AddInput(input); - - Assert.True(graph.Validate()); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_Validate_ReturnsFalseForCyclicGraph() - { - var graph = new OptimizationGraph(); - var node1 = new OptimizationNode(); - var node2 = new OptimizationNode(); - - graph.AddNode(node1); - graph.AddNode(node2); - - // Create cycle - node1.Inputs.Add(node2); - node2.Inputs.Add(node1); - node1.Outputs.Add(node2); - node2.Outputs.Add(node1); - - Assert.False(graph.Validate()); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_Clone_CreatesDeepCopy() - { - var graph = new OptimizationGraph(); - var input = new OptimizationNode { Name = "input", OperationType = OperationType.Input }; - var output = new OptimizationNode { Name = "output", OperationType = OperationType.Output }; - - graph.AddNode(input); - graph.AddNode(output); - output.AddInput(input); - - var clone = graph.Clone(); - - Assert.Equal(2, clone.Nodes.Count); - Assert.Single(clone.InputNodes); - Assert.Single(clone.OutputNodes); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_GetStatistics_ReturnsCorrectStats() - { - var graph = new OptimizationGraph(); - var input = new OptimizationNode { OperationType = OperationType.Input }; - var relu1 = new OptimizationNode { OperationType = OperationType.ReLU }; - var relu2 = new OptimizationNode { OperationType = OperationType.ReLU, IsFused = true }; - var output = new OptimizationNode { OperationType = OperationType.Output }; - - graph.AddNode(input); - graph.AddNode(relu1); - graph.AddNode(relu2); - graph.AddNode(output); - - var stats = graph.GetStatistics(); - - Assert.Equal(4, stats.TotalNodes); - Assert.Equal(1, stats.InputNodes); - Assert.Equal(1, stats.OutputNodes); - Assert.Equal(1, stats.FusedNodes); - Assert.True(stats.OperationTypeCounts.ContainsKey(OperationType.ReLU)); - Assert.Equal(2, stats.OperationTypeCounts[OperationType.ReLU]); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationGraph_ToString_ReturnsFormattedString() - { - var graph = new OptimizationGraph(); - var input = new OptimizationNode { OperationType = OperationType.Input }; - var output = new OptimizationNode { OperationType = OperationType.Output }; - - graph.AddNode(input); - graph.AddNode(output); - - var str = graph.ToString(); - - Assert.Contains("2 nodes", str); - Assert.Contains("1 inputs", str); - Assert.Contains("1 outputs", str); - } - - #endregion - - #region GraphStatistics Tests - - [Fact(Timeout = 120000)] - public async Task GraphStatistics_ToString_ReturnsFormattedString() - { - var stats = new GraphStatistics - { - TotalNodes = 10, - InputNodes = 2, - OutputNodes = 1, - FusedNodes = 3, - OperationTypeCounts = new Dictionary - { - { OperationType.ReLU, 5 }, - { OperationType.Add, 3 } - } - }; - - var str = stats.ToString(); - - Assert.Contains("10", str); - Assert.Contains("ReLU", str); - } - - #endregion - - #region OptimizationLevel Enum Tests - - [Fact(Timeout = 120000)] - public async Task OptimizationLevel_HasExpectedValues() - { - var levels = (OptimizationLevel[])Enum.GetValues(typeof(OptimizationLevel)); - - Assert.Contains(OptimizationLevel.None, levels); - Assert.Contains(OptimizationLevel.Basic, levels); - Assert.Contains(OptimizationLevel.Standard, levels); - Assert.Contains(OptimizationLevel.Aggressive, levels); - Assert.Contains(OptimizationLevel.Maximum, levels); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationLevel_ValuesAreOrdered() - { - Assert.True((int)OptimizationLevel.None < (int)OptimizationLevel.Basic); - Assert.True((int)OptimizationLevel.Basic < (int)OptimizationLevel.Standard); - Assert.True((int)OptimizationLevel.Standard < (int)OptimizationLevel.Aggressive); - Assert.True((int)OptimizationLevel.Aggressive < (int)OptimizationLevel.Maximum); - } - - #endregion - - #region OptimizationOptions Tests - - [Fact(Timeout = 120000)] - public async Task OptimizationOptions_Constructor_SetsDefaults() - { - var options = new OptimizationOptions(); - - Assert.Equal(OptimizationLevel.Standard, options.Level); - Assert.Equal("NCHW", options.TargetLayout); - Assert.Equal(10, options.MaxIterations); - Assert.True(options.EnableOperatorFusion); - Assert.True(options.EnableConstantFolding); - Assert.True(options.EnableDeadCodeElimination); - Assert.True(options.EnableCSE); - Assert.True(options.EnableLayoutOptimization); - Assert.True(options.EnableInPlaceOptimization); - Assert.True(options.EnableMemoryReuse); - Assert.True(options.EnableAlgebraicSimplification); - Assert.True(options.EnableStrengthReduction); - Assert.False(options.PrintStatistics); - Assert.False(options.ValidateAfterEachPass); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationOptions_FromLevel_None_DisablesAll() - { - var options = OptimizationOptions.FromLevel(OptimizationLevel.None); - - Assert.False(options.EnableOperatorFusion); - Assert.False(options.EnableConstantFolding); - Assert.False(options.EnableDeadCodeElimination); - Assert.False(options.EnableCSE); - Assert.False(options.EnableLayoutOptimization); - Assert.False(options.EnableInPlaceOptimization); - Assert.False(options.EnableMemoryReuse); - Assert.False(options.EnableAlgebraicSimplification); - Assert.False(options.EnableStrengthReduction); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationOptions_FromLevel_Basic_EnablesBasicOnly() - { - var options = OptimizationOptions.FromLevel(OptimizationLevel.Basic); - - Assert.True(options.EnableDeadCodeElimination); - Assert.True(options.EnableConstantFolding); - Assert.False(options.EnableOperatorFusion); - Assert.False(options.EnableCSE); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationOptions_FromLevel_Standard_EnablesStandard() - { - var options = OptimizationOptions.FromLevel(OptimizationLevel.Standard); - - Assert.True(options.EnableOperatorFusion); - Assert.True(options.EnableConstantFolding); - Assert.True(options.EnableDeadCodeElimination); - Assert.True(options.EnableAlgebraicSimplification); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationOptions_FromLevel_Aggressive_EnablesMore() - { - var options = OptimizationOptions.FromLevel(OptimizationLevel.Aggressive); - - Assert.True(options.EnableOperatorFusion); - Assert.True(options.EnableConstantFolding); - Assert.True(options.EnableDeadCodeElimination); - Assert.True(options.EnableCSE); - Assert.True(options.EnableAlgebraicSimplification); - Assert.True(options.EnableStrengthReduction); - Assert.True(options.EnableInPlaceOptimization); - Assert.True(options.EnableMemoryReuse); - } - - [Fact(Timeout = 120000)] - public async Task OptimizationOptions_FromLevel_Maximum_EnablesAll() - { - var options = OptimizationOptions.FromLevel(OptimizationLevel.Maximum); - - Assert.True(options.EnableOperatorFusion); - Assert.True(options.EnableConstantFolding); - Assert.True(options.EnableDeadCodeElimination); - Assert.True(options.EnableCSE); - Assert.True(options.EnableLayoutOptimization); - Assert.True(options.EnableInPlaceOptimization); - Assert.True(options.EnableMemoryReuse); - Assert.True(options.EnableAlgebraicSimplification); - Assert.True(options.EnableStrengthReduction); - } - - #endregion - - #region OptimizationPassType Enum Tests - - [Fact(Timeout = 120000)] - public async Task OptimizationPassType_HasExpectedValues() - { - var passTypes = (OptimizationPassType[])Enum.GetValues(typeof(OptimizationPassType)); - - // Fusion passes - Assert.Contains(OptimizationPassType.OperatorFusion, passTypes); - Assert.Contains(OptimizationPassType.ConvBatchNormFusion, passTypes); - Assert.Contains(OptimizationPassType.ConvBatchNormReLUFusion, passTypes); - Assert.Contains(OptimizationPassType.MatMulBiasFusion, passTypes); - Assert.Contains(OptimizationPassType.ElementwiseFusion, passTypes); - - // Graph optimization passes - Assert.Contains(OptimizationPassType.ConstantFolding, passTypes); - Assert.Contains(OptimizationPassType.DeadCodeElimination, passTypes); - Assert.Contains(OptimizationPassType.CommonSubexpressionElimination, passTypes); - Assert.Contains(OptimizationPassType.LayoutOptimization, passTypes); - - // Memory passes - Assert.Contains(OptimizationPassType.InPlaceOptimization, passTypes); - Assert.Contains(OptimizationPassType.MemoryReuseOptimization, passTypes); - - // Computation passes - Assert.Contains(OptimizationPassType.AlgebraicSimplification, passTypes); - Assert.Contains(OptimizationPassType.StrengthReduction, passTypes); - } - - #endregion - - #region DeadCodeEliminationPass Tests - - [Fact(Timeout = 120000)] - public async Task DeadCodeEliminationPass_Properties() - { - var pass = new DeadCodeEliminationPass(); - - Assert.Equal(OptimizationPassType.DeadCodeElimination, pass.PassType); - Assert.Equal("Dead Code Elimination", pass.Name); - } - - [Fact(Timeout = 120000)] - public async Task DeadCodeEliminationPass_CanApply_ReturnsFalseForEmptyGraph() - { - var pass = new DeadCodeEliminationPass(); - var graph = new OptimizationGraph(); - - Assert.False(pass.CanApply(graph)); - } - - [Fact(Timeout = 120000)] - public async Task DeadCodeEliminationPass_CanApply_ReturnsFalseWhenNoOutputs() - { - var pass = new DeadCodeEliminationPass(); - var graph = new OptimizationGraph(); - var node = new OptimizationNode { OperationType = OperationType.ReLU }; - graph.AddNode(node); - - Assert.False(pass.CanApply(graph)); - } - - [Fact(Timeout = 120000)] - public async Task DeadCodeEliminationPass_CanApply_ReturnsTrueForValidGraph() - { - var pass = new DeadCodeEliminationPass(); - var graph = new OptimizationGraph(); - var output = new OptimizationNode { OperationType = OperationType.Output }; - graph.AddNode(output); - - Assert.True(pass.CanApply(graph)); - } - - [Fact(Timeout = 120000)] - public async Task DeadCodeEliminationPass_Apply_RemovesUnreachableNodes() - { - var pass = new DeadCodeEliminationPass(); - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input }; - var output = new OptimizationNode { OperationType = OperationType.Output }; - var deadNode = new OptimizationNode { OperationType = OperationType.ReLU, CanEliminate = true }; - - graph.AddNode(input); - graph.AddNode(output); - graph.AddNode(deadNode); - output.AddInput(input); - - bool modified = pass.Apply(graph); - - Assert.True(modified); - Assert.DoesNotContain(deadNode, graph.Nodes); - Assert.Equal(2, graph.Nodes.Count); - } - - [Fact(Timeout = 120000)] - public async Task DeadCodeEliminationPass_Apply_PreservesReachableNodes() - { - var pass = new DeadCodeEliminationPass(); - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input }; - var middle = new OptimizationNode { OperationType = OperationType.ReLU }; - var output = new OptimizationNode { OperationType = OperationType.Output }; - - graph.AddNode(input); - graph.AddNode(middle); - graph.AddNode(output); - - middle.AddInput(input); - output.AddInput(middle); - - bool modified = pass.Apply(graph); - - Assert.False(modified); - Assert.Equal(3, graph.Nodes.Count); - } - - [Fact(Timeout = 120000)] - public async Task DeadCodeEliminationPass_Apply_DoesNotRemoveNonEliminableNodes() - { - var pass = new DeadCodeEliminationPass(); - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input }; - var output = new OptimizationNode { OperationType = OperationType.Output }; - var sideEffectNode = new OptimizationNode - { - OperationType = OperationType.Custom, - CanEliminate = false - }; - - graph.AddNode(input); - graph.AddNode(output); - graph.AddNode(sideEffectNode); - output.AddInput(input); - - bool modified = pass.Apply(graph); - - Assert.Contains(sideEffectNode, graph.Nodes); - } - - #endregion - - #region IRDataType Tests - - [Fact(Timeout = 120000)] - public async Task IRDataType_HasExpectedValues() - { - var types = (IRDataType[])Enum.GetValues(typeof(IRDataType)); - - // Floating point - Assert.Contains(IRDataType.Float16, types); - Assert.Contains(IRDataType.Float32, types); - Assert.Contains(IRDataType.Float64, types); - Assert.Contains(IRDataType.BFloat16, types); - - // Integer - Assert.Contains(IRDataType.Int8, types); - Assert.Contains(IRDataType.Int32, types); - Assert.Contains(IRDataType.Int64, types); - Assert.Contains(IRDataType.UInt8, types); - - // Quantized - Assert.Contains(IRDataType.QInt8, types); - Assert.Contains(IRDataType.QUInt8, types); - Assert.Contains(IRDataType.QInt4, types); - - // Other - Assert.Contains(IRDataType.Bool, types); - Assert.Contains(IRDataType.Complex64, types); - Assert.Contains(IRDataType.Decimal, types); - } - - [Fact(Timeout = 120000)] - public async Task IRDataTypeExtensions_IsFloatingPoint() - { - Assert.True(IRDataType.Float16.IsFloatingPoint()); - Assert.True(IRDataType.Float32.IsFloatingPoint()); - Assert.True(IRDataType.Float64.IsFloatingPoint()); - Assert.True(IRDataType.BFloat16.IsFloatingPoint()); - Assert.False(IRDataType.Int32.IsFloatingPoint()); - Assert.False(IRDataType.QInt8.IsFloatingPoint()); - } - - [Fact(Timeout = 120000)] - public async Task IRDataTypeExtensions_IsInteger() - { - Assert.True(IRDataType.Int8.IsInteger()); - Assert.True(IRDataType.Int32.IsInteger()); - Assert.True(IRDataType.UInt8.IsInteger()); - Assert.True(IRDataType.Int64.IsInteger()); - Assert.False(IRDataType.Float32.IsInteger()); - Assert.False(IRDataType.QInt8.IsInteger()); - } - - [Fact(Timeout = 120000)] - public async Task IRDataTypeExtensions_IsQuantized() - { - Assert.True(IRDataType.QInt8.IsQuantized()); - Assert.True(IRDataType.QUInt8.IsQuantized()); - Assert.True(IRDataType.QInt4.IsQuantized()); - Assert.True(IRDataType.QInt2.IsQuantized()); - Assert.False(IRDataType.Float32.IsQuantized()); - Assert.False(IRDataType.Int8.IsQuantized()); - } - - [Fact(Timeout = 120000)] - public async Task IRDataTypeExtensions_ElementSizeInBytes() - { - Assert.Equal(1, IRDataType.Bool.ElementSizeInBytes()); - Assert.Equal(1, IRDataType.Int8.ElementSizeInBytes()); - Assert.Equal(2, IRDataType.Float16.ElementSizeInBytes()); - Assert.Equal(4, IRDataType.Float32.ElementSizeInBytes()); - Assert.Equal(8, IRDataType.Float64.ElementSizeInBytes()); - Assert.Equal(16, IRDataType.Decimal.ElementSizeInBytes()); - } - - [Fact(Timeout = 120000)] - public async Task IRDataTypeExtensions_FromSystemType() - { - Assert.Equal(IRDataType.Float32, IRDataTypeExtensions.FromSystemType(typeof(float))); - Assert.Equal(IRDataType.Float64, IRDataTypeExtensions.FromSystemType(typeof(double))); - Assert.Equal(IRDataType.Int32, IRDataTypeExtensions.FromSystemType(typeof(int))); - Assert.Equal(IRDataType.Int64, IRDataTypeExtensions.FromSystemType(typeof(long))); - Assert.Equal(IRDataType.Bool, IRDataTypeExtensions.FromSystemType(typeof(bool))); - Assert.Equal(IRDataType.Decimal, IRDataTypeExtensions.FromSystemType(typeof(decimal))); - } - - [Fact(Timeout = 120000)] - public async Task IRDataTypeExtensions_ToSystemType() - { - Assert.Equal(typeof(float), IRDataType.Float32.ToSystemType()); - Assert.Equal(typeof(double), IRDataType.Float64.ToSystemType()); - Assert.Equal(typeof(int), IRDataType.Int32.ToSystemType()); - Assert.Equal(typeof(long), IRDataType.Int64.ToSystemType()); - Assert.Equal(typeof(bool), IRDataType.Bool.ToSystemType()); - Assert.Equal(typeof(decimal), IRDataType.Decimal.ToSystemType()); - } - - #endregion - - #region MemoryLayout Tests - - [Fact(Timeout = 120000)] - public async Task MemoryLayout_HasExpectedValues() - { - var layouts = (MemoryLayout[])Enum.GetValues(typeof(MemoryLayout)); - - Assert.Contains(MemoryLayout.RowMajor, layouts); - Assert.Contains(MemoryLayout.ColumnMajor, layouts); - Assert.Contains(MemoryLayout.NCHW, layouts); - Assert.Contains(MemoryLayout.NHWC, layouts); - Assert.Contains(MemoryLayout.Tiled4x4, layouts); - Assert.Contains(MemoryLayout.Blocked, layouts); - } - - #endregion - - #region DeviceType Tests - - [Fact(Timeout = 120000)] - public async Task DeviceType_HasExpectedValues() - { - var devices = (DeviceType[])Enum.GetValues(typeof(DeviceType)); - - Assert.Contains(DeviceType.CPU, devices); - Assert.Contains(DeviceType.GPU, devices); - Assert.Contains(DeviceType.TPU, devices); - Assert.Contains(DeviceType.NPU, devices); - Assert.Contains(DeviceType.FPGA, devices); - Assert.Contains(DeviceType.Auto, devices); - Assert.Contains(DeviceType.Any, devices); - } - - #endregion - - #region QuantizationParams Tests - - [Fact(Timeout = 120000)] - public async Task QuantizationParams_DefaultValues() - { - var qParams = new QuantizationParams(); - - Assert.Equal(1.0, qParams.Scale); - Assert.Equal(0, qParams.ZeroPoint); - Assert.Equal(double.MinValue, qParams.Min); - Assert.Equal(double.MaxValue, qParams.Max); - Assert.False(qParams.PerChannel); - Assert.Equal(-1, qParams.QuantizationAxis); - Assert.Null(qParams.PerChannelScales); - Assert.Null(qParams.PerChannelZeroPoints); - } - - [Fact(Timeout = 120000)] - public async Task QuantizationParams_CanBeConfigured() - { - var qParams = new QuantizationParams - { - Scale = 0.01, - ZeroPoint = 128, - Min = -1.0, - Max = 1.0, - PerChannel = true, - QuantizationAxis = 0, - PerChannelScales = new[] { 0.01, 0.02, 0.03 }, - PerChannelZeroPoints = new[] { 128, 127, 126 } - }; - - Assert.Equal(0.01, qParams.Scale); - Assert.Equal(128, qParams.ZeroPoint); - Assert.True(qParams.PerChannel); - Assert.Equal(3, qParams.PerChannelScales?.Length); - } - - #endregion - - #region TensorType Tests - - [Fact(Timeout = 120000)] - public async Task TensorType_DefaultValues() - { - var tensorType = new TensorType(); - - Assert.Equal(IRDataType.Float32, tensorType.DataType); - Assert.Empty(tensorType.Shape.ToArray()); - Assert.Equal(MemoryLayout.RowMajor, tensorType.Layout); - Assert.Equal(DeviceType.Auto, tensorType.Device); - Assert.Null(tensorType.Quantization); - Assert.Null(tensorType.Strides); - } - - [Fact(Timeout = 120000)] - public async Task TensorType_HasDynamicShape_ReturnsTrueForDynamicDimensions() - { - var staticType = new TensorType { Shape = new[] { 1, 3, 224, 224 } }; - var dynamicType = new TensorType { Shape = new[] { -1, 3, 224, 224 } }; - - Assert.False(staticType.HasDynamicShape); - Assert.True(dynamicType.HasDynamicShape); - } - - [Fact(Timeout = 120000)] - public async Task TensorType_NumElements_CalculatesCorrectly() - { - var scalar = new TensorType { Shape = Array.Empty() }; - var vector = new TensorType { Shape = new[] { 10 } }; - var matrix = new TensorType { Shape = new[] { 3, 4 } }; - var tensor = new TensorType { Shape = new[] { 2, 3, 4 } }; - - Assert.Equal(1, scalar.NumElements); - Assert.Equal(10, vector.NumElements); - Assert.Equal(12, matrix.NumElements); - Assert.Equal(24, tensor.NumElements); - } - - [Fact(Timeout = 120000)] - public async Task TensorType_NumElements_ReturnsMinusOneForDynamic() - { - var dynamicType = new TensorType { Shape = new[] { -1, 3, 224, 224 } }; - - Assert.Equal(-1, dynamicType.NumElements); - } - - [Fact(Timeout = 120000)] - public async Task TensorType_ElementSize_ReturnsCorrectSize() - { - var float32Type = new TensorType { DataType = IRDataType.Float32 }; - var float64Type = new TensorType { DataType = IRDataType.Float64 }; - var int8Type = new TensorType { DataType = IRDataType.Int8 }; - - Assert.Equal(4, float32Type.ElementSize); - Assert.Equal(8, float64Type.ElementSize); - Assert.Equal(1, int8Type.ElementSize); - } - - [Fact(Timeout = 120000)] - public async Task TensorType_TotalBytes_CalculatesCorrectly() - { - var tensorType = new TensorType - { - DataType = IRDataType.Float32, - Shape = new[] { 1, 3, 224, 224 } - }; - - long expectedBytes = 1L * 3 * 224 * 224 * 4; // 4 bytes per float32 - Assert.Equal(expectedBytes, tensorType.TotalBytes); - } - - [Fact(Timeout = 120000)] - public async Task TensorType_IsBroadcastCompatible_ChecksCorrectly() - { - var type1 = new TensorType { Shape = new[] { 1, 3, 1 } }; - var type2 = new TensorType { Shape = new[] { 4, 3, 5 } }; - var type3 = new TensorType { Shape = new[] { 4, 2, 5 } }; - - Assert.True(type1.IsBroadcastCompatible(type2)); - Assert.False(type1.IsBroadcastCompatible(type3)); // 3 != 2 - } - - [Fact(Timeout = 120000)] - public async Task TensorType_Clone_CreatesDeepCopy() - { - var original = new TensorType - { - DataType = IRDataType.Float16, - Shape = new[] { 1, 2, 3 }, - Layout = MemoryLayout.NHWC, - Device = DeviceType.GPU - }; - - var clone = original.Clone(); - - Assert.Equal(original.DataType, clone.DataType); - Assert.Equal(original.Shape.ToArray(), clone.Shape.ToArray()); - Assert.NotSame(original.Shape.ToArray(), clone.Shape.ToArray()); - Assert.Equal(original.Layout, clone.Layout); - Assert.Equal(original.Device, clone.Device); - } - - [Fact(Timeout = 120000)] - public async Task TensorType_ToString_ReturnsFormattedString() - { - var tensorType = new TensorType - { - DataType = IRDataType.Float32, - Shape = new[] { 1, 3, 224, 224 }, - Device = DeviceType.GPU - }; - - var str = tensorType.ToString(); - - Assert.Contains("Float32", str); - Assert.Contains("GPU", str); - } - - #endregion - - #region ConstantFoldingPass Tests - - [Fact(Timeout = 120000)] - public async Task ConstantFoldingPass_Properties() - { - var pass = new ConstantFoldingPass(); - - Assert.Equal(OptimizationPassType.ConstantFolding, pass.PassType); - Assert.Equal("Constant Folding", pass.Name); - } - - [Fact(Timeout = 120000)] - public async Task ConstantFoldingPass_CanApply_ReturnsTrueWhenConstantExists() - { - var pass = new ConstantFoldingPass(); - var graph = new OptimizationGraph(); - var constantNode = new OptimizationNode { OperationType = OperationType.Constant }; - graph.AddNode(constantNode); - - Assert.True(pass.CanApply(graph)); - } - - [Fact(Timeout = 120000)] - public async Task ConstantFoldingPass_CanApply_ReturnsFalseWhenNoConstants() - { - var pass = new ConstantFoldingPass(); - var graph = new OptimizationGraph(); - var node = new OptimizationNode { OperationType = OperationType.ReLU }; - graph.AddNode(node); - - Assert.False(pass.CanApply(graph)); - } - - #endregion - - #region AlgebraicSimplificationPass Tests - - [Fact(Timeout = 120000)] - public async Task AlgebraicSimplificationPass_Properties() - { - var pass = new AlgebraicSimplificationPass(); - - Assert.Equal(OptimizationPassType.AlgebraicSimplification, pass.PassType); - Assert.Equal("Algebraic Simplification", pass.Name); - } - - #endregion - - #region CommonSubexpressionEliminationPass Tests - - [Fact(Timeout = 120000)] - public async Task CommonSubexpressionEliminationPass_Properties() - { - var pass = new CommonSubexpressionEliminationPass(); - - Assert.Equal(OptimizationPassType.CommonSubexpressionElimination, pass.PassType); - Assert.Equal("Common Subexpression Elimination", pass.Name); - } - - [Fact(Timeout = 120000)] - public async Task CommonSubexpressionEliminationPass_EliminatesIdenticalAddOperations() - { - var pass = new CommonSubexpressionEliminationPass(); - var graph = new OptimizationGraph(); - - var input1 = new OptimizationNode { OperationType = OperationType.Input, Name = "a" }; - var input2 = new OptimizationNode { OperationType = OperationType.Input, Name = "b" }; - var add1 = new OptimizationNode { OperationType = OperationType.Add, Name = "add1" }; - var add2 = new OptimizationNode { OperationType = OperationType.Add, Name = "add2" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "out" }; - - graph.AddNode(input1); - graph.AddNode(input2); - graph.AddNode(add1); - graph.AddNode(add2); - graph.AddNode(output); - - // Both add operations use the same inputs (a + b and a + b) - add1.AddInput(input1); - add1.AddInput(input2); - add2.AddInput(input1); - add2.AddInput(input2); - output.AddInput(add1); - - bool modified = pass.Apply(graph); - - // Should eliminate one of the duplicate add operations - Assert.True(modified); - } - - /// - /// BUG TEST: CSE should NOT eliminate non-commutative operations with reversed operands. - /// The current implementation incorrectly sorts input IDs which would merge a-b and b-a. - /// - [Fact(Timeout = 120000)] - public async Task CommonSubexpressionEliminationPass_PreservesNonCommutativeOperations() - { - var pass = new CommonSubexpressionEliminationPass(); - var graph = new OptimizationGraph(); - - var inputA = new OptimizationNode { OperationType = OperationType.Input, Name = "a" }; - var inputB = new OptimizationNode { OperationType = OperationType.Input, Name = "b" }; - var sub1 = new OptimizationNode { OperationType = OperationType.Subtract, Name = "a_minus_b" }; - var sub2 = new OptimizationNode { OperationType = OperationType.Subtract, Name = "b_minus_a" }; - var add = new OptimizationNode { OperationType = OperationType.Add, Name = "result" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "out" }; - - graph.AddNode(inputA); - graph.AddNode(inputB); - graph.AddNode(sub1); - graph.AddNode(sub2); - graph.AddNode(add); - graph.AddNode(output); - - // sub1 = a - b - sub1.AddInput(inputA); - sub1.AddInput(inputB); - - // sub2 = b - a (DIFFERENT from a - b!) - sub2.AddInput(inputB); - sub2.AddInput(inputA); - - // result = (a - b) + (b - a) - add.AddInput(sub1); - add.AddInput(sub2); - output.AddInput(add); - - // Count subtraction nodes before - int subCountBefore = graph.Nodes.Count(n => n.OperationType == OperationType.Subtract); - Assert.Equal(2, subCountBefore); - - bool modified = pass.Apply(graph); - - // Count subtraction nodes after - int subCountAfter = graph.Nodes.Count(n => n.OperationType == OperationType.Subtract); - - // Both subtraction nodes should be preserved because a-b ≠ b-a - Assert.Equal(2, subCountAfter); - Assert.False(modified, "Non-commutative operations with different operand order should NOT be merged"); - } - - /// - /// BUG TEST: CSE should NOT eliminate division operations with reversed operands. - /// a/b ≠ b/a, so these should not be merged. - /// - [Fact(Timeout = 120000)] - public async Task CommonSubexpressionEliminationPass_PreservesDivisionOperandOrder() - { - var pass = new CommonSubexpressionEliminationPass(); - var graph = new OptimizationGraph(); - - var inputA = new OptimizationNode { OperationType = OperationType.Input, Name = "a" }; - var inputB = new OptimizationNode { OperationType = OperationType.Input, Name = "b" }; - var div1 = new OptimizationNode { OperationType = OperationType.Divide, Name = "a_div_b" }; - var div2 = new OptimizationNode { OperationType = OperationType.Divide, Name = "b_div_a" }; - var multiply = new OptimizationNode { OperationType = OperationType.Multiply, Name = "result" }; - var output = new OptimizationNode { OperationType = OperationType.Output, Name = "out" }; - - graph.AddNode(inputA); - graph.AddNode(inputB); - graph.AddNode(div1); - graph.AddNode(div2); - graph.AddNode(multiply); - graph.AddNode(output); - - // div1 = a / b - div1.AddInput(inputA); - div1.AddInput(inputB); - - // div2 = b / a (DIFFERENT from a / b!) - div2.AddInput(inputB); - div2.AddInput(inputA); - - // result = (a / b) * (b / a) - multiply.AddInput(div1); - multiply.AddInput(div2); - output.AddInput(multiply); - - // Count division nodes before - int divCountBefore = graph.Nodes.Count(n => n.OperationType == OperationType.Divide); - Assert.Equal(2, divCountBefore); - - bool modified = pass.Apply(graph); - - // Count division nodes after - int divCountAfter = graph.Nodes.Count(n => n.OperationType == OperationType.Divide); - - // Both division nodes should be preserved because a/b ≠ b/a - Assert.Equal(2, divCountAfter); - Assert.False(modified, "Division operations with different operand order should NOT be merged"); - } - - [Fact(Timeout = 120000)] - public async Task CommonSubexpressionEliminationPass_CanApply_ReturnsTrueForGraphWithMultipleNodes() - { - var pass = new CommonSubexpressionEliminationPass(); - var graph = new OptimizationGraph(); - - var node1 = new OptimizationNode { OperationType = OperationType.Add }; - var node2 = new OptimizationNode { OperationType = OperationType.Add }; - graph.AddNode(node1); - graph.AddNode(node2); - - Assert.True(pass.CanApply(graph)); - } - - [Fact(Timeout = 120000)] - public async Task CommonSubexpressionEliminationPass_CanApply_ReturnsFalseForSingleNodeGraph() - { - var pass = new CommonSubexpressionEliminationPass(); - var graph = new OptimizationGraph(); - - var node = new OptimizationNode(); - graph.AddNode(node); - - Assert.False(pass.CanApply(graph)); - } - - #endregion - - #region StrengthReductionPass Tests - - [Fact(Timeout = 120000)] - public async Task StrengthReductionPass_Properties() - { - var pass = new StrengthReductionPass(); - - Assert.Equal(OptimizationPassType.StrengthReduction, pass.PassType); - Assert.Equal("Strength Reduction", pass.Name); - } - - #endregion - - #region InPlaceOptimizationPass Tests - - [Fact(Timeout = 120000)] - public async Task InPlaceOptimizationPass_Properties() - { - var pass = new InPlaceOptimizationPass(); - - Assert.Equal(OptimizationPassType.InPlaceOptimization, pass.PassType); - Assert.Equal("In-Place Operation Optimization", pass.Name); - } - - #endregion - - #region MemoryReuseOptimizationPass Tests - - [Fact(Timeout = 120000)] - public async Task MemoryReuseOptimizationPass_Properties() - { - var pass = new MemoryReuseOptimizationPass(); - - Assert.Equal(OptimizationPassType.MemoryReuseOptimization, pass.PassType); - Assert.Equal("Memory Reuse Optimization", pass.Name); - } - - #endregion - - #region LayoutOptimizationPass Tests - - [Fact(Timeout = 120000)] - public async Task LayoutOptimizationPass_Properties() - { - var pass = new LayoutOptimizationPass(); - - Assert.Equal(OptimizationPassType.LayoutOptimization, pass.PassType); - Assert.Equal("Layout Optimization", pass.Name); - } - - #endregion - - #region ElementwiseFusionPass Tests - - [Fact(Timeout = 120000)] - public async Task ElementwiseFusionPass_Properties() - { - var pass = new ElementwiseFusionPass(); - - Assert.Equal(OptimizationPassType.ElementwiseFusion, pass.PassType); - Assert.Equal("Elementwise Operation Fusion", pass.Name); - } - - #endregion - - #region ConvBatchNormFusionPass Tests - - [Fact(Timeout = 120000)] - public async Task ConvBatchNormFusionPass_Properties() - { - var pass = new ConvBatchNormFusionPass(); - - Assert.Equal(OptimizationPassType.ConvBatchNormFusion, pass.PassType); - Assert.Equal("Conv + BatchNorm Fusion", pass.Name); - } - - #endregion - - #region ConvBatchNormReLUFusionPass Tests - - [Fact(Timeout = 120000)] - public async Task ConvBatchNormReLUFusionPass_Properties() - { - var pass = new ConvBatchNormReLUFusionPass(); - - Assert.Equal(OptimizationPassType.ConvBatchNormReLUFusion, pass.PassType); - Assert.Equal("Conv + BatchNorm + ReLU Fusion", pass.Name); - } - - #endregion - - #region MatMulBiasFusionPass Tests - - [Fact(Timeout = 120000)] - public async Task MatMulBiasFusionPass_Properties() - { - var pass = new MatMulBiasFusionPass(); - - Assert.Equal(OptimizationPassType.MatMulBiasFusion, pass.PassType); - Assert.Equal("MatMul + Bias Fusion", pass.Name); - } - - #endregion - - #region MatMulBiasActivationFusionPass Tests - - [Fact(Timeout = 120000)] - public async Task MatMulBiasActivationFusionPass_Properties() - { - var pass = new MatMulBiasActivationFusionPass(); - - Assert.Equal(OptimizationPassType.MatMulBiasActivationFusion, pass.PassType); - Assert.Equal("MatMul + Bias + Activation Fusion", pass.Name); - } - - #endregion - - #region MultiHeadAttentionFusionPass Tests - - [Fact(Timeout = 120000)] - public async Task MultiHeadAttentionFusionPass_Properties() - { - var pass = new MultiHeadAttentionFusionPass(); - - Assert.Equal(OptimizationPassType.AttentionFusion, pass.PassType); - Assert.Equal("Multi-Head Attention Fusion", pass.Name); - } - - #endregion - - #region Integration Tests - Graph Construction and Optimization - - [Fact(Timeout = 120000)] - public async Task IntegrationTest_SimpleLinearGraph() - { - // Create a simple Input -> ReLU -> Output graph - var graph = new OptimizationGraph(); - - var input = new OptimizationNode - { - Name = "input", - OperationType = OperationType.Input, - OutputShape = new[] { 1, 784 } - }; - - var relu = new OptimizationNode - { - Name = "relu1", - OperationType = OperationType.ReLU, - OutputShape = new[] { 1, 784 } - }; - - var output = new OptimizationNode - { - Name = "output", - OperationType = OperationType.Output, - OutputShape = new[] { 1, 784 } - }; - - graph.AddNode(input); - graph.AddNode(relu); - graph.AddNode(output); - - relu.AddInput(input); - output.AddInput(relu); - - Assert.True(graph.Validate()); - Assert.Equal(3, graph.Nodes.Count); - Assert.Single(graph.InputNodes); - Assert.Single(graph.OutputNodes); - - var order = graph.GetTopologicalOrder(); - Assert.Equal(input, order[0]); - Assert.Equal(relu, order[1]); - Assert.Equal(output, order[2]); - } - - [Fact(Timeout = 120000)] - public async Task IntegrationTest_ApplyMultiplePasses() - { - var graph = new OptimizationGraph(); - - var input = new OptimizationNode { OperationType = OperationType.Input }; - var relu = new OptimizationNode { OperationType = OperationType.ReLU }; - var output = new OptimizationNode { OperationType = OperationType.Output }; - var deadNode = new OptimizationNode { OperationType = OperationType.Add, CanEliminate = true }; - - graph.AddNode(input); - graph.AddNode(relu); - graph.AddNode(output); - graph.AddNode(deadNode); - - relu.AddInput(input); - output.AddInput(relu); - - // Apply dead code elimination - var dcePass = new DeadCodeEliminationPass(); - bool modified = dcePass.Apply(graph); - - Assert.True(modified); - Assert.Equal(3, graph.Nodes.Count); - Assert.DoesNotContain(deadNode, graph.Nodes); - } - - #endregion -} diff --git a/tests/AiDotNet.Tests/IntegrationTests/LoRA/LoRAIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/LoRA/LoRAIntegrationTests.cs index bc64f0fdfd..4f52cd466d 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/LoRA/LoRAIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/LoRA/LoRAIntegrationTests.cs @@ -106,14 +106,6 @@ public async Task LoRALayer_GetSetParameters_RoundTrip() } - [Fact(Timeout = 120000)] - public async Task LoRALayer_SupportsJitCompilation_ReturnsTrueWhenInitialized() - { - var layer = new LoRALayer(InputSize, OutputSize, Rank, Alpha); - - Assert.True(layer.SupportsJitCompilation); - } - #endregion #region StandardLoRAAdapter Tests diff --git a/tests/AiDotNet.Tests/IntegrationTests/MetaLearning/MetaLearningTestModels.cs b/tests/AiDotNet.Tests/IntegrationTests/MetaLearning/MetaLearningTestModels.cs index 3130c19623..a65db64333 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/MetaLearning/MetaLearningTestModels.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/MetaLearning/MetaLearningTestModels.cs @@ -202,13 +202,6 @@ public Dictionary GetFeatureImportance() return importance; } - public bool SupportsJitCompilation => false; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation is not supported."); - } - private string SerializeParameters() { return string.Join(",", _parameters.Select(p => p.ToString("R", CultureInfo.InvariantCulture))); @@ -456,13 +449,6 @@ public Dictionary GetFeatureImportance() return importance; } - public bool SupportsJitCompilation => false; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation is not supported."); - } - private string SerializeParameters() { return string.Join(",", _parameters.Select(p => p.ToString("R", CultureInfo.InvariantCulture))); diff --git a/tests/AiDotNet.Tests/IntegrationTests/MixedPrecision/MixedPrecisionIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/MixedPrecision/MixedPrecisionIntegrationTests.cs index cb1456ac89..5dc8dc363c 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/MixedPrecision/MixedPrecisionIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/MixedPrecision/MixedPrecisionIntegrationTests.cs @@ -1646,7 +1646,6 @@ public TestLayer(string? customName = null) public override string LayerName => _customName ?? base.LayerName; - public override bool SupportsJitCompilation => false; public override bool SupportsTraining => true; // Expose protected members for testing @@ -1665,8 +1664,6 @@ public override void UpdateParameters(float learningRate) { } public override Vector GetParameters() => new Vector(0); public override void ResetState() { } - - public override ComputationNode ExportComputationGraph(List> nodes) => null!; } #endregion diff --git a/tests/AiDotNet.Tests/IntegrationTests/Quantization/QuantizationIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/Quantization/QuantizationIntegrationTests.cs index 552eb75835..67ef48ce29 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/Quantization/QuantizationIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/Quantization/QuantizationIntegrationTests.cs @@ -1706,14 +1706,6 @@ public void ApplyGradients(Vector gradients, T learningRate) // No-op for test model } - // IJitCompilable - public AiDotNet.Autodiff.ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation not supported for test model"); - } - - public bool SupportsJitCompilation => false; - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/IntegrationTests/Regression/KernelRegressionIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/Regression/KernelRegressionIntegrationTests.cs index 76d43398b0..d654cf81c5 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/Regression/KernelRegressionIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/Regression/KernelRegressionIntegrationTests.cs @@ -496,24 +496,6 @@ public async Task LocallyWeightedRegression_SerializeDeserialize_PreservesModel( } } - [Fact(Timeout = 120000)] - public async Task LocallyWeightedRegression_SoftMode_EnablesJitCompilation() - { - // Arrange - var options = new LocallyWeightedRegressionOptions - { - Bandwidth = 1.0, - UseSoftMode = true - }; - var lwr = new LocallyWeightedRegression(options); - var X = CreateMatrix(new double[,] { { 1 }, { 2 }, { 3 } }); - var y = CreateVector(new double[] { 1, 4, 9 }); - lwr.Train(X, y); - - // Act & Assert - Assert.True(lwr.SupportsJitCompilation); - } - #endregion #region KNearestNeighborsRegression Tests @@ -606,23 +588,6 @@ public async Task KNearestNeighborsRegression_SerializeDeserialize_PreservesMode } } - [Fact(Timeout = 120000)] - public async Task KNearestNeighborsRegression_SoftKNN_EnablesJitCompilation() - { - // Arrange - var options = new KNearestNeighborsOptions { K = 3 }; - var knn = new KNearestNeighborsRegression(options); - var X = CreateMatrix(new double[,] { { 1 }, { 2 }, { 3 } }); - var y = CreateVector(new double[] { 1, 4, 9 }); - knn.Train(X, y); - - // Act - knn.UseSoftKNN = true; - - // Assert - Assert.True(knn.SupportsJitCompilation); - } - [Fact(Timeout = 120000)] public async Task KNearestNeighborsRegression_MultiDimensionalData_WorksCorrectly() { diff --git a/tests/AiDotNet.Tests/IntegrationTests/ReinforcementLearning/BaseClassesIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/ReinforcementLearning/BaseClassesIntegrationTests.cs index f9c563e711..690dd9219d 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/ReinforcementLearning/BaseClassesIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/ReinforcementLearning/BaseClassesIntegrationTests.cs @@ -77,16 +77,6 @@ public async Task DeepReinforcementLearningAgentBase_ParameterCount_SumsNetworks Assert.Equal(agent.NetworkParameterCount, agent.ParameterCount); } - [Fact(Timeout = 120000)] - public async Task DeepReinforcementLearningAgentBase_JitRemoved_SupportsJitIsFalse() - { - var agent = new TestDeepAgent(CreateOptions()); - - // After JIT removal, SupportsJitCompilation always returns false - Assert.False(agent.SupportsJitCompilation); - Assert.Throws(() => agent.ExportComputationGraph(new List>())); - } - [Fact(Timeout = 120000)] public async Task ReinforcementLearningAgentBase_DefaultsAndStateRoundTrip_Work() { @@ -96,8 +86,6 @@ public async Task ReinforcementLearningAgentBase_DefaultsAndStateRoundTrip_Work( var action = agent.Predict(state); Assert.Equal(agent.FeatureCount, action.Length); - Assert.False(agent.SupportsJitCompilation); - Assert.Throws(() => agent.ExportComputationGraph(new List>())); Assert.Throws(() => agent.Train(state, action)); var names = agent.FeatureNames; diff --git a/tests/AiDotNet.Tests/IntegrationTests/TransferLearning/TransferLearningAlgorithmsIntegrationTests.cs b/tests/AiDotNet.Tests/IntegrationTests/TransferLearning/TransferLearningAlgorithmsIntegrationTests.cs index 62f60e1aae..cb95bc8edf 100644 --- a/tests/AiDotNet.Tests/IntegrationTests/TransferLearning/TransferLearningAlgorithmsIntegrationTests.cs +++ b/tests/AiDotNet.Tests/IntegrationTests/TransferLearning/TransferLearningAlgorithmsIntegrationTests.cs @@ -550,13 +550,6 @@ public void ApplyGradients(Vector gradients, T learningRate) // No-op for mock } - public bool SupportsJitCompilation => false; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("Mock model does not support JIT compilation"); - } - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/MergedPRBugFixTests.cs b/tests/AiDotNet.Tests/MergedPRBugFixTests.cs index 5775f5064e..161e824827 100644 --- a/tests/AiDotNet.Tests/MergedPRBugFixTests.cs +++ b/tests/AiDotNet.Tests/MergedPRBugFixTests.cs @@ -1537,210 +1537,11 @@ public async Task InterpretabilityMetricsHelper_ValidFunctions_ReturnCorrectResu #endregion - #region InferenceOptimization PR #768 - Production Bug Fixes - - [Fact(Timeout = 60000)] - public async Task OptimizationNode_AddInput_ThrowsForNullInputNode() - { - // ARRANGE - var node = new AiDotNet.InferenceOptimization.Core.OptimizationNode(); - AiDotNet.InferenceOptimization.Core.OptimizationNode? nullInput = null; - - // ACT & ASSERT - var ex = Assert.Throws(() => node.AddInput(nullInput!)); - Assert.Equal("inputNode", ex.ParamName); - } - - [Fact(Timeout = 60000)] - public async Task OptimizationNode_RemoveInput_ThrowsForNullInputNode() - { - // ARRANGE - var node = new AiDotNet.InferenceOptimization.Core.OptimizationNode(); - AiDotNet.InferenceOptimization.Core.OptimizationNode? nullInput = null; - - // ACT & ASSERT - var ex = Assert.Throws(() => node.RemoveInput(nullInput!)); - Assert.Equal("inputNode", ex.ParamName); - } - - [Fact(Timeout = 60000)] - public async Task OptimizationNode_ReplaceInput_ThrowsForNullOldInput() - { - // ARRANGE - var node = new AiDotNet.InferenceOptimization.Core.OptimizationNode(); - var newInput = new AiDotNet.InferenceOptimization.Core.OptimizationNode(); - AiDotNet.InferenceOptimization.Core.OptimizationNode? nullInput = null; - - // ACT & ASSERT - var ex = Assert.Throws(() => node.ReplaceInput(nullInput!, newInput)); - Assert.Equal("oldInput", ex.ParamName); - } - - [Fact(Timeout = 60000)] - public async Task OptimizationNode_ReplaceInput_ThrowsForNullNewInput() - { - // ARRANGE - var node = new AiDotNet.InferenceOptimization.Core.OptimizationNode(); - var oldInput = new AiDotNet.InferenceOptimization.Core.OptimizationNode(); - AiDotNet.InferenceOptimization.Core.OptimizationNode? nullInput = null; - - // ACT & ASSERT - var ex = Assert.Throws(() => node.ReplaceInput(oldInput, nullInput!)); - Assert.Equal("newInput", ex.ParamName); - } - - [Fact(Timeout = 60000)] - public async Task OptimizationGraph_FindNodeById_ThrowsForNullId() - { - // ARRANGE - var graph = new AiDotNet.InferenceOptimization.Core.OptimizationGraph(); - string? nullId = null; - - // ACT & ASSERT - var ex = Assert.Throws(() => graph.FindNodeById(nullId!)); - Assert.Equal("id", ex.ParamName); - } - - [Fact(Timeout = 60000)] - public async Task OptimizationGraph_FindNodesByName_ThrowsForNullName() - { - // ARRANGE - var graph = new AiDotNet.InferenceOptimization.Core.OptimizationGraph(); - string? nullName = null; - - // ACT & ASSERT - var ex = Assert.Throws(() => graph.FindNodesByName(nullName!)); - Assert.Equal("name", ex.ParamName); - } - - [Fact(Timeout = 60000)] - public async Task IRDataTypeExtensions_FromSystemType_ThrowsForNullType() - { - // ARRANGE - Type? nullType = null; - - // ACT & ASSERT - var ex = Assert.Throws(() => - AiDotNet.InferenceOptimization.IR.Common.IRDataTypeExtensions.FromSystemType(nullType!)); - Assert.Equal("type", ex.ParamName); - } - - [Fact(Timeout = 60000)] - public async Task TensorType_IsBroadcastCompatible_ThrowsForNullOther() - { - // ARRANGE - var tensorType = new AiDotNet.InferenceOptimization.IR.Common.TensorType - { - Shape = new int[] { 3, 4 } - }; - AiDotNet.InferenceOptimization.IR.Common.TensorType? nullOther = null; - - // ACT & ASSERT - var ex = Assert.Throws(() => tensorType.IsBroadcastCompatible(nullOther!)); - Assert.Equal("other", ex.ParamName); - } - - [Fact(Timeout = 60000)] - public async Task GraphOptimizer_Optimize_ThrowsForNullGraph() - { - // ARRANGE - var optimizer = new AiDotNet.InferenceOptimization.Core.GraphOptimizer(); - AiDotNet.InferenceOptimization.Core.IOptimizationGraph? nullGraph = null; - - // ACT & ASSERT - var ex = Assert.Throws(() => optimizer.Optimize(nullGraph!)); - Assert.Equal("graph", ex.ParamName); - } - - [Fact(Timeout = 60000)] - public async Task GraphOptimizer_AddPass_ThrowsForNullPass() - { - // ARRANGE - var optimizer = new AiDotNet.InferenceOptimization.Core.GraphOptimizer(); - AiDotNet.InferenceOptimization.Passes.IOptimizationPass? nullPass = null; - - // ACT & ASSERT - var ex = Assert.Throws(() => optimizer.AddPass(nullPass!)); - Assert.Equal("pass", ex.ParamName); - } - - [Fact(Timeout = 60000)] - public async Task OptimizationNode_AddInputRemoveInput_ValidInputs_WorksCorrectly() - { - // ARRANGE - var node = new AiDotNet.InferenceOptimization.Core.OptimizationNode { Name = "output" }; - var inputNode = new AiDotNet.InferenceOptimization.Core.OptimizationNode { Name = "input" }; - - // ACT - Add input - node.AddInput(inputNode); - - // ASSERT - Input was added - Assert.Contains(inputNode, node.Inputs); - Assert.Contains(node, inputNode.Outputs); - - // ACT - Remove input - node.RemoveInput(inputNode); - - // ASSERT - Input was removed - Assert.DoesNotContain(inputNode, node.Inputs); - Assert.DoesNotContain(node, inputNode.Outputs); - } - - [Fact(Timeout = 60000)] - public async Task OptimizationGraph_FindNodesByIdAndName_ValidInputs_WorksCorrectly() - { - // ARRANGE - var graph = new AiDotNet.InferenceOptimization.Core.OptimizationGraph(); - var node1 = new AiDotNet.InferenceOptimization.Core.OptimizationNode { Name = "conv1" }; - var node2 = new AiDotNet.InferenceOptimization.Core.OptimizationNode { Name = "conv1" }; - var node3 = new AiDotNet.InferenceOptimization.Core.OptimizationNode { Name = "relu1" }; - graph.AddNode(node1); - graph.AddNode(node2); - graph.AddNode(node3); - - // ACT & ASSERT - FindNodeById - var foundById = graph.FindNodeById(node1.Id); - Assert.Same(node1, foundById); - - // ACT & ASSERT - FindNodesByName - var foundByName = graph.FindNodesByName("conv1"); - Assert.Equal(2, foundByName.Count); - - // ACT & ASSERT - FindNodeById with non-existent ID returns null - var notFound = graph.FindNodeById("non-existent-id"); - Assert.Null(notFound); - } - - [Fact(Timeout = 60000)] - public async Task TensorType_IsBroadcastCompatible_ValidInputs_WorksCorrectly() - { - // ARRANGE - var type1 = new AiDotNet.InferenceOptimization.IR.Common.TensorType { Shape = new int[] { 3, 4 } }; - var type2 = new AiDotNet.InferenceOptimization.IR.Common.TensorType { Shape = new int[] { 3, 4 } }; - var type3 = new AiDotNet.InferenceOptimization.IR.Common.TensorType { Shape = new int[] { 1, 4 } }; - var type4 = new AiDotNet.InferenceOptimization.IR.Common.TensorType { Shape = new int[] { 3, 5 } }; - - // ACT & ASSERT - Assert.True(type1.IsBroadcastCompatible(type2)); // Same shape - Assert.True(type1.IsBroadcastCompatible(type3)); // Broadcastable (1 can become 3) - Assert.False(type1.IsBroadcastCompatible(type4)); // Not compatible (4 != 5) - } - - [Fact(Timeout = 60000)] - public async Task IRDataTypeExtensions_FromSystemType_ValidTypes_ReturnsCorrectDataType() - { - // ACT & ASSERT - Assert.Equal(AiDotNet.InferenceOptimization.IR.Common.IRDataType.Float32, - AiDotNet.InferenceOptimization.IR.Common.IRDataTypeExtensions.FromSystemType(typeof(float))); - Assert.Equal(AiDotNet.InferenceOptimization.IR.Common.IRDataType.Float64, - AiDotNet.InferenceOptimization.IR.Common.IRDataTypeExtensions.FromSystemType(typeof(double))); - Assert.Equal(AiDotNet.InferenceOptimization.IR.Common.IRDataType.Int32, - AiDotNet.InferenceOptimization.IR.Common.IRDataTypeExtensions.FromSystemType(typeof(int))); - Assert.Equal(AiDotNet.InferenceOptimization.IR.Common.IRDataType.Bool, - AiDotNet.InferenceOptimization.IR.Common.IRDataTypeExtensions.FromSystemType(typeof(bool))); - } - - #endregion + // The InferenceOptimization PR #768 bug-fix tests were removed alongside the + // src/InferenceOptimization/ tree. That tree was orphaned JIT IR/passes/kernels + // scaffolding with zero production callers. Compilation now flows through + // AiDotNet.Tensors' AutoTracer → CompiledInferencePlan → CompiledTrainingPlan + // pipeline (enabled by default; opt out via TensorCodecOptions). #region HyperparameterOptimization PR #767 - Production Bug Fixes diff --git a/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/AdversarialAttackTests.cs b/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/AdversarialAttackTests.cs index bed6daa534..61aa17ff42 100644 --- a/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/AdversarialAttackTests.cs +++ b/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/AdversarialAttackTests.cs @@ -89,8 +89,6 @@ public void SetParameters(Vector parameters) { } public void SetActiveFeatureIndices(IEnumerable featureIndices) => _activeFeatures = featureIndices.ToList(); public bool IsFeatureUsed(int featureIndex) => _activeFeatures.Contains(featureIndex); public Dictionary GetFeatureImportance() => Enumerable.Range(0, _inputDim).ToDictionary(i => $"Feature{i}", i => 1.0 / _inputDim); - public bool SupportsJitCompilation => false; - public ComputationNode ExportComputationGraph(List> inputNodes) => throw new NotSupportedException(); public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/AdversarialTrainingTests.cs b/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/AdversarialTrainingTests.cs index d7ad61b3b1..aca244c358 100644 --- a/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/AdversarialTrainingTests.cs +++ b/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/AdversarialTrainingTests.cs @@ -75,8 +75,6 @@ public void SetParameters(Vector parameters) { } public void SetActiveFeatureIndices(IEnumerable featureIndices) => _activeFeatures = featureIndices.ToList(); public bool IsFeatureUsed(int featureIndex) => _activeFeatures.Contains(featureIndex); public Dictionary GetFeatureImportance() => Enumerable.Range(0, _inputDim).ToDictionary(i => $"Feature{i}", i => 1.0 / _inputDim); - public bool SupportsJitCompilation => false; - public ComputationNode ExportComputationGraph(List> inputNodes) => throw new NotSupportedException(); public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/RandomizedSmoothingTests.cs b/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/RandomizedSmoothingTests.cs index 40cf2fccaf..13d887732c 100644 --- a/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/RandomizedSmoothingTests.cs +++ b/tests/AiDotNet.Tests/UnitTests/AdversarialRobustness/RandomizedSmoothingTests.cs @@ -78,8 +78,6 @@ public void SetParameters(Vector parameters) { } public void SetActiveFeatureIndices(IEnumerable featureIndices) => _activeFeatures = featureIndices.ToList(); public bool IsFeatureUsed(int featureIndex) => _activeFeatures.Contains(featureIndex); public Dictionary GetFeatureImportance() => Enumerable.Range(0, _inputDim).ToDictionary(i => $"Feature{i}", i => 1.0 / _inputDim); - public bool SupportsJitCompilation => false; - public ComputationNode ExportComputationGraph(List> inputNodes) => throw new NotSupportedException(); public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/UnitTests/ContinualLearning/ElasticWeightConsolidationTests.cs b/tests/AiDotNet.Tests/UnitTests/ContinualLearning/ElasticWeightConsolidationTests.cs index 2400cf03f7..923345be26 100644 --- a/tests/AiDotNet.Tests/UnitTests/ContinualLearning/ElasticWeightConsolidationTests.cs +++ b/tests/AiDotNet.Tests/UnitTests/ContinualLearning/ElasticWeightConsolidationTests.cs @@ -218,14 +218,6 @@ public void ApplyGradients(Vector gradients, double learningRate) } } - // IJitCompilable - public bool SupportsJitCompilation => false; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("MockModel does not support JIT compilation."); - } - public Vector SanitizeParameters(Vector parameters) => parameters; } } diff --git a/tests/AiDotNet.Tests/UnitTests/Genetics/ModelIndividualTests.cs b/tests/AiDotNet.Tests/UnitTests/Genetics/ModelIndividualTests.cs index 20bf6f5c82..9fe32695ad 100644 --- a/tests/AiDotNet.Tests/UnitTests/Genetics/ModelIndividualTests.cs +++ b/tests/AiDotNet.Tests/UnitTests/Genetics/ModelIndividualTests.cs @@ -183,28 +183,6 @@ public void ApplyGradients(Vector gradients, double learningRate) } } - // IJitCompilable implementation - public bool SupportsJitCompilation => true; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - // Create a simple computation graph for the mock model - var inputShape = new int[] { 1, _parameterCount }; - var inputTensor = new Tensor(inputShape); - var inputNode = TensorOperations.Variable(inputTensor, "input"); - inputNodes.Add(inputNode); - - // Create parameter node - var paramTensor = new Tensor(new int[] { _parameterCount }, _parameters); - var paramNode = TensorOperations.Variable(paramTensor, "parameters"); - inputNodes.Add(paramNode); - - // Compute element-wise multiply and sum - var mulNode = TensorOperations.ElementwiseMultiply(inputNode, paramNode); - var outputNode = TensorOperations.Sum(mulNode); - return outputNode; - } - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/UnitTests/Inference/QuantizedAttentionTests.cs b/tests/AiDotNet.Tests/UnitTests/Inference/QuantizedAttentionTests.cs index 23de1a55a5..dd498edacf 100644 --- a/tests/AiDotNet.Tests/UnitTests/Inference/QuantizedAttentionTests.cs +++ b/tests/AiDotNet.Tests/UnitTests/Inference/QuantizedAttentionTests.cs @@ -93,7 +93,6 @@ public async Task QuantizedAttention_IsInferenceOnly() var quantized = new QuantizedAttentionLayer(mha); Assert.False(quantized.SupportsTraining); - Assert.False(quantized.SupportsJitCompilation); Assert.Equal(0, quantized.ParameterCount); Assert.Null(quantized.GetWeights()); Assert.Null(quantized.GetBiases()); diff --git a/tests/AiDotNet.Tests/UnitTests/KnowledgeDistillation/TeacherModelFactoryTests.cs b/tests/AiDotNet.Tests/UnitTests/KnowledgeDistillation/TeacherModelFactoryTests.cs index 5a0c6cb3d2..a2db497a2f 100644 --- a/tests/AiDotNet.Tests/UnitTests/KnowledgeDistillation/TeacherModelFactoryTests.cs +++ b/tests/AiDotNet.Tests/UnitTests/KnowledgeDistillation/TeacherModelFactoryTests.cs @@ -364,22 +364,6 @@ public IFullModel, Vector> WithParameters(Vector< return copy; } - // IJitCompilable implementation - public bool SupportsJitCompilation => true; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - // Create a computation graph for the mock model - var inputShape = new int[] { 1, _inputDim }; - var inputTensor = new Tensor(inputShape); - var inputNode = TensorOperations.Variable(inputTensor, "input"); - inputNodes.Add(inputNode); - - // Simple computation: sum of input elements normalized - var sumNode = TensorOperations.Sum(inputNode); - return sumNode; - } - public Vector SanitizeParameters(Vector parameters) => parameters; } } diff --git a/tests/AiDotNet.Tests/UnitTests/Layers/AdvancedAlgebraLayerTests.cs b/tests/AiDotNet.Tests/UnitTests/Layers/AdvancedAlgebraLayerTests.cs index 7d9f14fe09..cf675e03f2 100644 --- a/tests/AiDotNet.Tests/UnitTests/Layers/AdvancedAlgebraLayerTests.cs +++ b/tests/AiDotNet.Tests/UnitTests/Layers/AdvancedAlgebraLayerTests.cs @@ -119,9 +119,6 @@ public async Task OctonionLinearLayer_SupportsTraining_IsTrue() Assert.True(layer.SupportsTraining); } - // JIT compilation tests removed — JIT moved to Tensors engine level (Lazy Graph Compiler v0.28.0). - // Per-layer SupportsJitCompilation will be removed in a dedicated PR. - #endregion #region HyperbolicLinearLayer Tests diff --git a/tests/AiDotNet.Tests/UnitTests/MetaLearning/Helpers/MatrixMockModel.cs b/tests/AiDotNet.Tests/UnitTests/MetaLearning/Helpers/MatrixMockModel.cs index c1b8dcaad9..c0fa5de3a1 100644 --- a/tests/AiDotNet.Tests/UnitTests/MetaLearning/Helpers/MatrixMockModel.cs +++ b/tests/AiDotNet.Tests/UnitTests/MetaLearning/Helpers/MatrixMockModel.cs @@ -142,13 +142,5 @@ public void ApplyGradients(Vector gradients, double learningRate) } } - // IJitCompilable implementation - public bool SupportsJitCompilation => false; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - throw new NotSupportedException("JIT compilation not supported in mock model"); - } - public Vector SanitizeParameters(Vector parameters) => parameters; } diff --git a/tests/AiDotNet.Tests/UnitTests/MetaLearning/Helpers/SimpleMockModel.cs b/tests/AiDotNet.Tests/UnitTests/MetaLearning/Helpers/SimpleMockModel.cs index ff06c9875d..e0bdefc05e 100644 --- a/tests/AiDotNet.Tests/UnitTests/MetaLearning/Helpers/SimpleMockModel.cs +++ b/tests/AiDotNet.Tests/UnitTests/MetaLearning/Helpers/SimpleMockModel.cs @@ -141,28 +141,6 @@ public void ApplyGradients(Vector gradients, double learningRate) } } - // IJitCompilable implementation - public bool SupportsJitCompilation => true; - - public ComputationNode ExportComputationGraph(List> inputNodes) - { - // Create a simple linear computation graph: output = sum(input * parameters) - var inputShape = new int[] { 1, _parameters.Length }; - var inputTensor = new Tensor(inputShape); - var inputNode = TensorOperations.Variable(inputTensor, "input"); - inputNodes.Add(inputNode); - - // Create parameter node - var paramTensor = new Tensor(new int[] { _parameters.Length }, _parameters); - var paramNode = TensorOperations.Variable(paramTensor, "parameters"); - inputNodes.Add(paramNode); - - // Compute element-wise multiply and sum - var mulNode = TensorOperations.ElementwiseMultiply(inputNode, paramNode); - var outputNode = TensorOperations.Sum(mulNode); - return outputNode; - } - // ISecondOrderGradientComputable implementation public Vector ComputeSecondOrderGradients( List<(Tensor input, Tensor target)> adaptationSteps, diff --git a/tests/AiDotNet.Tests/UnitTests/NeuralNetworks/Layers/RotaryPositionalEncodingLayerTests.cs b/tests/AiDotNet.Tests/UnitTests/NeuralNetworks/Layers/RotaryPositionalEncodingLayerTests.cs index 26fa266d13..7fed985b0d 100644 --- a/tests/AiDotNet.Tests/UnitTests/NeuralNetworks/Layers/RotaryPositionalEncodingLayerTests.cs +++ b/tests/AiDotNet.Tests/UnitTests/NeuralNetworks/Layers/RotaryPositionalEncodingLayerTests.cs @@ -16,7 +16,6 @@ public async Task Constructor_ValidParameters_CreatesLayer() var layer = new RotaryPositionalEncodingLayer(128, 64); Assert.NotNull(layer); - Assert.False(layer.SupportsJitCompilation); Assert.True(layer.SupportsTraining); }