Skip to content

perf: JIT compiler integration for zero-allocation diffusion model forward passes #1015

@ooples

Description

@ooples

Goal

Complete the AiDotNet JIT compiler integration with TensorWorkspace to enable zero-allocation, fused-kernel execution for diffusion models and all neural networks. The JIT compiler already has 8 optimization passes, 170+ IR operations, and 20+ fused operations but is not wired into the forward pass of any model.

What Needs to Happen in AiDotNet (not AiDotNet.Tensors)

Graph Capture for Neural Networks

  • LayerBase graph capture mode: Add a mode where Forward() records operations to a computation graph instead of executing them. First call captures, subsequent calls execute compiled graph.
  • NeuralNetworkBase.CompileForward(): Method that exports the full forward pass as a computation graph, runs JIT compilation, and stores the compiled function.
  • UNetNoisePredictor compiled forward: Wire ForwardUNet to use compiled graph when available.
  • DiffusionResBlock graph export: ExportComputationGraph currently delegates to conv1 only. Needs to export the full GroupNorm->SiLU->Conv->TimeEmbed->GroupNorm->SiLU->Conv->Residual chain.

Missing IR Operations for Diffusion

  • GroupNormOp: GroupNorm is in every DiffusionResBlock but has no IR operation. Cannot be optimized or fused by the JIT compiler.
  • FusedGroupNormActivationOp: Fuse GroupNorm + SiLU (the most repeated pattern in diffusion UNets). Eliminates one 41MB intermediate tensor per ResBlock.
  • FusedConv2DBiasActivationOp: Maps to IEngine.FusedConv2D. The diffusion UNet uses Conv+Bias+Identity (not Conv+BatchNorm), so we need this distinct from FusedConvBatchNormActivationOp.

Missing Fusion Patterns

  • GroupNorm + SiLU: Pattern 11 in OperationFusionPass
  • Conv2D + Bias + SiLU: Pattern 12 (uses FusedConv2D from IEngine)
  • GroupNorm + SiLU + Conv2D: Pattern 13 (3-op fusion for the full DiffusionResBlock half-block)
  • Residual Add + GroupNorm: Pattern 14 (fuses the skip connection add with the next norm)

Missing Optimization Passes

  • Memory planning pass: Analyze tensor lifetimes in the compiled graph, assign dead tensors to overlapping workspace slots. Compute minimum memory footprint. This is what lets us run production SD15 (860M params) without OOM.
  • Tile scheduling pass: Partition Conv2D and MatMul into L1/L2-optimal tile sizes. PyTorch TorchInductor does this; we need it for competitive CPU performance.
  • Operator reordering pass: Schedule independent operations to maximize cache line reuse between producer and consumer.

Benchmark Infrastructure

  • benchmarks/ project: BenchmarkDotNet project comparing against PyTorch (via TorchSharp)
  • Conv2D benchmark: Various channel counts (64, 128, 256, 512, 1280) and spatial sizes (8x8, 16x16, 32x32, 64x64)
  • DiffusionResBlock benchmark: Full ResBlock forward (GroupNorm->SiLU->Conv->GroupNorm->SiLU->Conv+skip) at paper dimensions
  • UNet benchmark: Full SD15 UNet forward pass (input [1,4,64,64], base=320, [1,2,4,4])
  • Memory benchmark: Track peak allocation, GC pauses, total allocation bytes during 50-step Predict

Integration with ConvolutionalLayer

  • ConvolutionalLayer.Forward uses JIT when compiled: If a compiled graph is available, execute it instead of the interpreted Conv2DInto + BroadcastAddInPlace + ApplyActivation path.
  • FusedConv2D selection: ConvolutionalLayer already calls GetFusedActivationType() and Engine.FusedConv2D. The JIT should produce the same fused call but with workspace memory.

Quick Wins (Do Now for Diffusion)

  1. Add GroupNormOp to IR operations
  2. Add FusedGroupNormActivationOp to IR operations
  3. Add GroupNorm+SiLU fusion pattern to OperationFusionPass
  4. Fix DiffusionResBlock.ExportComputationGraph to export the full chain
  5. Create benchmark project with Conv2D and ResBlock benchmarks

Related

  • AiDotNet.Tensors#38: TensorWorkspace + JIT compiler integration
  • AiDotNet.Tensors#37: In-place IEngine ops + TensorWorkspace foundation

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions