diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d55fa0ac..cb7a3773 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -218,6 +218,99 @@ jobs: path: /tmp/SwiftLM-test-speculative.log retention-days: 7 + # ── DFlash Speculative Decoding E2E ── + # Uses the standard macos-15 runner (7 GB RAM). + dflash-speculative-decoding: + runs-on: macos-15 + timeout-minutes: 45 + needs: build_and_unit_test + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Install Metal Toolchain + run: xcodebuild -downloadComponent MetalToolchain || true + + - name: Cache Swift packages + uses: actions/cache@v4 + with: + path: .build + key: ${{ runner.os }}-spm-SwiftLM-v3-${{ hashFiles('Package.resolved') }} + restore-keys: | + ${{ runner.os }}-spm-SwiftLM-v3- + + - name: Clear stale module cache + run: find .build -type d -name ModuleCache -exec rm -rf {} + 2>/dev/null || true + + - name: Resolve dependencies + run: swift package resolve + + - name: Build (Release) + run: swift build -c release + + - name: Compile and install custom MLX Metal library + run: | + if [ -d "mlx-swift/Source/Cmlx/mlx" ]; then + MLX_SRC="mlx-swift/Source/Cmlx/mlx" + else + MLX_SRC=".build/checkouts/mlx-swift/Source/Cmlx/mlx" + fi + mkdir -p .build/metallib_build + pushd .build/metallib_build + cmake "../../$MLX_SRC" \ + -DMLX_BUILD_TESTS=OFF \ + -DMLX_BUILD_EXAMPLES=OFF \ + -DMLX_BUILD_BENCHMARKS=OFF \ + -DMLX_BUILD_PYTHON_BINDINGS=OFF \ + -DMLX_METAL_JIT=OFF \ + -DMLX_ENABLE_NAX=1 \ + -DCMAKE_BUILD_TYPE=Release 2>&1 | tail -20 + make mlx-metallib -j$(sysctl -n hw.ncpu) 2>&1 | tail -20 + popd + BUILT=$(find .build/metallib_build -name "mlx.metallib" | head -1) + cp "$BUILT" .build/release/mlx.metallib + python3 -m venv /tmp/mlx_venv + /tmp/mlx_venv/bin/pip install --quiet huggingface_hub hf + + - name: Cache MLX models (dflash + main) + uses: actions/cache@v4 + with: + path: ~/.cache/huggingface + key: mlx-dflash-qwen35-4b + + - name: Pre-download HuggingFace models + run: | + source /tmp/mlx_venv/bin/activate + hf download mlx-community/Qwen3.5-4B-4bit || true + hf download z-lab/Qwen3.5-4B-DFlash || true + + - name: Run DFlash E2E + env: + HF_HUB_DOWNLOAD_TIMEOUT: "900" + run: | + chmod +x tests/test-dflash.sh + for attempt in 1 2 3; do + echo "Attempt $attempt of 3..." + if tests/test-dflash.sh .build/release/SwiftLM 15415; then + exit 0 + fi + if [ "$attempt" -lt 3 ]; then + echo "Test failed, retrying in 10s..." + sleep 10 + fi + done + echo "All attempts failed" + exit 1 + + - name: Upload dflash test logs on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: dflash-test-logs + path: /tmp/SwiftLM-test-dflash.log + retention-days: 7 + # ── Speculative Decoding Memory Evaluation ── # Runs the 2B model with NUM_DRAFT_TOKENS=2 to check peak # memory compression/efficiency. Emits vm_stat readings as step summary. diff --git a/.gitignore b/.gitignore index e25d0db7..c38e3792 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,6 @@ tmp/ .agents/harness/audio-omni-gemma4/runs/ .venv/ mem-palace/ + + +tests/DFlash/intermediates/ diff --git a/Package.resolved b/Package.resolved index b5e6e0a6..e35107aa 100644 --- a/Package.resolved +++ b/Package.resolved @@ -50,8 +50,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-asn1.git", "state" : { - "revision" : "9f542610331815e29cc3821d3b6f488db8715517", - "version" : "1.6.0" + "revision" : "eb50cbd14606a9161cbc5d452f18797c90ef0bab", + "version" : "1.7.0" } }, { @@ -77,8 +77,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-certificates.git", "state" : { - "revision" : "24ccdeeeed4dfaae7955fcac9dbf5489ed4f1a25", - "version" : "1.18.0" + "revision" : "5aa1c0d1bc204908df47c2075bdbb39573d05e8d", + "version" : "1.19.0" } }, { @@ -104,8 +104,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-crypto.git", "state" : { - "revision" : "bb4ba815dab96d4edc1e0b86d7b9acf9ff973a84", - "version" : "4.3.1" + "revision" : "1b6b2e274e85105bfa155183145a1dcfd63331f1", + "version" : "4.5.0" } }, { @@ -122,8 +122,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-http-structured-headers.git", "state" : { - "revision" : "76d7627bd88b47bf5a0f8497dd244885960dde0b", - "version" : "1.6.0" + "revision" : "933538faa42c432d385f02e07df0ace7c5ecfc47", + "version" : "1.7.0" } }, { @@ -158,8 +158,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-log.git", "state" : { - "revision" : "8c0f217f01000dd30f60d6e536569ad4e74291f9", - "version" : "1.11.0" + "revision" : "5073617dac96330a486245e4c0179cb0a6fd2256", + "version" : "1.12.0" } }, { @@ -167,8 +167,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-metrics.git", "state" : { - "revision" : "59a494d2ad97b0796db5119ef19fe1d48618d12b", - "version" : "2.9.0" + "revision" : "d51c8d13fa366eec807eedb4e37daa60ff5bfdd5", + "version" : "2.10.1" } }, { @@ -176,8 +176,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio.git", "state" : { - "revision" : "558f24a4647193b5a0e2104031b71c55d31ff83a", - "version" : "2.97.1" + "revision" : "f71c8d2a5e74a2c6d11a0fbe324774b5d6084237", + "version" : "2.99.0" } }, { @@ -185,8 +185,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-extras.git", "state" : { - "revision" : "abcf5312eb8ed2fb11916078aef7c46b06f20813", - "version" : "1.33.0" + "revision" : "5a48717e29f62cb8326d6d42e46b562ca93847a6", + "version" : "1.34.0" } }, { @@ -194,8 +194,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-http2.git", "state" : { - "revision" : "6d8d596f0a9bfebb925733003731fe2d749b7e02", - "version" : "1.42.0" + "revision" : "81cc18264f92cd307ff98430f89372711d4f6fe9", + "version" : "1.43.0" } }, { @@ -203,8 +203,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-ssl.git", "state" : { - "revision" : "df9c3406028e3297246e6e7081977a167318b692", - "version" : "2.36.1" + "revision" : "3f337058ccd7243c4cac7911477d8ad4c598d4da", + "version" : "2.37.0" } }, { @@ -212,8 +212,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-nio-transport-services.git", "state" : { - "revision" : "60c3e187154421171721c1a38e800b390680fb5d", - "version" : "1.26.0" + "revision" : "67787bb645a5e67d2edcdfbe48a216cc549222d5", + "version" : "1.28.0" } }, { diff --git a/Package.swift b/Package.swift index 6314eb66..42bccb66 100644 --- a/Package.swift +++ b/Package.swift @@ -6,8 +6,10 @@ let package = Package( platforms: [.macOS(.v14), .iOS(.v17)], products: [ .library(name: "MLXInferenceCore", targets: ["MLXInferenceCore"]), + .library(name: "DFlash", targets: ["DFlash"]), .executable(name: "SwiftLM", targets: ["SwiftLM"]), - .executable(name: "SwiftBuddy", targets: ["SwiftBuddy"]) + .executable(name: "SwiftBuddy", targets: ["SwiftBuddy"]), + .executable(name: "DFlashKernelBench", targets: ["DFlashKernelBench"]) ], dependencies: [ // Local Apple MLX Swift fork for C++ extensions @@ -29,6 +31,7 @@ let package = Package( name: "SwiftLM", dependencies: [ "MLXInferenceCore", + "DFlash", .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXLLM", package: "mlx-swift-lm"), .product(name: "MLXVLM", package: "mlx-swift-lm"), @@ -40,6 +43,16 @@ let package = Package( ], path: "Sources/SwiftLM" ), + // ── DFlash Kernel Micro-Benchmark ─────────────────────────── + .executableTarget( + name: "DFlashKernelBench", + dependencies: [ + "DFlash", + .product(name: "MLX", package: "mlx-swift"), + .product(name: "MLXNN", package: "mlx-swift"), + ], + path: "Sources/DFlashKernelBench" + ), // ── STFT Audio Profiling Testing Script (macOS only) ─────────── .executableTarget( name: "SwiftLMTestSTFT", @@ -86,6 +99,17 @@ let package = Package( .enableExperimentalFeature("StrictConcurrency") ] ), + // ── DFlash Speculative Decoding ───────────────────────────── + .target( + name: "DFlash", + dependencies: [ + .product(name: "MLX", package: "mlx-swift"), + .product(name: "MLXLLM", package: "mlx-swift-lm"), + .product(name: "MLXLMCommon", package: "mlx-swift-lm"), + ], + path: "Sources/DFlash", + exclude: ["DFlashKernelsOptimized.swift"] + ), // ── Automated Test Harness ────────────────────────────────── .testTarget( name: "SwiftBuddyTests", diff --git a/README.md b/README.md index 52f9fec4..3a8d3778 100644 --- a/README.md +++ b/README.md @@ -438,6 +438,8 @@ curl http://localhost:5413/v1/chat/completions \ | `--turbo-kv` | `false` | Enable TurboQuant 3-bit KV cache compression (activates after 2048 tokens, server-wide) | | `--draft-model` | (none) | Draft model path/ID for speculative decoding. When used with `--stream-experts`, `--num-draft-tokens` is auto-capped to 1 to minimise SSD I/O fan-out (see performance note above). | | `--num-draft-tokens` | `4` | Tokens per speculation round. Auto-capped to 1 when combined with `--stream-experts`. | +| `--dflash` | `false` | Enable DFlash block-diffusion speculative decoding. Requires a compatible DFlash draft model | +| `--dflash-block-size`| (auto) | Number of tokens per DFlash draft block. Defaults to draft model config | ## 🔧 Per-Request API Parameters diff --git a/Sources/DFlash/DFlashDraftBackend.swift b/Sources/DFlash/DFlashDraftBackend.swift new file mode 100644 index 00000000..e7bccae4 --- /dev/null +++ b/Sources/DFlash/DFlashDraftBackend.swift @@ -0,0 +1,91 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Based on DFlash (arXiv:2602.06036) + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - Draft Backend + +/// Backend for generating draft tokens using the DFlash draft model. +public final class DFlashDraftBackend: @unchecked Sendable { + + public init() {} + + /// Create the draft cache (one `ContextOnlyDraftKVCache` per layer). + public func makeCache( + draftModel: DFlashDraftModel, + sinkSize: Int = 64, + windowSize: Int = 1024 + ) -> [ContextOnlyDraftKVCache] { + (0 ..< draftModel.layers.count).map { _ in + ContextOnlyDraftKVCache(sinkSize: sinkSize, windowSize: windowSize) + } + } + + /// Generate draft tokens greedily using the DFlash draft model. + /// + /// - Parameters: + /// - targetModel: The target model (must conform to DFlashTargetModel for embed/lm_head access) + /// - draftModel: The DFlash draft model + /// - draftCache: The draft model's KV caches + /// - stagedFirst: The first token (already verified by the target) + /// - targetHidden: The target model's hidden states for context + /// - blockLen: Number of tokens to draft + /// - maskTokenTail: Mask token IDs for positions 1..blockLen-1 + /// - suppressTokenMask: Optional mask to suppress certain tokens + /// - Returns: Draft token IDs [blockLen-1] + public func draftGreedy( + targetModel: any DFlashTargetModel, + draftModel: DFlashDraftModel, + draftCache: [ContextOnlyDraftKVCache], + stagedFirst: MLXArray, + targetHidden: MLXArray, + blockLen: Int, + maskTokenTail: MLXArray, + suppressTokenMask: MLXArray? = nil + ) -> MLXArray { + precondition(blockLen > 1, "draftGreedy requires blockLen > 1") + + let blockTokenIDs = concatenated( + [stagedFirst[..<1], maskTokenTail[..<(blockLen - 1)]], + axis: 0 + ) + + // Get noise embedding from target model's embed_tokens + let noiseEmbedding = targetModel.dflashEmbedTokens(blockTokenIDs[.newAxis]) + if DFlashDumper.isEnabled { + DFlashDumper.saveInt("swift_block_token_ids", blockTokenIDs[.newAxis]) + DFlashDumper.save("swift_noise_embedding", noiseEmbedding) + } + + // Run the draft model + let draftHidden = draftModel( + noiseEmbedding: noiseEmbedding, + targetHidden: targetHidden, + cache: draftCache + ) + if DFlashDumper.isEnabled { + DFlashDumper.save("swift_draft_hidden", draftHidden) + } + + // Get draft logits via the target model's lm_head + let draftLogits = targetModel.dflashLmHeadLogits( + draftHidden[.ellipsis, 1..., 0...] + ) + if DFlashDumper.isEnabled { + DFlashDumper.save("swift_draft_logits", draftLogits) + } + + // Greedy decode + let drafted = DFlashRuntime.greedyTokensWithMask( + logits: draftLogits, + suppressTokenMask: suppressTokenMask + ).squeezed(axis: 0) + + asyncEval(drafted) + return drafted + } +} diff --git a/Sources/DFlash/DFlashDraftModel.swift b/Sources/DFlash/DFlashDraftModel.swift new file mode 100644 index 00000000..3b7b0f46 --- /dev/null +++ b/Sources/DFlash/DFlashDraftModel.swift @@ -0,0 +1,417 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Based on DFlash (arXiv:2602.06036) + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - DFlash GLU MLP + +/// Gated Linear Unit MLP for the DFlash draft model. +/// Equivalent to Qwen3NextMLP / Llama MLP with SwiGLU activation. +final class DFlashGLUMLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gateProj: Linear + @ModuleInfo(key: "up_proj") var upProj: Linear + @ModuleInfo(key: "down_proj") var downProj: Linear + + init(dimensions: Int, hiddenDimensions: Int) { + _gateProj.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + _upProj.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + _downProj.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + downProj(silu(gateProj(x)) * upProj(x)) + } +} + +// MARK: - Draft Model Configuration + +/// Configuration for the DFlash draft model, deserialized from config.json. +public struct DFlashDraftConfiguration: Codable, Sendable { + var modelType: String = "dflash_qwen3" + var hiddenSize: Int = 1024 + var numHiddenLayers: Int = 4 + var intermediateSize: Int = 2816 + var numAttentionHeads: Int = 16 + var rmsNormEps: Float = 1e-6 + var vocabularySize: Int = 151_936 + var numKeyValueHeads: Int = 8 + var maxPositionEmbeddings: Int = 131072 + var ropeTheta: Float = 1_000_000.0 + var headDim: Int = 128 + var tieWordEmbeddings: Bool = false + var numTargetLayers: Int = 36 + var blockSize: Int = 16 + var attentionBias: Bool = false + var attentionDropout: Float = 0.0 + var ropeScaling: [String: StringOrNumber]? + var layerTypes: [String] = [] + var dflashConfig: DFlashConfig? + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case hiddenSize = "hidden_size" + case numHiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case numAttentionHeads = "num_attention_heads" + case rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case numKeyValueHeads = "num_key_value_heads" + case maxPositionEmbeddings = "max_position_embeddings" + case ropeTheta = "rope_theta" + case headDim = "head_dim" + case tieWordEmbeddings = "tie_word_embeddings" + case numTargetLayers = "num_target_layers" + case blockSize = "block_size" + case attentionBias = "attention_bias" + case attentionDropout = "attention_dropout" + case ropeScaling = "rope_scaling" + case layerTypes = "layer_types" + case dflashConfig = "dflash_config" + } + + struct DFlashConfig: Codable, Sendable { + var targetLayerIds: [Int]? + var maskTokenId: Int? + + enum CodingKeys: String, CodingKey { + case targetLayerIds = "target_layer_ids" + case maskTokenId = "mask_token_id" + } + } +} + +// MARK: - Helper: build target layer IDs + +func buildTargetLayerIDs(numTargetLayers: Int, numDraftLayers: Int) -> [Int] { + if numDraftLayers <= 1 { + return [numTargetLayers / 2] + } + let start = 1 + let end = numTargetLayers - 3 + let span = end - start + return (0 ..< numDraftLayers).map { i in + Int(round(Double(start) + Double(i) * Double(span) / Double(numDraftLayers - 1))) + } +} + +// MARK: - Context-Only Draft KV Cache + +/// A sliding-window KV cache that only stores context keys/values +/// (no incremental update-and-fetch), used by the DFlash draft model's +/// cross-attention layers. +public final class ContextOnlyDraftKVCache { + public var keys: MLXArray? + public var values: MLXArray? + public var offset: Int = 0 + let sinkSize: Int + let windowSize: Int + + public init(sinkSize: Int = 64, windowSize: Int = 1024) { + self.sinkSize = sinkSize + self.windowSize = windowSize + } + + public func appendContext( + contextKeys: MLXArray, + contextValues: MLXArray, + numPositions: Int + ) { + guard numPositions > 0 else { return } + if keys == nil { + keys = contextKeys + values = contextValues + } else { + keys = concatenated([keys!, contextKeys], axis: 2) + values = concatenated([values!, contextValues], axis: 2) + } + offset += numPositions + applyWindow() + } + + private func applyWindow() { + guard let k = keys, let v = values else { return } + let cacheLen = k.dim(2) + let maxLen = sinkSize + windowSize + guard cacheLen > maxLen else { return } + let sinkK = k[.ellipsis, .. (MLXArray?, MLXArray?) { + (keys, values) + } + + public var cacheLength: Int { + keys?.dim(2) ?? 0 + } +} + +// MARK: - DFlash Attention + +/// Cross-attention layer for the DFlash draft model. +/// Uses target hidden states as context and noise token embeddings as queries. +final class DFlashAttention: Module { + let nHeads: Int + let nKVHeads: Int + let headDim: Int + let scale: Float + + @ModuleInfo(key: "q_proj") var qProj: Linear + @ModuleInfo(key: "k_proj") var kProj: Linear + @ModuleInfo(key: "v_proj") var vProj: Linear + @ModuleInfo(key: "o_proj") var oProj: Linear + @ModuleInfo(key: "q_norm") var qNorm: RMSNorm + @ModuleInfo(key: "k_norm") var kNorm: RMSNorm + + let rope: RoPELayer + + init(_ args: DFlashDraftConfiguration) { + let dim = args.hiddenSize + self.nHeads = args.numAttentionHeads + self.nKVHeads = args.numKeyValueHeads + self.headDim = args.headDim + self.scale = pow(Float(headDim), -0.5) + + _qProj.wrappedValue = Linear(dim, nHeads * headDim, bias: args.attentionBias) + _kProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: args.attentionBias) + _vProj.wrappedValue = Linear(dim, nKVHeads * headDim, bias: args.attentionBias) + _oProj.wrappedValue = Linear(nHeads * headDim, dim, bias: args.attentionBias) + _qNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps) + _kNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps) + + self.rope = initializeRope( + dims: headDim, + base: args.ropeTheta, + traditional: false, + scalingConfig: args.ropeScaling, + maxPositionEmbeddings: args.maxPositionEmbeddings + ) + + super.init() + } + + func callAsFunction( + _ hiddenStates: MLXArray, + targetHidden: MLXArray, + cache: ContextOnlyDraftKVCache? = nil + ) -> MLXArray { + let B = hiddenStates.dim(0) + let blockLen = hiddenStates.dim(1) + let ctxLen = targetHidden.dim(1) + + var queries = qNorm(qProj(hiddenStates).reshaped(B, blockLen, nHeads, headDim)) + .transposed(0, 2, 1, 3) + var contextKeys = kNorm( + kProj(targetHidden).reshaped(B, ctxLen, nKVHeads, headDim) + ).transposed(0, 2, 1, 3) + let contextValues = vProj(targetHidden).reshaped(B, ctxLen, nKVHeads, headDim) + .transposed(0, 2, 1, 3) + + var noiseKeys = kNorm( + kProj(hiddenStates).reshaped(B, blockLen, nKVHeads, headDim) + ).transposed(0, 2, 1, 3) + let noiseValues = vProj(hiddenStates).reshaped(B, blockLen, nKVHeads, headDim) + .transposed(0, 2, 1, 3) + + if let cache { + let cacheOffset = cache.offset + let queryOffset = cacheOffset + ctxLen + + queries = rope(queries, offset: queryOffset) + contextKeys = rope(contextKeys, offset: cacheOffset) + noiseKeys = rope(noiseKeys, offset: queryOffset) + + cache.appendContext( + contextKeys: contextKeys, + contextValues: contextValues, + numPositions: ctxLen + ) + let (cachedKeys, cachedValues) = cache.fetch() + let keys = concatenated([cachedKeys!, noiseKeys], axis: 2) + let values = concatenated([cachedValues!, noiseValues], axis: 2) + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, + scale: scale, mask: .none + ) + let attnOut = output.transposed(0, 2, 1, 3).reshaped(B, blockLen, -1) + return oProj(attnOut) + } else { + queries = rope(queries, offset: ctxLen) + contextKeys = rope(contextKeys, offset: 0) + noiseKeys = rope(noiseKeys, offset: ctxLen) + + let keys = concatenated([contextKeys, noiseKeys], axis: 2) + let values = concatenated([contextValues, noiseValues], axis: 2) + + let output = MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, + scale: scale, mask: .none + ) + return oProj(output.transposed(0, 2, 1, 3).reshaped(B, blockLen, -1)) + } + } +} + +// MARK: - DFlash Decoder Layer + +final class DFlashDecoderLayer: Module { + @ModuleInfo(key: "self_attn") var selfAttn: DFlashAttention + @ModuleInfo(key: "mlp") var mlp: DFlashGLUMLP + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + init(_ args: DFlashDraftConfiguration) { + _selfAttn.wrappedValue = DFlashAttention(args) + _mlp.wrappedValue = DFlashGLUMLP( + dimensions: args.hiddenSize, + hiddenDimensions: args.intermediateSize + ) + _inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps + ) + _postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps + ) + super.init() + } + + func callAsFunction( + _ hiddenStates: MLXArray, + targetHidden: MLXArray, + cache: ContextOnlyDraftKVCache? = nil + ) -> MLXArray { + let residual = hiddenStates + var h = inputLayerNorm(hiddenStates) + h = selfAttn(h, targetHidden: targetHidden, cache: cache) + h = residual + h + + let r = h + h = postAttentionLayerNorm(h) + h = mlp(h) + return r + h + } +} + +// MARK: - DFlash Draft Model + +/// The DFlash block-diffusion draft model. +/// +/// This model takes noise token embeddings (from the target model's embed_tokens) +/// and target hidden states, and produces draft logits for block-diffusion speculative decoding. +public final class DFlashDraftModel: Module { + let args: DFlashDraftConfiguration + public let modelType: String + + let layers: [DFlashDecoderLayer] + public let targetLayerIDs: [Int] + @ModuleInfo(key: "norm") var norm: RMSNorm + @ModuleInfo(key: "fc") var fc: Linear + @ModuleInfo(key: "hidden_norm") var hiddenNorm: RMSNorm + public let blockSize: Int + public let maskTokenID: Int + + public init(_ args: DFlashDraftConfiguration) { + self.args = args + self.modelType = "dflash_qwen3" + + self.layers = (0 ..< args.numHiddenLayers).map { _ in + DFlashDecoderLayer(args) + } + + let targetLayerIDs = args.dflashConfig?.targetLayerIds + ?? buildTargetLayerIDs( + numTargetLayers: args.numTargetLayers, + numDraftLayers: args.numHiddenLayers + ) + self.targetLayerIDs = targetLayerIDs + _norm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + _fc.wrappedValue = Linear(targetLayerIDs.count * args.hiddenSize, args.hiddenSize, bias: false) + _hiddenNorm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + self.blockSize = args.blockSize + self.maskTokenID = args.dflashConfig?.maskTokenId ?? 0 + + super.init() + } + + func projectTargetHidden(_ targetHidden: MLXArray) -> MLXArray { + if DFlashDumper.isEnabled { + DFlashDumper.save("swift_fc_weight", fc.weight) + DFlashDumper.save("swift_fc_bias", fc.bias ?? MLXArray.zeros([0])) + } + let fcOut = fc(targetHidden) + if DFlashDumper.isEnabled { + DFlashDumper.save("swift_fc_output", fcOut) + } + let result = hiddenNorm(fcOut) + DFlashDumper.save("swift_projected_hidden", result) + return result + } + + public func callAsFunction( + noiseEmbedding: MLXArray, + targetHidden: MLXArray, + cache: [ContextOnlyDraftKVCache]? = nil + ) -> MLXArray { + var hiddenStates = noiseEmbedding + if DFlashDumper.isEnabled { + DFlashDumper.save("swift_target_hidden_input", targetHidden) + } + let projectedHidden = projectTargetHidden(targetHidden) + + let draftCache = cache ?? layers.map { _ in + ContextOnlyDraftKVCache() + } + + for (i, layer) in layers.enumerated() { + hiddenStates = layer( + hiddenStates, + targetHidden: projectedHidden, + cache: i < draftCache.count ? draftCache[i] : nil + ) + if DFlashDumper.isEnabled { + DFlashDumper.save("swift_draft_layer\(i)_output", hiddenStates) + } + } + let result = norm(hiddenStates) + if DFlashDumper.isEnabled { + DFlashDumper.save("swift_draft_final_normed", result) + } + return result + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + weights + } +} + +// MARK: - Extract context feature from hidden states + +/// Extract and concatenate hidden states at the specified layer IDs. +/// The layer IDs are 0-indexed into the model's layers, and we take +/// `hiddenStates[layerID + 1]` because index 0 is the embedding output. +public func extractContextFeature( + hiddenStates: [MLXArray], + layerIDs: [Int] +) -> MLXArray { + let selected = layerIDs.map { hiddenStates[$0 + 1] } + return concatenated(selected, axis: -1) +} + +/// Extract context feature from a dictionary of captured hidden states. +public func extractContextFeatureFromDict( + capturedDict: [Int: MLXArray], + targetLayerIDs: [Int] +) -> MLXArray { + let selected = targetLayerIDs.map { capturedDict[$0 + 1]! } + return concatenated(selected, axis: -1) +} diff --git a/Sources/DFlash/DFlashDraftRegistry.swift b/Sources/DFlash/DFlashDraftRegistry.swift new file mode 100644 index 00000000..f9bd2583 --- /dev/null +++ b/Sources/DFlash/DFlashDraftRegistry.swift @@ -0,0 +1,68 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Based on DFlash (arXiv:2602.06036) + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - Draft Model Registry + +/// Registry mapping target model names to their DFlash draft models. +public enum DFlashDraftRegistry { + + /// Known target → draft model mappings. + static let registry: [String: String] = [ + "Qwen3.5-4B": "z-lab/Qwen3.5-4B-DFlash", + "Qwen3.5-9B": "z-lab/Qwen3.5-9B-DFlash", + "Qwen3.5-27B": "z-lab/Qwen3.5-27B-DFlash", + "Qwen3.5-35B-A3B": "z-lab/Qwen3.5-35B-A3B-DFlash", + "Qwen3.6-35B-A3B": "z-lab/Qwen3.6-35B-A3B-DFlash", + "Qwen3-4B": "z-lab/Qwen3-4B-DFlash-b16", + "Qwen3-8B": "z-lab/Qwen3-8B-DFlash-b16", + ] + + /// Normalize a model reference by stripping the org prefix. + private static func stripModelOrg(_ modelRef: String) -> String { + modelRef.split(separator: "/").last.map(String.init) ?? modelRef + } + + /// Resolve an optional draft model reference for the given target model. + /// + /// - Parameters: + /// - modelRef: The target model reference (org/name or local path) + /// - draftRef: An explicit draft model reference (takes priority) + /// - Returns: The resolved draft model reference, or nil if none found + public static func resolveDraftRef(modelRef: String, draftRef: String? = nil) -> String? { + if let draftRef { return draftRef } + + let stripped = stripModelOrg(modelRef).lowercased() + + // Exact match + for (key, value) in registry where key.lowercased() == stripped { + return value + } + + // Prefix match (e.g., "qwen3.5-4b-4bit" matches "qwen3.5-4b") + var bestMatch: (key: String, value: String)? + for (key, value) in registry { + let lowered = key.lowercased() + if stripped == lowered + || stripped.hasPrefix(lowered + "-") + || stripped.hasPrefix(lowered + "_") + { + if bestMatch == nil || key.count > bestMatch!.key.count { + bestMatch = (key, value) + } + } + } + + return bestMatch?.value + } + + /// List supported base model names. + public static func supportedBaseModels() -> [String] { + Array(registry.keys).sorted() + } +} diff --git a/Sources/DFlash/DFlashEngine.swift b/Sources/DFlash/DFlashEngine.swift new file mode 100644 index 00000000..c50b537b --- /dev/null +++ b/Sources/DFlash/DFlashEngine.swift @@ -0,0 +1,84 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Based on DFlash (arXiv:2602.06036) + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - Engine Protocol + +/// Protocol for DFlash verify/rollback engines. +/// +/// Two concrete implementations exist: +/// - ``FullAttentionEngine`` — for pure-attention target models +/// - ``HybridGDNEngine`` — for hybrid GatedDeltaNet + attention target models +public protocol DFlashEngine: Sendable { + /// Arm the target model's cache for rollback before verification. + func armRollback(targetCache: [KVCache], prefixLen: Int) + + /// Roll back the target cache after partial acceptance. + func rollback( + targetCache: [KVCache], + targetLen: Int, + acceptanceLength: Int, + draftedTokens: Int + ) -> Int +} + +// MARK: - Full Attention Engine + +/// Engine for pure-attention target models (no recurrent layers). +/// Rollback is just KV cache trimming. +public final class FullAttentionEngine: DFlashEngine, @unchecked Sendable { + public init() {} + + public func armRollback(targetCache: [KVCache], prefixLen: Int) { + // Pure attention: no arming needed + } + + public func rollback( + targetCache: [KVCache], + targetLen: Int, + acceptanceLength: Int, + draftedTokens: Int + ) -> Int { + DFlashRuntime.restoreTargetCacheAfterAcceptance( + targetCache, + targetLen: targetLen, + acceptanceLength: acceptanceLength, + draftedTokens: draftedTokens + ) + } +} + +// MARK: - Hybrid GDN Engine + +/// Engine for hybrid GatedDeltaNet + attention target models. +/// Uses RecurrentRollbackCache for recurrent layers with tape replay. +public final class HybridGDNEngine: DFlashEngine, @unchecked Sendable { + public init() {} + + public func armRollback(targetCache: [KVCache], prefixLen: Int) { + for cache in targetCache { + if let rollbackCache = cache as? RecurrentRollbackCache { + rollbackCache.armRollback(prefixLen: prefixLen) + } + } + } + + public func rollback( + targetCache: [KVCache], + targetLen: Int, + acceptanceLength: Int, + draftedTokens: Int + ) -> Int { + DFlashRuntime.restoreTargetCacheAfterAcceptance( + targetCache, + targetLen: targetLen, + acceptanceLength: acceptanceLength, + draftedTokens: draftedTokens + ) + } +} diff --git a/Sources/DFlash/DFlashIntermediateDumper.swift b/Sources/DFlash/DFlashIntermediateDumper.swift new file mode 100644 index 00000000..a9802aff --- /dev/null +++ b/Sources/DFlash/DFlashIntermediateDumper.swift @@ -0,0 +1,118 @@ +// DFlashIntermediateDumper.swift +// +// Utility to dump DFlash intermediate values to .npy files for comparison +// with the Python reference implementation. +// +// Usage: Set DFLASH_DUMP_DIR env var before running SwiftLM. +// All intermediate arrays are saved as .npy files. +// Only the first cycle's dumps are saved to avoid huge files. + +import Foundation +import MLX + +public enum DFlashDumper { + + private static var dumpDir: String? = ProcessInfo.processInfo.environment["DFLASH_DUMP_DIR"] + private static var cycleCount = 0 + private static var saved = Set() + + public static var isEnabled: Bool { dumpDir != nil } + + public static func setup() { + if let dir = dumpDir { + try? FileManager.default.createDirectory(atPath: dir, withIntermediateDirectories: true) + print("[DFlashDumper] Dumping intermediates to: \(dir)") + } + cycleCount = 0 + saved.removeAll() + } + + public static func markCycle() { + cycleCount += 1 + } + + /// Save an MLXArray as a .npy file (float32 format) + /// Only saves on the first cycle to avoid huge files. + public static func save(_ name: String, _ arr: MLXArray) { + guard let dir = dumpDir else { return } + guard !saved.contains(name) else { return } // only save first occurrence + saved.insert(name) + + let floatArr = arr.asType(.float32) + eval(floatArr) + + let shape = (0..> 8) & 0xFF)) + fileData.append(Data(headerBytes)) + + // Convert to [Float] and write + let floatData = floatArr.asArray(Float.self) + floatData.withUnsafeBufferPointer { ptr in + fileData.append(Data(buffer: ptr)) + } + + let url = URL(fileURLWithPath: dir).appendingPathComponent("\(name).npy") + try? fileData.write(to: url) + } + + /// Save an MLXArray as .npy (int32 format) + public static func saveInt(_ name: String, _ arr: MLXArray) { + guard let dir = dumpDir else { return } + guard !saved.contains(name) else { return } + saved.insert(name) + + let intArr = arr.asType(.int32) + eval(intArr) + + let shape = (0..> 8) & 0xFF)) + fileData.append(Data(headerBytes)) + + let intData = intArr.asArray(Int32.self) + intData.withUnsafeBufferPointer { ptr in + fileData.append(Data(buffer: ptr)) + } + + let url = URL(fileURLWithPath: dir).appendingPathComponent("\(name).npy") + try? fileData.write(to: url) + } +} diff --git a/Sources/DFlash/DFlashKernelProvider.swift b/Sources/DFlash/DFlashKernelProvider.swift new file mode 100644 index 00000000..a1ee533a --- /dev/null +++ b/Sources/DFlash/DFlashKernelProvider.swift @@ -0,0 +1,19 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file + +import Foundation +import MLX + +/// Provider for DFlash specialized kernels. +public protocol DFlashKernelProvider: Sendable { + func gatedDeltaKernelWithTape( + q: MLXArray, k: MLXArray, v: MLXArray, + g: MLXArray, beta: MLXArray, + state: MLXArray, mask: MLXArray? + ) -> (MLXArray, MLXArray, MLXArray) +} + +/// Registry to allow models to use DFlash kernels without module circular dependencies. +public struct DFlashKernelRegistry: Sendable { + public nonisolated(unsafe) static var provider: DFlashKernelProvider? = nil +} diff --git a/Sources/DFlash/DFlashKernels.swift b/Sources/DFlash/DFlashKernels.swift new file mode 100644 index 00000000..e9100ba9 --- /dev/null +++ b/Sources/DFlash/DFlashKernels.swift @@ -0,0 +1,843 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Based on DFlash (arXiv:2602.06036) + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +/// Metal kernels for DFlash speculative decoding. +/// +/// Provides: +/// - **Tape replay kernel**: Replays accepted innovation steps through the +/// GatedDeltaNet recurrent state for efficient rollback. +/// - **GatedDelta kernel with tape**: Modified GatedDelta forward that records +/// the innovation tape alongside the normal output. +/// - **Batched SDPA 2-pass kernel**: Custom attention kernel for long-context +/// verify that stays numerically aligned with stock MLX attention. +public enum DFlashKernels { + + /// Shared instance for use as the global DFlashKernelProvider + public static let shared = DFlashKernelsInstance() + + // MARK: - Tape Replay Kernel + + private static func makeTapeReplayKernel( + hasMask: Bool = false, + vectorized: Bool = false + ) -> MLXFast.MLXFastKernel? { + // Branchless + correct semantics via metal::select: + // When mask=0 (do_step=false), metal::select returns the OLD state[i], + // so state is completely unchanged — no decay, no accumulate. + // When mask=1 (do_step=true), the computed next value is used. + // metal::select is a conditional move with no warp divergence. + let maskLoad = hasMask + ? "bool do_step = static_cast(mask[b_idx * T + t]) > 0.5f;" + : "constexpr bool do_step = true;" + let gSetup = vectorized ? "auto g_ = g + (b_idx * T * Hv + hv_idx) * Dk;" + : "auto g_ = g + b_idx * T * Hv;" + let gAccess = vectorized ? "g_[s_idx]" : "g_[hv_idx]" + let gAdvance = vectorized ? "g_ += Hv * Dk;" : "g_ += Hv;" + + let source = """ + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + + auto tape_ = tape + b_idx * T * Hv * Dv + hv_idx * Dv; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + + \(gSetup) + + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) + state[i] = static_cast(i_state[n_per_t * dk_idx + i]); + + for (int t = 0; t < T; ++t) { + \(maskLoad) + float delta = static_cast(tape_[dv_idx]); + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + float next = state[i] * \(gAccess) + k_[s_idx] * delta; + next = static_cast(static_cast(next)); + // Conditional move: old state when masked, next when accepted. + state[i] = metal::select(state[i], next, do_step); + } + tape_ += Hv * Dv; + k_ += Hk * Dk; + \(gAdvance) + } + + for (int i = 0; i < n_per_t; ++i) + o_state[n_per_t * dk_idx + i] = static_cast(state[i]); + """ + + var inputNames = ["tape", "k", "g", "state_in", "T"] + if hasMask { inputNames.append("mask") } + + var suffix = "" + if vectorized { suffix += "_vec" } + if hasMask { suffix += "_mask" } + + return MLXFast.metalKernel( + name: "dflash_tape_replay\(suffix)", + inputNames: inputNames, + outputNames: ["state_out"], + source: source + ) + } + + // MARK: - GatedDelta with Tape Kernel + + private static func makeGatedDeltaTapeKernel( + hasMask: Bool = false, + vectorized: Bool = false + ) -> MLXFast.MLXFastKernel? { + // Two optimizations over the naive branching version: + // + // 1. Uniform simdgroup predicate: mask[b_idx*T+t] is the same scalar for + // every thread in the simdgroup (uniform control flow). Wrapping the two + // expensive simd_sum calls in `if (do_step)` skips ~50% of them at + // typical acceptance rates with zero warp divergence. + // + // 2. metal::select for state correctness: state must be completely + // unchanged when mask=0 (no decay). We save state before the decay pass, + // then use metal::select to restore it when !do_step. + let maskLoad = hasMask + ? "bool do_step = static_cast(mask[b_idx * T + t]) > 0.5f;" + : "constexpr bool do_step = true;" + let gSetup = vectorized ? "auto g_ = g + (b_idx * T * Hv + hv_idx) * Dk;" + : "auto g_ = g + b_idx * T * Hv;" + let gAccess = vectorized ? "g_[s_idx]" : "g_[hv_idx]" + let gAdvance = vectorized ? "g_ += Hv * Dk;" : "g_ += Hv;" + + let source = """ + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + + auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv; + y += b_idx * T * Hv * Dv + hv_idx * Dv; + auto tape_ = innovation_tape + b_idx * T * Hv * Dv + hv_idx * Dv; + + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + + \(gSetup) + auto beta_ = beta + b_idx * T * Hv; + + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) + state[i] = static_cast(i_state[n_per_t * dk_idx + i]); + + for (int t = 0; t < T; ++t) { + \(maskLoad) + + // Save pre-decay state; needed by metal::select to restore when !do_step. + float old_state[n_per_t]; + float kv_mem = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + old_state[i] = state[i]; + state[i] = state[i] * \(gAccess); + kv_mem += state[i] * k_[s_idx]; + } + + // Uniform predicate: skip two simd_sum calls when !do_step. + // All threads in the simdgroup read the same mask scalar → no divergence. + float delta = 0.0f; + float out = 0.0f; + if (do_step) { + kv_mem = simd_sum(kv_mem); + delta = (static_cast(v_[dv_idx]) - kv_mem) + * static_cast(beta_[hv_idx]); + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] += k_[s_idx] * delta; + out += state[i] * static_cast(q_[s_idx]); + } + out = simd_sum(out); + } + + if (thread_index_in_simdgroup == 0) { + y[dv_idx] = static_cast(out); + tape_[dv_idx] = delta; + } + + // Restore pre-decay state when !do_step; quantize new state when do_step. + for (int i = 0; i < n_per_t; ++i) { + float quant_new = static_cast(static_cast(state[i])); + state[i] = metal::select(old_state[i], quant_new, do_step); + } + + q_ += Hk * Dk; + k_ += Hk * Dk; + v_ += Hv * Dv; + y += Hv * Dv; + tape_ += Hv * Dv; + \(gAdvance) + beta_ += Hv; + } + + for (int i = 0; i < n_per_t; ++i) + o_state[n_per_t * dk_idx + i] = static_cast(state[i]); + """ + + var inputNames = ["q", "k", "v", "g", "beta", "state_in", "T"] + if hasMask { inputNames.append("mask") } + + var suffix = "" + if vectorized { suffix += "_vec" } + if hasMask { suffix += "_mask" } + + return MLXFast.metalKernel( + name: "dflash_gated_delta_tape\(suffix)", + inputNames: inputNames, + outputNames: ["y", "state_out", "innovation_tape"], + source: source + ) + } + + // MARK: - Lazy Kernel Singleton + + private final class KernelCache { + static let shared = KernelCache() + + // Layout: [vectorized (0/1)][masked (0/1)] + let tapeReplay: [[MLXFast.MLXFastKernel?]] + let gatedDeltaTape: [[MLXFast.MLXFastKernel?]] + + private init() { + tapeReplay = [ + [makeTapeReplayKernel(hasMask: false, vectorized: false), + makeTapeReplayKernel(hasMask: true, vectorized: false)], + [makeTapeReplayKernel(hasMask: false, vectorized: true), + makeTapeReplayKernel(hasMask: true, vectorized: true)], + ] + gatedDeltaTape = [ + [makeGatedDeltaTapeKernel(hasMask: false, vectorized: false), + makeGatedDeltaTapeKernel(hasMask: true, vectorized: false)], + [makeGatedDeltaTapeKernel(hasMask: false, vectorized: true), + makeGatedDeltaTapeKernel(hasMask: true, vectorized: true)], + ] + } + } + + // MARK: - Public API: Tape Replay + + /// Replay the innovation tape through the GatedDeltaNet state. + /// + /// - Parameters: + /// - tape: Innovation tape [B, T, Hv, Dv] + /// - k: Keys [B, T, Hk, Dk] + /// - g: Gates (decay) — either [B, T, Hv] or [B, T, Hv, Dk] + /// - state: Current recurrent state [B, Hv, Dv, Dk] + /// - mask: Optional mask [B, T] + /// - Returns: Replayed state [B, Hv, Dv, Dk] + public static func tapeReplayKernel( + tape: MLXArray, + k: MLXArray, + g: MLXArray, + state: MLXArray, + mask: MLXArray? = nil + ) -> MLXArray { + let forceFallback = ProcessInfo.processInfo.environment["DFLASH_FORCE_OPS"] != nil + let isCPU = Device.defaultDevice().deviceType == .cpu + if isCPU || forceFallback { return tapeReplayOps(tape: tape, k: k, g: g, state: state, mask: mask) } + + let B = k.dim(0) + let steps = k.dim(1) + let Hk = k.dim(2) + let Dk = k.dim(3) + let Hv = tape.dim(2) + let Dv = tape.dim(3) + let inputType = state.dtype + + if Dk < 32 || Dk % 32 != 0 { + return tapeReplayOps(tape: tape, k: k, g: g, state: state, mask: mask) + } + + let vec = g.ndim == 4 ? 1 : 0 + let msk = mask != nil ? 1 : 0 + let kernel = KernelCache.shared.tapeReplay[vec][msk] + + guard let kernel else { + return tapeReplayOps(tape: tape, k: k, g: g, state: state, mask: mask) + } + + var inputs: [MLXArray] = [tape, k, g, state, MLXArray(steps)] + if let mask { inputs.append(mask) } + + let outputs = kernel( + inputs, + template: [ + ("InT", inputType), + ("Dk", Dk), + ("Dv", Dv), + ("Hk", Hk), + ("Hv", Hv), + ], + grid: (32, Dv, B * Hv), + threadGroup: (32, 4, 1), + outputShapes: [state.shape], + outputDTypes: [inputType] + ) + return outputs[0] + } + + // MARK: - Public API: GatedDelta with Tape + + /// Run GatedDelta forward while recording the innovation tape for rollback. + /// + /// - Parameters: + /// - q: Queries [B, T, Hk, Dk] + /// - k: Keys [B, T, Hk, Dk] + /// - v: Values [B, T, Hv, Dv] + /// - g: Gates (decay) — either [B, T, Hv] or [B, T, Hv, Dk] + /// - beta: Beta values [B, T, Hv] + /// - state: Recurrent state [B, Hv, Dv, Dk] + /// - mask: Optional mask [B, T] + /// - Returns: Tuple of (output [B, T, Hv, Dv], new state, innovation tape [B, T, Hv, Dv]) + public static func gatedDeltaKernelWithTape( + q: MLXArray, + k: MLXArray, + v: MLXArray, + g: MLXArray, + beta: MLXArray, + state: MLXArray, + mask: MLXArray? = nil + ) -> (MLXArray, MLXArray, MLXArray) { + let forceFallback = ProcessInfo.processInfo.environment["DFLASH_FORCE_OPS"] != nil + let isCPU = Device.defaultDevice().deviceType == .cpu + if isCPU || forceFallback { return gatedDeltaOpsWithTape(q: q, k: k, v: v, g: g, beta: beta, state: state, mask: mask) } + + let B = k.dim(0) + let T = k.dim(1) + let Hk = k.dim(2) + let Dk = k.dim(3) + let Hv = v.dim(2) + let Dv = v.dim(3) + + if Dk < 32 || Dk % 32 != 0 { + return gatedDeltaOpsWithTape(q: q, k: k, v: v, g: g, beta: beta, state: state, mask: mask) + } + + let inputType = q.dtype + let vec = g.ndim == 4 ? 1 : 0 + let msk = mask != nil ? 1 : 0 + let kernel = KernelCache.shared.gatedDeltaTape[vec][msk] + + guard let kernel else { + return gatedDeltaOpsWithTape(q: q, k: k, v: v, g: g, beta: beta, state: state, mask: mask) + } + + var inputs: [MLXArray] = [q, k, v, g, beta, state, MLXArray(T)] + if let mask { inputs.append(mask) } + + let outputs = kernel( + inputs, + template: [ + ("InT", inputType), + ("Dk", Dk), + ("Dv", Dv), + ("Hk", Hk), + ("Hv", Hv), + ], + grid: (32, Dv, B * Hv), + threadGroup: (32, 4, 1), + outputShapes: [[B, T, Hv, Dv], state.shape, [B, T, Hv, Dv]], + outputDTypes: [inputType, inputType, DType.float32] + ) + return (outputs[0], outputs[1], outputs[2]) + } + + // MARK: - Fallback: Ops-based implementations + + private static func tapeReplayOps( + tape: MLXArray, + k: MLXArray, + g: MLXArray, + state: MLXArray, + mask: MLXArray? = nil + ) -> MLXArray { + let T = tape.dim(1) + let Hk = k.dim(2) + let Hv = tape.dim(2) + let repeatFactor = Hv / Hk + var k = k + if repeatFactor > 1 { + k = MLX.repeated(k, count: repeatFactor, axis: 2) + } + + var state = state + for t in 0 ..< T { + let prev = state + let decay: MLXArray + if g.ndim == 4 { + decay = g[0..., t, 0..., .newAxis, 0...] + } else { + decay = expandedDimensions(g[0..., t, 0...], axes: [2, 3]) + } + let delta = tape[0..., t, 0..., .newAxis] + let kT = expandedDimensions(k[0..., t, 0...], axis: -2) + state = state * decay + delta * kT + if let mask { + // MLX.where is faster than arithmetic masking for tape replay ops + // (benchmark: 382 µs vs 455 µs on M-series, scalar-g masked). + let stepMask = mask[0..., t][.newAxis, .newAxis, .newAxis] + state = MLX.where(stepMask, state, prev) + } + } + return state + } + + private static func gatedDeltaOpsWithTape( + q: MLXArray, + k: MLXArray, + v: MLXArray, + g: MLXArray, + beta: MLXArray, + state: MLXArray, + mask: MLXArray? = nil + ) -> (MLXArray, MLXArray, MLXArray) { + let T = q.dim(1) + let Hk = q.dim(2) + let Hv = v.dim(2) + let repeatFactor = Hv / Hk + var q = q + var k = k + if repeatFactor > 1 { + q = MLX.repeated(q, count: repeatFactor, axis: 2) + k = MLX.repeated(k, count: repeatFactor, axis: 2) + } + + var state = state + var outputs = [MLXArray]() + var tapeEntries = [MLXArray]() + + for t in 0 ..< T { + let decay: MLXArray + if g.ndim == 4 { + decay = g[0..., t, 0..., .newAxis, 0...] + } else { + decay = expandedDimensions(g[0..., t, 0...], axes: [2, 3]) + } + let decayedState = state * decay + let kvMem = (decayedState * expandedDimensions(k[0..., t, 0...], axis: -2)).sum(axis: -1) + let delta = (v[0..., t, 0...] - kvMem) * expandedDimensions(beta[0..., t, 0...], axis: -1) + let newState = decayedState + expandedDimensions(k[0..., t, 0...], axis: -2) * expandedDimensions(delta, axis: -1) + let y = (newState * expandedDimensions(q[0..., t, 0...], axis: -2)).sum(axis: -1) + + if let mask { + // Arithmetic masking is faster than MLX.where for gdelta ops + // (benchmark: 816 µs vs 1005 µs on M-series, scalar-g masked). + let sGate = expandedDimensions(mask[0..., t], axes: [1, 2, 3]).asType(state.dtype) + let yGate = expandedDimensions(mask[0..., t], axes: [1, 2]).asType(y.dtype) + state = newState * sGate + state * (1 - sGate) + outputs.append(y * yGate) + tapeEntries.append((delta * yGate).asType(DType.float32)) + } else { + state = newState + outputs.append(y) + tapeEntries.append(delta.asType(DType.float32)) + } + } + + return ( + MLX.stacked(outputs, axis: 1), + state, + MLX.stacked(tapeEntries, axis: 1) + ) + } + + // MARK: - Block Computation for 2-Pass SDPA + + private static func computeSDPA2PassBlocks(gqaFactor: Int, nKV: Int, deviceArch: String? = nil) -> Int { + let arch = deviceArch ?? Device.defaultDevice().description + let devc = arch.isEmpty ? "" : String(arch.suffix(1)) + let nSimds = gqaFactor + let N = nKV + + var blocks: Int + if devc == "d" { + blocks = 128 + if nSimds <= 2 && N > 8192 { + blocks = 256 + } else if nSimds >= 6 { + if N >= 16384 && N < 65536 { + blocks = 512 + } else if N >= 65536 { + blocks = 1024 + } + } + } else if devc == "s" { + blocks = 64 + if N > 1024 && nSimds > 4 { + if N <= 8192 { + blocks = 128 + } else if N <= 32768 { + blocks = 256 + } else if N <= 65536 { + blocks = 512 + } else { + blocks = 1024 + } + } + } else { + blocks = nSimds >= 4 ? 64 : 32 + } + + return blocks + } + + // MARK: - Batched SDPA 2-Pass Kernels + + private final class SDPAKernelCache { + static let shared = SDPAKernelCache() + + private var _partialsKernel: MLXFast.MLXFastKernel? + private var _partialsKernelMasked: MLXFast.MLXFastKernel? + private var _reduceKernel: MLXFast.MLXFastKernel? + private var _initialized = false + private let _lock = NSLock() + + var partialsKernel: MLXFast.MLXFastKernel? { + _lock.lock(); defer { _lock.unlock() } + if !_initialized { _initAll() } + return _partialsKernel + } + + var partialsKernelMasked: MLXFast.MLXFastKernel? { + _lock.lock(); defer { _lock.unlock() } + if !_initialized { _initAll() } + return _partialsKernelMasked + } + + var reduceKernel: MLXFast.MLXFastKernel? { + _lock.lock(); defer { _lock.unlock() } + if !_initialized { _initAll() } + return _reduceKernel + } + + private init() {} + + private func _initAll() { + _partialsKernel = SDPAKernelCache.makePartialsKernel(hasMask: false) + _partialsKernelMasked = SDPAKernelCache.makePartialsKernel(hasMask: true) + _reduceKernel = SDPAKernelCache.makeReduceKernel() + _initialized = true + } + + private static func makePartialsKernel(hasMask: Bool) -> MLXFast.MLXFastKernel? { + let maskSetup = hasMask + ? "auto mask_ = mask + (((b_idx * Hq + q_head_idx) * M_FIXED + q_seq_idx) * N + block_idx);" + : "" + let maskUseKey = hasMask + ? "auto mask_value = static_cast(mask_[0]); use_key = use_key && (mask_value >= Limits::finite_min);" + : "" + let maskScore = hasMask ? "score += static_cast(mask_[0]);" : "" + let maskAdvance = hasMask ? "mask_ += blocks;" : "" + + var inputs = [ + "queries", "keys", "values", "gqa_factor", "N", + "k_head_stride", "k_seq_stride", "v_head_stride", "v_seq_stride", + "scale", "blocks" + ] + if hasMask { inputs.append("mask") } + + let source = """ + constexpr int BD = 32; + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V / BD; + + auto q_head_idx = threadgroup_position_in_grid.x; + auto b_idx = threadgroup_position_in_grid.y; + auto block_idx = threadgroup_position_in_grid.z; + auto q_seq_idx = thread_position_in_threadgroup.z; + auto simd_lid = thread_index_in_simdgroup; + + auto Hq = threadgroups_per_grid.x; + auto hk_idx = q_head_idx / gqa_factor; + auto q_batch_head_idx = b_idx * Hq + q_head_idx; + auto o_offset = q_batch_head_idx * M_FIXED + q_seq_idx; + + auto q_ = queries + (o_offset * D) + simd_lid * qk_per_thread; + auto k_ = keys + ((b_idx * Hk + hk_idx) * k_head_stride) + block_idx * k_seq_stride + simd_lid * qk_per_thread; + auto v_ = values + ((b_idx * Hk + hk_idx) * v_head_stride) + block_idx * v_seq_stride + simd_lid * v_per_thread; + + partials += (o_offset * blocks + block_idx) * V + simd_lid * v_per_thread; + sums += o_offset * blocks + block_idx; + maxs += o_offset * blocks + block_idx; + \(maskSetup) + + thread float q[qk_per_thread]; + thread float o[v_per_thread]; + threadgroup InT tg_k[BD * qk_per_thread]; + threadgroup InT tg_v[BD * v_per_thread]; + + for (int i = 0; i < qk_per_thread; ++i) { + q[i] = static_cast(scale) * static_cast(q_[i]); + } + for (int i = 0; i < v_per_thread; ++i) { + o[i] = 0.0f; + } + + float max_score = Limits::finite_min; + float sum_exp_score = 0.0f; + + for (int n = block_idx; n < N; n += blocks) { + if (q_seq_idx == 0) { + for (int i = 0; i < qk_per_thread; ++i) { + tg_k[simd_lid * qk_per_thread + i] = k_[i]; + } + for (int i = 0; i < v_per_thread; ++i) { + tg_v[simd_lid * v_per_thread + i] = v_[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + bool use_key = (n <= (N - M_FIXED + q_seq_idx)); + \(maskUseKey) + + if (use_key) { + float score = 0.0f; + for (int i = 0; i < qk_per_thread; ++i) { + score += q[i] * static_cast(tg_k[simd_lid * qk_per_thread + i]); + } + score = simd_sum(score); + \(maskScore) + + float new_max = metal::max(max_score, score); + float factor = fast::exp(max_score - new_max); + float exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + for (int i = 0; i < v_per_thread; ++i) { + o[i] = o[i] * factor + exp_score * static_cast(tg_v[simd_lid * v_per_thread + i]); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + k_ += blocks * int(k_seq_stride); + v_ += blocks * int(v_seq_stride); + \(maskAdvance) + } + + if (simd_lid == 0) { + sums[0] = sum_exp_score; + maxs[0] = max_score; + } + for (int i = 0; i < v_per_thread; ++i) { + partials[i] = static_cast(o[i]); + } + """ + + let suffix = hasMask ? "_mask" : "" + return MLXFast.metalKernel( + name: "batched_sdpa_2pass_partials\(suffix)", + inputNames: inputs, + outputNames: ["partials", "sums", "maxs"], + source: source + ) + } + + private static func makeReduceKernel() -> MLXFast.MLXFastKernel? { + let source = """ + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = V / BD; + + auto head_idx = threadgroup_position_in_grid.x; + auto q_seq_idx = threadgroup_position_in_grid.y; + auto simd_gid = simdgroup_index_in_threadgroup; + auto simd_lid = thread_index_in_simdgroup; + + auto q_offset = head_idx * M_FIXED + q_seq_idx; + partials += (q_offset * blocks + simd_gid) * V + simd_lid * elem_per_thread; + sums += q_offset * blocks; + maxs += q_offset * blocks; + out += q_offset * V + simd_gid * elem_per_thread; + + thread float o[elem_per_thread]; + threadgroup float outputs[BN * BD]; + + for (int i = 0; i < elem_per_thread; ++i) { + o[i] = 0.0f; + } + + float sum_exp_score = 0.0f; + float max_score = Limits::finite_min; + + for (int b = 0; b < blocks / BN; ++b) { + max_score = metal::max(max_score, maxs[simd_lid + BN * b]); + } + max_score = simd_max(max_score); + + for (int b = 0; b < blocks / BN; ++b) { + float factor = fast::exp(maxs[simd_lid + BN * b] - max_score); + sum_exp_score += factor * sums[simd_lid + BN * b]; + } + sum_exp_score = simd_sum(sum_exp_score); + + for (int b = 0; b < blocks / BN; ++b) { + float factor = fast::exp(maxs[simd_gid] - max_score); + for (int i = 0; i < elem_per_thread; ++i) { + o[i] += factor * static_cast(partials[i]); + } + maxs += BN; + partials += BN * V; + } + + for (int i = 0; i < elem_per_thread; ++i) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid]); + o[i] = sum_exp_score == 0.0f ? o[i] : (o[i] / sum_exp_score); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; ++i) { + out[i] = static_cast(o[i]); + } + } + """ + + return MLXFast.metalKernel( + name: "batched_sdpa_2pass_reduce", + inputNames: ["partials", "sums", "maxs", "blocks"], + outputNames: ["out"], + source: source + ) + } + } + + // MARK: - Public API: Batched SDPA + + /// Batched 2-pass SDPA for DFlash verify phase with long context. + /// + /// Optimized for: query length 16, bfloat16/float16, head dim 128 or 256. + /// Returns nil if conditions are not met; callers should fall back to `sdpaFallback`. + public static func batchedSDPA2Pass( + queries: MLXArray, + keys: MLXArray, + values: MLXArray, + scale: Float, + mask: MLXArray? = nil + ) -> MLXArray? { + guard queries.ndim == 4, keys.ndim == 4, values.ndim == 4 else { return nil } + + let B = queries.dim(0) + let Hq = queries.dim(1) + let qLen = queries.dim(2) + let D = queries.dim(3) + let Hk = keys.dim(1) + let nKV = keys.dim(2) + let Vdim = values.dim(3) + let inputType = queries.dtype + + guard qLen == 16 else { return nil } + guard inputType == .bfloat16 || inputType == .float16 else { return nil } + guard (D == 128 || D == 256) && (Vdim == 128 || Vdim == 256) && D == Vdim else { return nil } + guard Hk > 0 && Hq % Hk == 0 else { return nil } + + let queriesContig = MLX.contiguous(queries) + let keysContig = MLX.contiguous(keys) + let valuesContig = MLX.contiguous(values) + + let gqaFactor = Hq / Hk + let blocks = computeSDPA2PassBlocks(gqaFactor: gqaFactor, nKV: nKV) + guard blocks > 0 && blocks % 32 == 0 else { return nil } + + let kHeadStride = keys.dim(2) * keys.dim(3) + let kSeqStride = keys.dim(3) + let vHeadStride = values.dim(2) * values.dim(3) + let vSeqStride = values.dim(3) + + let cache = SDPAKernelCache.shared + var kernel = cache.partialsKernel + var inputs: [MLXArray] = [ + queriesContig, keysContig, valuesContig, + MLXArray(gqaFactor), MLXArray(nKV), + MLXArray(kHeadStride), MLXArray(kSeqStride), + MLXArray(vHeadStride), MLXArray(vSeqStride), + MLXArray(scale), MLXArray(blocks) + ] + + if let mask { + let maskContig = mask.dtype != inputType ? mask.asType(inputType) : mask + kernel = cache.partialsKernelMasked + inputs.append(maskContig) + } + + guard let partialsKernel = kernel, let reduceKernel = cache.reduceKernel else { return nil } + + let partialShape = [B * Hq, qLen, blocks, Vdim] + let statsShape = [B * Hq, qLen, blocks] + + let outputs1 = partialsKernel( + inputs, + template: [ + ("InT", inputType), ("D", D), ("V", Vdim), ("Hk", Hk), ("M_FIXED", qLen) + ], + grid: (Hq * 32, B, blocks * qLen), + threadGroup: (32, 1, qLen), + outputShapes: [partialShape, statsShape, statsShape], + outputDTypes: [inputType, .float32, .float32] + ) + + let outputs2 = reduceKernel( + [outputs1[0], outputs1[1], outputs1[2], MLXArray(blocks)], + template: [("InT", inputType), ("V", Vdim), ("M_FIXED", qLen)], + grid: ((B * Hq) * 1024, qLen, 1), + threadGroup: (1024, 1, 1), + outputShapes: [queries.shape], + outputDTypes: [inputType] + ) + + return outputs2[0] + } + + /// Fallback SDPA using MLXFast when batched kernel conditions are not met. + public static func sdpaFallback( + queries: MLXArray, + keys: MLXArray, + values: MLXArray, + scale: Float, + mask: MLXArray? = nil + ) -> MLXArray { + MLXFast.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, mask: mask + ) + } +} + +/// Concrete DFlashKernelProvider that delegates to DFlashKernels static methods. +public final class DFlashKernelsInstance: DFlashKernelProvider, @unchecked Sendable { + public func gatedDeltaKernelWithTape( + q: MLXArray, k: MLXArray, v: MLXArray, + g: MLXArray, beta: MLXArray, + state: MLXArray, mask: MLXArray? + ) -> (MLXArray, MLXArray, MLXArray) { + DFlashKernels.gatedDeltaKernelWithTape( + q: q, k: k, v: v, g: g, beta: beta, + state: state, mask: mask + ) + } +} diff --git a/Sources/DFlash/DFlashKernelsOptimized.swift b/Sources/DFlash/DFlashKernelsOptimized.swift new file mode 100644 index 00000000..10be9b99 --- /dev/null +++ b/Sources/DFlash/DFlashKernelsOptimized.swift @@ -0,0 +1,603 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Based on DFlash (arXiv:2602.06036) +// +// Branchless-optimized: arithmetic masking, select() over branches, +// collapsed kernel caches, fused MACs, zero conditional jumps in hot paths. + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +public enum DFlashKernels { + + public static let shared = DFlashKernelsInstance() + + // MARK: - Kernel Source Factories + + private static func makeTapeReplayKernel(hasMask: Bool, vectorized: Bool) -> MLXFast.MLXFastKernel? { + // Branchless mask: arithmetic gate instead of if-guard around entire loop body. + // `mask_gate` is 1.0 or 0.0; state update is gated by multiplication — no branch. + let maskLoad = hasMask ? "float mask_gate = static_cast(\(#"mask[b_idx * T + t]"#));" + : "constexpr float mask_gate = 1.0f;" + let gSetup = vectorized ? "auto g_ = g + (b_idx * T * Hv + hv_idx) * Dk;" + : "auto g_ = g + b_idx * T * Hv;" + let gAccess = vectorized ? "g_[s_idx]" : "g_[hv_idx]" + let gAdvance = vectorized ? "g_ += Hv * Dk;" : "g_ += Hv;" + + let source = """ + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + + auto tape_ = tape + b_idx * T * Hv * Dv + hv_idx * Dv; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + + \(gSetup) + + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) + state[i] = static_cast(i_state[n_per_t * dk_idx + i]); + + for (int t = 0; t < T; ++t) { + \(maskLoad) + // Branchless: delta scaled by gate; when gate==0 delta==0 → state unchanged. + float delta = static_cast(tape_[dv_idx]) * mask_gate; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + // Fused: decay + accumulate in one expression, no temps. + state[i] = state[i] * \(gAccess) + k_[s_idx] * delta; + state[i] = static_cast(static_cast(state[i])); + } + tape_ += Hv * Dv; + k_ += Hk * Dk; + \(gAdvance) + } + + for (int i = 0; i < n_per_t; ++i) + o_state[n_per_t * dk_idx + i] = static_cast(state[i]); + """ + + var names = ["tape", "k", "g", "state_in", "T"] + if hasMask { names.append("mask") } + let suffix = (vectorized ? "_vec" : "") + (hasMask ? "_mask" : "") + return MLXFast.metalKernel(name: "dflash_tape_replay\(suffix)", + inputNames: names, outputNames: ["state_out"], source: source) + } + + private static func makeGatedDeltaTapeKernel(hasMask: Bool, vectorized: Bool) -> MLXFast.MLXFastKernel? { + // Branchless mask: use_key becomes a float gate multiplied into score and delta. + // metal::select replaces every branch in the inner loop. + let maskLoad = hasMask ? "float mask_gate = static_cast(\(#"mask[b_idx * T + t]"#));" + : "constexpr float mask_gate = 1.0f;" + let gSetup = vectorized ? "auto g_ = g + (b_idx * T * Hv + hv_idx) * Dk;" + : "auto g_ = g + b_idx * T * Hv;" + let gAccess = vectorized ? "g_[s_idx]" : "g_[hv_idx]" + let gAdvance = vectorized ? "g_ += Hv * Dk;" : "g_ += Hv;" + + let source = """ + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + + auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv; + y += b_idx * T * Hv * Dv + hv_idx * Dv; + auto tape_ = innovation_tape + b_idx * T * Hv * Dv + hv_idx * Dv; + auto beta_ = beta + b_idx * T * Hv; + + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + + \(gSetup) + + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) + state[i] = static_cast(i_state[n_per_t * dk_idx + i]); + + for (int t = 0; t < T; ++t) { + \(maskLoad) + // Decay pass — always executes; gate zeroes out the write-back below. + float kv_mem = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] * \(gAccess); + kv_mem += state[i] * k_[s_idx]; + } + kv_mem = simd_sum(kv_mem); + + // Branchless delta: gate multiplies out contribution when masked. + float delta = (static_cast(v_[dv_idx]) - kv_mem) + * static_cast(beta_[hv_idx]) + * mask_gate; + + float out = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] += k_[s_idx] * delta; + out += state[i] * static_cast(q_[s_idx]); + } + out = simd_sum(out); + + // Write output/tape gated by mask_gate (zero when masked). + if (thread_index_in_simdgroup == 0) { + y[dv_idx] = static_cast(out * mask_gate); + tape_[dv_idx] = delta; // already zero-gated above + } + + for (int i = 0; i < n_per_t; ++i) + state[i] = static_cast(static_cast(state[i])); + + q_ += Hk * Dk; + k_ += Hk * Dk; + v_ += Hv * Dv; + y += Hv * Dv; + tape_ += Hv * Dv; + beta_ += Hv; + \(gAdvance) + } + + for (int i = 0; i < n_per_t; ++i) + o_state[n_per_t * dk_idx + i] = static_cast(state[i]); + """ + + var names = ["q", "k", "v", "g", "beta", "state_in", "T"] + if hasMask { names.append("mask") } + let suffix = (vectorized ? "_vec" : "") + (hasMask ? "_mask" : "") + return MLXFast.metalKernel(name: "dflash_gated_delta_tape\(suffix)", + inputNames: names, + outputNames: ["y", "state_out", "innovation_tape"], + source: source) + } + + // MARK: - Kernel Cache (indexed, no repeated branches) + + private final class KernelCache { + static let shared = KernelCache() + // Layout: [vectorized (0/1)][masked (0/1)] + let tapeReplay: [[MLXFast.MLXFastKernel?]] + let gatedDeltaTape: [[MLXFast.MLXFastKernel?]] + private init() { + tapeReplay = [ + [makeTapeReplayKernel(hasMask: false, vectorized: false), + makeTapeReplayKernel(hasMask: true, vectorized: false)], + [makeTapeReplayKernel(hasMask: false, vectorized: true), + makeTapeReplayKernel(hasMask: true, vectorized: true)], + ] + gatedDeltaTape = [ + [makeGatedDeltaTapeKernel(hasMask: false, vectorized: false), + makeGatedDeltaTapeKernel(hasMask: true, vectorized: false)], + [makeGatedDeltaTapeKernel(hasMask: false, vectorized: true), + makeGatedDeltaTapeKernel(hasMask: true, vectorized: true)], + ] + } + } + + // MARK: - Public API: Tape Replay + + public static func tapeReplayKernel( + tape: MLXArray, k: MLXArray, g: MLXArray, + state: MLXArray, mask: MLXArray? = nil + ) -> MLXArray { + let isCPU = Device.defaultDevice().deviceType == .cpu + || ProcessInfo.processInfo.environment["DFLASH_FORCE_OPS"] != nil + let Dk = k.dim(3) + let needFallback = isCPU || Dk < 32 || Dk % 32 != 0 + if needFallback { return tapeReplayOps(tape: tape, k: k, g: g, state: state, mask: mask) } + + let vec = g.ndim == 4 ? 1 : 0 + let msk = mask != nil ? 1 : 0 + guard let kernel = KernelCache.shared.tapeReplay[vec][msk] else { + return tapeReplayOps(tape: tape, k: k, g: g, state: state, mask: mask) + } + + let B = k.dim(0); let Hk = k.dim(2); let Hv = tape.dim(2); let Dv = tape.dim(3) + let steps = k.dim(1); let inputType = state.dtype + var inputs: [MLXArray] = [tape, k, g, state, MLXArray(steps)] + if let mask { inputs.append(mask) } + + return kernel(inputs, + template: [("InT", inputType), ("Dk", Dk), ("Dv", Dv), ("Hk", Hk), ("Hv", Hv)], + grid: (32, Dv, B * Hv), threadGroup: (32, 4, 1), + outputShapes: [state.shape], outputDTypes: [inputType])[0] + } + + // MARK: - Public API: GatedDelta with Tape + + public static func gatedDeltaKernelWithTape( + q: MLXArray, k: MLXArray, v: MLXArray, + g: MLXArray, beta: MLXArray, + state: MLXArray, mask: MLXArray? = nil + ) -> (MLXArray, MLXArray, MLXArray) { + let isCPU = Device.defaultDevice().deviceType == .cpu + || ProcessInfo.processInfo.environment["DFLASH_FORCE_OPS"] != nil + let Dk = k.dim(3) + let needFallback = isCPU || Dk < 32 || Dk % 32 != 0 + if needFallback { return gatedDeltaOpsWithTape(q: q, k: k, v: v, g: g, beta: beta, state: state, mask: mask) } + + let vec = g.ndim == 4 ? 1 : 0 + let msk = mask != nil ? 1 : 0 + guard let kernel = KernelCache.shared.gatedDeltaTape[vec][msk] else { + return gatedDeltaOpsWithTape(q: q, k: k, v: v, g: g, beta: beta, state: state, mask: mask) + } + + let B = k.dim(0); let T = k.dim(1); let Hk = k.dim(2) + let Hv = v.dim(2); let Dv = v.dim(3); let inputType = q.dtype + var inputs: [MLXArray] = [q, k, v, g, beta, state, MLXArray(T)] + if let mask { inputs.append(mask) } + + let out = kernel(inputs, + template: [("InT", inputType), ("Dk", Dk), ("Dv", Dv), ("Hk", Hk), ("Hv", Hv)], + grid: (32, Dv, B * Hv), threadGroup: (32, 4, 1), + outputShapes: [[B, T, Hv, Dv], state.shape, [B, T, Hv, Dv]], + outputDTypes: [inputType, inputType, DType.float32]) + return (out[0], out[1], out[2]) + } + + // MARK: - Fallback: Ops-based implementations + + @inline(__always) + private static func tapeReplayOps( + tape: MLXArray, k: MLXArray, g: MLXArray, + state: MLXArray, mask: MLXArray? + ) -> MLXArray { + let Hv = tape.dim(2); let Hk = k.dim(2) + let repeatFactor = Hv / Hk + let k_ = repeatFactor > 1 ? MLX.repeated(k, count: repeatFactor, axis: 2) : k + let T = tape.dim(1) + var state = state + + for t in 0 ..< T { + let decay: MLXArray = g.ndim == 4 + ? g[0..., t, 0..., .newAxis, 0...] + : expandedDimensions(g[0..., t, 0...], axes: [2, 3]) + let delta = tape[0..., t, 0..., .newAxis] + let kT = expandedDimensions(k_[0..., t, 0...], axis: -2) + let next = state * decay + delta * kT + // Branchless select: arithmetic mask avoids if/else entirely. + if let mask { + let gate = expandedDimensions(mask[0..., t], axes: [1, 2, 3]).asType(state.dtype) + state = next * gate + state * (1 - gate) + } else { + state = next + } + } + return state + } + + @inline(__always) + private static func gatedDeltaOpsWithTape( + q: MLXArray, k: MLXArray, v: MLXArray, + g: MLXArray, beta: MLXArray, + state: MLXArray, mask: MLXArray? + ) -> (MLXArray, MLXArray, MLXArray) { + let Hv = v.dim(2); let Hk = q.dim(2) + let repeatFactor = Hv / Hk + let q_ = repeatFactor > 1 ? MLX.repeated(q, count: repeatFactor, axis: 2) : q + let k_ = repeatFactor > 1 ? MLX.repeated(k, count: repeatFactor, axis: 2) : k + let T = q.dim(1) + + var state = state + var outputs = [MLXArray]() + var tapeEntries = [MLXArray]() + outputs.reserveCapacity(T) + tapeEntries.reserveCapacity(T) + + for t in 0 ..< T { + let decay: MLXArray = g.ndim == 4 + ? g[0..., t, 0..., .newAxis, 0...] + : expandedDimensions(g[0..., t, 0...], axes: [2, 3]) + let decayedState = state * decay + let kvMem = (decayedState * expandedDimensions(k_[0..., t, 0...], axis: -2)).sum(axis: -1) + let delta = (v[0..., t, 0...] - kvMem) * expandedDimensions(beta[0..., t, 0...], axis: -1) + let next = decayedState + expandedDimensions(k_[0..., t, 0...], axis: -2) + * expandedDimensions(delta, axis: -1) + let y = (next * expandedDimensions(q_[0..., t, 0...], axis: -2)).sum(axis: -1) + + if let mask { + // Branchless arithmetic gate — no MLX.where overhead on common path. + let sGate = expandedDimensions(mask[0..., t], axes: [1, 2, 3]).asType(state.dtype) + let yGate = expandedDimensions(mask[0..., t], axes: [1, 2]).asType(y.dtype) + state = next * sGate + state * (1 - sGate) + outputs.append(y * yGate) + tapeEntries.append((delta * yGate).asType(DType.float32)) + } else { + state = next + outputs.append(y) + tapeEntries.append(delta.asType(DType.float32)) + } + } + return (MLX.stacked(outputs, axis: 1), state, MLX.stacked(tapeEntries, axis: 1)) + } + + // MARK: - Block Computation (branchless lookup) + + private static func computeSDPA2PassBlocks(gqaFactor: Int, nKV: Int, deviceArch: String? = nil) -> Int { + let arch = deviceArch ?? Device.defaultDevice().description + let devc = arch.last.map(String.init) ?? "" + + // Encode device: 2=d, 1=s, 0=other — no if/else chain. + let devCode = (devc == "d" ? 2 : 0) | (devc == "s" ? 1 : 0) + + switch devCode { + case 2: // M-series "d" + // Branchless clamp-and-shift: pick log₂ bucket via leading-zero trick. + let base = 128 + let bump1 = (gqaFactor <= 2 && nKV > 8192) ? 1 : 0 // → 256 + let bump2 = (gqaFactor >= 6 && nKV >= 16384) ? 1 : 0 // → 512 or 1024 + let bump3 = (gqaFactor >= 6 && nKV >= 65536) ? 1 : 0 // extra → 1024 + return base << (bump1 + bump2 + bump3) + + case 1: // "s" + guard nKV > 1024 && gqaFactor > 4 else { return 64 } + // Arithmetic shift: each doubling of N → +1 shift, capped at 1024. + let shift = min(max((Int(log2(Double(nKV))) - 10), 0), 4) + return 64 << shift + + default: + return gqaFactor >= 4 ? 64 : 32 + } + } + + // MARK: - Batched SDPA 2-Pass + + private final class SDPAKernelCache { + static let shared = SDPAKernelCache() + // [masked (0/1)] + let partials: [MLXFast.MLXFastKernel?] + let reduce: MLXFast.MLXFastKernel? + private init() { + partials = [makePartialsKernel(hasMask: false), makePartialsKernel(hasMask: true)] + reduce = makeReduceKernel() + } + + private static func makePartialsKernel(hasMask: Bool) -> MLXFast.MLXFastKernel? { + let maskSetup = hasMask ? "auto mask_ = mask + (((b_idx * Hq + q_head_idx) * M_FIXED + q_seq_idx) * N + block_idx);" : "" + // Branchless mask: convert to float and fuse into score. + // Non-masked path: mask_gate is a compile-time constant 1.0. + let maskGate = hasMask + ? "float mask_gate = static_cast(mask_[0]); use_key = use_key & (mask_gate > Limits::finite_min);" + : "constexpr float mask_gate = 0.0f; (void)mask_gate;" + let maskScore = hasMask ? "score += mask_gate;" : "" + let maskAdvance = hasMask ? "mask_ += blocks;" : "" + + var inputs = ["queries","keys","values","gqa_factor","N", + "k_head_stride","k_seq_stride","v_head_stride","v_seq_stride", + "scale","blocks"] + if hasMask { inputs.append("mask") } + + let source = """ + constexpr int BD = 32; + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V / BD; + + auto q_head_idx = threadgroup_position_in_grid.x; + auto b_idx = threadgroup_position_in_grid.y; + auto block_idx = threadgroup_position_in_grid.z; + auto q_seq_idx = thread_position_in_threadgroup.z; + auto simd_lid = thread_index_in_simdgroup; + auto Hq = threadgroups_per_grid.x; + auto hk_idx = q_head_idx / gqa_factor; + auto q_batch_head_idx = b_idx * Hq + q_head_idx; + auto o_offset = q_batch_head_idx * M_FIXED + q_seq_idx; + + auto q_ = queries + (o_offset * D) + simd_lid * qk_per_thread; + auto k_ = keys + ((b_idx * Hk + hk_idx) * k_head_stride) + block_idx * k_seq_stride + simd_lid * qk_per_thread; + auto v_ = values + ((b_idx * Hk + hk_idx) * v_head_stride) + block_idx * v_seq_stride + simd_lid * v_per_thread; + + partials += (o_offset * blocks + block_idx) * V + simd_lid * v_per_thread; + sums += o_offset * blocks + block_idx; + maxs += o_offset * blocks + block_idx; + \(maskSetup) + + thread float q[qk_per_thread]; + thread float o[v_per_thread]; + threadgroup InT tg_k[BD * qk_per_thread]; + threadgroup InT tg_v[BD * v_per_thread]; + + for (int i = 0; i < qk_per_thread; ++i) + q[i] = static_cast(scale) * static_cast(q_[i]); + for (int i = 0; i < v_per_thread; ++i) + o[i] = 0.0f; + + float max_score = Limits::finite_min; + float sum_exp_score = 0.0f; + + for (int n = block_idx; n < N; n += blocks) { + if (q_seq_idx == 0) { + for (int i = 0; i < qk_per_thread; ++i) tg_k[simd_lid * qk_per_thread + i] = k_[i]; + for (int i = 0; i < v_per_thread; ++i) tg_v[simd_lid * v_per_thread + i] = v_[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Branchless causal mask via integer comparison cast to float. + bool use_key = (n <= (N - M_FIXED + q_seq_idx)); + \(maskGate) + + // Compute score unconditionally; select kills contribution when !use_key. + float score = 0.0f; + for (int i = 0; i < qk_per_thread; ++i) + score += q[i] * static_cast(tg_k[simd_lid * qk_per_thread + i]); + score = simd_sum(score); + \(maskScore) + // Blend to -inf when use_key==false — no branch in execution. + score = metal::select(Limits::finite_min, score, use_key); + + float new_max = metal::max(max_score, score); + float factor = fast::exp(max_score - new_max); + float exp_score = fast::exp(score - new_max); + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + for (int i = 0; i < v_per_thread; ++i) + o[i] = o[i] * factor + exp_score * static_cast(tg_v[simd_lid * v_per_thread + i]); + + threadgroup_barrier(mem_flags::mem_threadgroup); + k_ += blocks * int(k_seq_stride); + v_ += blocks * int(v_seq_stride); + \(maskAdvance) + } + + if (simd_lid == 0) { + sums[0] = sum_exp_score; + maxs[0] = max_score; + } + for (int i = 0; i < v_per_thread; ++i) + partials[i] = static_cast(o[i]); + """ + + let suffix = hasMask ? "_mask" : "" + return MLXFast.metalKernel(name: "batched_sdpa_2pass_partials\(suffix)", + inputNames: inputs, + outputNames: ["partials", "sums", "maxs"], + source: source) + } + + private static func makeReduceKernel() -> MLXFast.MLXFastKernel? { + let source = """ + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = V / BD; + + auto head_idx = threadgroup_position_in_grid.x; + auto q_seq_idx = threadgroup_position_in_grid.y; + auto simd_gid = simdgroup_index_in_threadgroup; + auto simd_lid = thread_index_in_simdgroup; + auto q_offset = head_idx * M_FIXED + q_seq_idx; + + partials += (q_offset * blocks + simd_gid) * V + simd_lid * elem_per_thread; + sums += q_offset * blocks; + maxs += q_offset * blocks; + out += q_offset * V + simd_gid * elem_per_thread; + + thread float o[elem_per_thread]; + threadgroup float outputs[BN * BD]; + for (int i = 0; i < elem_per_thread; ++i) o[i] = 0.0f; + + // Two-pass: find global max, then accumulate. + float max_score = Limits::finite_min; + for (int b = 0; b < blocks / BN; ++b) + max_score = metal::max(max_score, maxs[simd_lid + BN * b]); + max_score = simd_max(max_score); + + float sum_exp_score = 0.0f; + for (int b = 0; b < blocks / BN; ++b) + sum_exp_score += fast::exp(maxs[simd_lid + BN * b] - max_score) * sums[simd_lid + BN * b]; + sum_exp_score = simd_sum(sum_exp_score); + + // Branchless reciprocal: avoid division-by-zero via max with epsilon. + float inv_sum = 1.0f / metal::max(sum_exp_score, 1e-9f); + + for (int b = 0; b < blocks / BN; ++b) { + float factor = fast::exp(maxs[simd_gid] - max_score); + for (int i = 0; i < elem_per_thread; ++i) + o[i] += factor * static_cast(partials[i]); + maxs += BN; + partials += BN * V; + } + + for (int i = 0; i < elem_per_thread; ++i) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid]) * inv_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; ++i) + out[i] = static_cast(o[i]); + } + """ + return MLXFast.metalKernel(name: "batched_sdpa_2pass_reduce", + inputNames: ["partials", "sums", "maxs", "blocks"], + outputNames: ["out"], source: source) + } + } + + // MARK: - Public API: Batched SDPA + + public static func batchedSDPA2Pass( + queries: MLXArray, keys: MLXArray, values: MLXArray, + scale: Float, mask: MLXArray? = nil + ) -> MLXArray? { + guard queries.ndim == 4, keys.ndim == 4, values.ndim == 4 else { return nil } + let B = queries.dim(0); let Hq = queries.dim(1) + let qLen = queries.dim(2); let D = queries.dim(3) + let Hk = keys.dim(1); let nKV = keys.dim(2); let Vdim = values.dim(3) + let inputType = queries.dtype + + guard qLen == 16, + inputType == .bfloat16 || inputType == .float16, + (D == 128 || D == 256) && D == Vdim, + Hk > 0 && Hq % Hk == 0 else { return nil } + + let gqaFactor = Hq / Hk + let blocks = computeSDPA2PassBlocks(gqaFactor: gqaFactor, nKV: nKV) + guard blocks > 0 && blocks % 32 == 0 else { return nil } + + let cache = SDPAKernelCache.shared + let msk = mask != nil ? 1 : 0 + guard let partialsKernel = cache.partials[msk], let reduceKernel = cache.reduce else { return nil } + + let qC = MLX.contiguous(queries) + let kC = MLX.contiguous(keys) + let vC = MLX.contiguous(values) + + var inputs: [MLXArray] = [ + qC, kC, vC, + MLXArray(gqaFactor), MLXArray(nKV), + MLXArray(keys.dim(2) * keys.dim(3)), MLXArray(keys.dim(3)), + MLXArray(values.dim(2) * values.dim(3)), MLXArray(values.dim(3)), + MLXArray(scale), MLXArray(blocks) + ] + if let mask { + inputs.append(mask.dtype != inputType ? mask.asType(inputType) : mask) + } + + let partialShape = [B * Hq, qLen, blocks, Vdim] + let statsShape = [B * Hq, qLen, blocks] + + let out1 = partialsKernel(inputs, + template: [("InT", inputType), ("D", D), ("V", Vdim), ("Hk", Hk), ("M_FIXED", qLen)], + grid: (Hq * 32, B, blocks * qLen), threadGroup: (32, 1, qLen), + outputShapes: [partialShape, statsShape, statsShape], + outputDTypes: [inputType, .float32, .float32]) + + let out2 = reduceKernel([out1[0], out1[1], out1[2], MLXArray(blocks)], + template: [("InT", inputType), ("V", Vdim), ("M_FIXED", qLen)], + grid: ((B * Hq) * 1024, qLen, 1), threadGroup: (1024, 1, 1), + outputShapes: [queries.shape], outputDTypes: [inputType]) + return out2[0] + } + + public static func sdpaFallback( + queries: MLXArray, keys: MLXArray, values: MLXArray, + scale: Float, mask: MLXArray? = nil + ) -> MLXArray { + MLXFast.scaledDotProductAttention(queries: queries, keys: keys, values: values, scale: scale, mask: mask) + } +} + +public final class DFlashKernelsInstance: DFlashKernelProvider, @unchecked Sendable { + public func gatedDeltaKernelWithTape( + q: MLXArray, k: MLXArray, v: MLXArray, + g: MLXArray, beta: MLXArray, + state: MLXArray, mask: MLXArray? + ) -> (MLXArray, MLXArray, MLXArray) { + DFlashKernels.gatedDeltaKernelWithTape(q: q, k: k, v: v, g: g, beta: beta, state: state, mask: mask) + } +} diff --git a/Sources/DFlash/DFlashRuntime.swift b/Sources/DFlash/DFlashRuntime.swift new file mode 100644 index 00000000..3dcb1f68 --- /dev/null +++ b/Sources/DFlash/DFlashRuntime.swift @@ -0,0 +1,635 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Based on DFlash (arXiv:2602.06036) + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - Model Introspection Protocol + +/// Protocol that target models can conform to in order to expose their +/// internal structure for DFlash speculative decoding. +/// +/// The DFlash runtime needs to: +/// 1. Access the embedding layer for draft noise embeddings +/// 2. Access the lm_head for draft logits +/// 3. Run a custom forward pass that captures intermediate hidden states +/// 4. Determine if the model has hybrid GDN layers +public protocol DFlashTargetModel: LanguageModel { + /// Embed token IDs and return the embedding vectors. + func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray + + /// Compute logits from hidden states (via lm_head or tied weights). + func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray + + /// Run a forward pass capturing hidden states at the specified layer indices. + /// + /// - Parameters: + /// - inputIDs: Input token IDs [1, seqLen] + /// - cache: The KV cache array + /// - captureLayerIDs: Set of 0-based layer indices whose output to capture + /// - Returns: Tuple of (logits, captured hidden states keyed by layerID+1) + func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) + + /// Whether the model contains hybrid GatedDeltaNet layers. + var dflashIsHybridGDN: Bool { get } + + /// Whether the hybrid GDN layers should use full innovation-tape rollback + /// (RecurrentRollbackCache) vs lightweight snapshot-only rollback + /// (MambaSnapshotCache). Tape rollback is more accurate but ~30% slower + /// on large models due to the per-step innovation tensor overhead. + /// Default: true (tape rollback). + var dflashUseTapeRollback: Bool { get } +} + +// Default: tape rollback for backward compatibility. +public extension DFlashTargetModel { + var dflashUseTapeRollback: Bool { true } +} + +// MARK: - DFlash Generation Event + +/// Events emitted during DFlash generation. +public enum DFlashEvent: Sendable { + /// Prefill completed + case prefill(promptTokenCount: Int, prefillUs: Double) + /// Prefill progress (chunked) + case prefillProgress(tokensProcessed: Int, tokensTotal: Int) + /// A token was generated + case token(tokenID: Int, generatedTokens: Int, acceptanceRatio: Double, cyclesCompleted: Int) + /// Generation summary + case summary(DFlashSummary) +} + +/// Summary statistics for a DFlash generation run. +public struct DFlashSummary: Sendable { + public let elapsedUs: Double + public let promptTokenCount: Int + public let generatedTokenIDs: [Int] + public let acceptedFromDraft: Int + public let acceptanceRatio: Double + public let blockTokens: Int + public let cyclesCompleted: Int + public let phaseTimingsUs: PhaseTimings + + public struct PhaseTimings: Sendable { + public let prefill: Double + public let draft: Double + public let verify: Double + public let replay: Double + } + + public var generationTokens: Int { generatedTokenIDs.count } + public var tokensPerSecond: Double { + let genUs = elapsedUs - phaseTimingsUs.prefill + return genUs > 0 ? Double(generationTokens) / (genUs / 1_000_000.0) : 0 + } +} + +// MARK: - DFlash Runtime + +/// The main DFlash speculative decoding runtime. +/// +/// Orchestrates the block-diffusion draft → verify → accept/reject → rollback +/// cycle for lossless speculative decoding on Apple Silicon. +public enum DFlashRuntime { + + // MARK: - Token Utilities + + /// Build a suppress token mask from a list of token IDs. + public static func buildSuppressTokenMask( + vocabSize: Int, + suppressTokenIDs: [Int]? + ) -> MLXArray? { + let ids = Set((suppressTokenIDs ?? []).filter { $0 >= 0 && $0 < vocabSize }) + guard !ids.isEmpty else { return nil } + var mask = [Bool](repeating: false, count: vocabSize) + for id in ids { mask[id] = true } + return MLXArray(mask) + } + + /// Greedy token selection with optional suppress mask. + public static func greedyTokensWithMask( + logits: MLXArray, + suppressTokenMask: MLXArray? = nil + ) -> MLXArray { + if let mask = suppressTokenMask { + let floor = MLXArray(-1e9, dtype: logits.dtype) + let maskedLogits = MLX.where(mask, floor, logits) + return argMax(maskedLogits, axis: -1).asType(.uint32) + } + return argMax(logits, axis: -1).asType(.uint32) + } + + /// Match the acceptance length between drafted and posterior tokens. + /// Returns the number of consecutive matches starting from position 0. + /// E.g. if drafted=[1,2,3] and posterior=[1,2,5], returns 2. + public static func matchAcceptanceLength( + draftedTokens: MLXArray, + posteriorTokens: MLXArray + ) -> MLXArray { + let count = draftedTokens.dim(0) + guard count > 0 else { return MLXArray(0, dtype: .int32) } + let matches = (draftedTokens .== posteriorTokens).asType(.int32) + // cumprod: [1,1,0,...] for consecutive matches, then sum counts them + return cumprod(matches, axis: 0).sum(axis: 0, keepDims: false) + } + + // MARK: - Target Cache Management + + /// Create the appropriate cache entries for the target model. + /// For hybrid GDN models, replaces MambaCache with a rollback-capable variant: + /// - dflashUseTapeRollback=true → RecurrentRollbackCache (accurate, ~30% slower on large models) + /// - dflashUseTapeRollback=false → MambaSnapshotCache (snapshot-only, O(1) overhead) + public static func makeTargetCache( + targetModel: any DFlashTargetModel + ) -> [KVCache] { + var cache = targetModel.newCache(parameters: nil) + if targetModel.dflashIsHybridGDN { + for i in 0 ..< cache.count { + if cache[i] is MambaCache { + cache[i] = targetModel.dflashUseTapeRollback + ? RecurrentRollbackCache() + : MambaSnapshotCache() + } + } + } + return cache + } + + /// Arm all rollback-capable caches in the target model. + /// RecurrentRollbackCache arms for innovation-tape recording. + /// MambaSnapshotCache takes a lazy state snapshot (O(1), no GPU copy). + /// Plain MambaCache instances are not checkpointed. + public static func armTargetRollback(targetCache: [KVCache], prefixLen: Int) { + for cache in targetCache { + if let rollbackCache = cache as? DFlashRollbackCache { + rollbackCache.armRollback(prefixLen: prefixLen) + } + } + } + + /// Restore the target cache after partial acceptance of draft tokens. + /// + /// RecurrentRollbackCache: replays innovation tape for accepted steps (exact). + /// MambaSnapshotCache: restores pre-verify snapshot (fast, loses accepted steps). + /// KVCacheSimple: trims KV entries for rejected tokens. + /// + /// For KVCacheSimple: trim to remove rejected tokens' KV entries. + /// + /// - Returns: Time spent on replay in nanoseconds + @discardableResult + public static func restoreTargetCacheAfterAcceptance( + _ cacheEntries: [KVCache], + targetLen: Int, + acceptanceLength: Int, + draftedTokens: Int + ) -> Int { + let fullyAccepted = draftedTokens > 0 && acceptanceLength == draftedTokens + var replayNs: Int = 0 + + for cache in cacheEntries { + if let rollbackCache = cache as? DFlashRollbackCache { + if fullyAccepted { + rollbackCache.clearTransients() + continue + } + let startNs = Int(DispatchTime.now().uptimeNanoseconds) + rollbackCache.rollback(nAccepted: acceptanceLength) + replayNs += Int(DispatchTime.now().uptimeNanoseconds) - startNs + } else if let mambaCache = cache as? MambaCache { + // Plain MambaCache (non-rollback): no checkpoint-based rollback available. + // Python doesn't call checkpoint/trim on these. The state contains + // contributions from all verify tokens but we can't undo them. + // Only update the offset to reflect the accepted prefix. + mambaCache.offset = targetLen + } else if cache.isTrimmable { + let offset = cache.offset + if offset > targetLen { + let startNs = Int(DispatchTime.now().uptimeNanoseconds) + cache.trim(offset - targetLen) + replayNs += Int(DispatchTime.now().uptimeNanoseconds) - startNs + } + } + } + + return replayNs + } + + // MARK: - Main Generation Loop + + /// Generate tokens using DFlash speculative decoding. + /// + /// - Parameters: + /// - targetModel: The target (large) language model (must conform to DFlashTargetModel) + /// - draftModel: The DFlash block-diffusion draft model + /// - promptTokens: Pre-tokenized prompt token IDs + /// - maxNewTokens: Maximum number of new tokens to generate + /// - blockTokens: Number of tokens per draft block (default: draft model's block_size) + /// - stopTokenIDs: Token IDs that signal end of generation + /// - suppressTokenIDs: Token IDs to suppress during generation + /// - draftSinkSize: Sink tokens to keep in draft cache + /// - draftWindowSize: Sliding window size for draft cache + /// - Returns: AsyncStream of DFlashEvent values + public static func generate( + targetModel: any DFlashTargetModel, + draftModel: DFlashDraftModel, + promptTokens: [Int], + maxNewTokens: Int, + blockTokens: Int? = nil, + stopTokenIDs: [Int] = [], + suppressTokenIDs: [Int]? = nil, + draftSinkSize: Int = 64, + draftWindowSize: Int = 1024 + ) -> AsyncStream { + // Streaming: yield events from inside the generation loop + // via a Continuation, avoiding the buffered-array bottleneck. + AsyncStream(bufferingPolicy: .unbounded) { continuation in + let task = Task { + generateStreaming( + targetModel: targetModel, + draftModel: draftModel, + promptTokens: promptTokens, + maxNewTokens: maxNewTokens, + blockTokens: blockTokens, + stopTokenIDs: stopTokenIDs, + suppressTokenIDs: suppressTokenIDs, + draftSinkSize: draftSinkSize, + draftWindowSize: draftWindowSize, + yield: { event in + guard !Task.isCancelled else { return } + continuation.yield(event) + } + ) + continuation.finish() + } + continuation.onTermination = { _ in task.cancel() } + } + } + + /// Synchronous generation that returns all events at once. + /// Kept for backward compatibility — delegates to the streaming implementation. + public static func generateSync( + targetModel: any DFlashTargetModel, + draftModel: DFlashDraftModel, + promptTokens: [Int], + maxNewTokens: Int, + blockTokens: Int? = nil, + stopTokenIDs: [Int] = [], + suppressTokenIDs: [Int]? = nil, + draftSinkSize: Int = 64, + draftWindowSize: Int = 1024 + ) -> [DFlashEvent] { + var events: [DFlashEvent] = [] + generateStreaming( + targetModel: targetModel, + draftModel: draftModel, + promptTokens: promptTokens, + maxNewTokens: maxNewTokens, + blockTokens: blockTokens, + stopTokenIDs: stopTokenIDs, + suppressTokenIDs: suppressTokenIDs, + draftSinkSize: draftSinkSize, + draftWindowSize: draftWindowSize, + yield: { events.append($0) } + ) + return events + } + + /// Core streaming generation loop. Takes a yield closure so it can be + /// used both from the async `generate()` (via Continuation) and the + /// synchronous `generateSync()` (buffering into an array). + private static func generateStreaming( + targetModel: any DFlashTargetModel, + draftModel: DFlashDraftModel, + promptTokens: [Int], + maxNewTokens: Int, + blockTokens: Int?, + stopTokenIDs: [Int], + suppressTokenIDs: [Int]?, + draftSinkSize: Int, + draftWindowSize: Int, + yield: (DFlashEvent) -> Void + ) { + let promptLen = promptTokens.count + guard promptLen > 0 && maxNewTokens > 0 else { return } + + let promptArray = MLXArray(promptTokens.map { Int32($0) }).reshaped(1, -1).asType(.uint32) + + // Detect engine and create caches + let engine: any DFlashEngine = targetModel.dflashIsHybridGDN + ? HybridGDNEngine() + : FullAttentionEngine() + + let draftBackend = DFlashDraftBackend() + + var targetCache = makeTargetCache(targetModel: targetModel) + + let draftCache = draftBackend.makeCache( + draftModel: draftModel, + sinkSize: draftSinkSize, + windowSize: draftWindowSize + ) + + let targetLayerIDList = draftModel.targetLayerIDs + let captureLayerIDs = Set(targetLayerIDList.map { $0 + 1 }) + let maskTokenID = draftModel.maskTokenID + + let startNanos = DispatchTime.now().uptimeNanoseconds + + // ── Prefill ──────────────────────────────────────────────── + let prefillStepSize = 2048 + var targetHidden: MLXArray? + var prefillLogits: MLXArray! + + for chunkStart in stride(from: 0, to: promptLen, by: prefillStepSize) { + let chunkEnd = min(chunkStart + prefillStepSize, promptLen) + let chunkIDs = promptArray[0..., chunkStart ..< chunkEnd] + + let (chunkLogits, chunkHidden) = targetModel.dflashForwardWithCapture( + inputIDs: chunkIDs, + cache: targetCache, + captureLayerIDs: captureLayerIDs + ) + + // Batched asyncEval: enqueue everything without blocking + asyncEval(chunkLogits) + for (_, v) in chunkHidden { asyncEval(v) } + + let feat = extractContextFeatureFromDict( + capturedDict: chunkHidden, + targetLayerIDs: targetLayerIDList + ) + + if targetHidden == nil { + targetHidden = MLXArray.zeros( + [feat.dim(0), promptLen, feat.dim(-1)], + dtype: feat.dtype + ) + } + targetHidden![0..., chunkStart ..< chunkEnd, 0...] = feat + eval(targetHidden!) + + prefillLogits = chunkLogits + + if DFlashDumper.isEnabled { + DFlashDumper.save("swift_target_hidden", targetHidden!) + DFlashDumper.save("swift_prefill_logits", chunkLogits) + } + + yield(.prefillProgress( + tokensProcessed: chunkEnd, + tokensTotal: promptLen + )) + } + + MLX.Memory.clearCache() + + let prefillNanos = Int(DispatchTime.now().uptimeNanoseconds) - Int(startNanos) + + let suppressTokenMask = buildSuppressTokenMask( + vocabSize: Int(prefillLogits.dim(-1)), + suppressTokenIDs: suppressTokenIDs + ) + + var stagedFirst = greedyTokensWithMask( + logits: prefillLogits[0..., -1, 0...], + suppressTokenMask: suppressTokenMask + ).reshaped(-1) + + yield(.prefill( + promptTokenCount: promptLen, + prefillUs: Double(prefillNanos) / 1000.0 + )) + + // Yield the first token + let firstTokenID = Int(stagedFirst.item(Int.self)) + yield(.token( + tokenID: firstTokenID, + generatedTokens: 1, + acceptanceRatio: 0.0, + cyclesCompleted: 0 + )) + + // ── Generation Loop ─────────────────────────────────────── + let draftBlockSize = draftModel.blockSize + let requestedBlockTokens = blockTokens ?? draftBlockSize + let effectiveBlockTokens = max(1, min(requestedBlockTokens, draftBlockSize)) + let verifyLenCap = effectiveBlockTokens + + var generatedTokenIDs: [Int] = [] + var acceptedFromDraft = 0 + var cyclesCompleted = 0 + var start = promptLen + var firstTokenYielded = false + + generatedTokenIDs.append(firstTokenID) + firstTokenYielded = true + + let maskTokenTail = MLXArray.full( + [max(0, effectiveBlockTokens - 1)], + values: MLXArray(Int32(maskTokenID), dtype: .uint32) + ) + + var verifyNsTotal: Int = 0 + var draftNsTotal: Int = 0 + var replayNsTotal: Int = 0 + + // Precompute stop token set for O(1) lookup + let stopTokenSet = Set(stopTokenIDs) + + // Prefetch state: the draft for the NEXT cycle can be overlapped + // with the current cycle's rollback. + var prefetchedDraft: MLXArray? + var prefetchedBlockLen: Int? + + while generatedTokenIDs.count < maxNewTokens { + let remaining = maxNewTokens - generatedTokenIDs.count + let blockLen = max(1, min(effectiveBlockTokens, remaining)) + + // ── Draft Phase ────────────────────────────────────── + // Use prefetched draft if available and blockLen matches + var drafted: MLXArray? + let currentStagedFirst = stagedFirst + if blockLen > 1 { + if let pf = prefetchedDraft, prefetchedBlockLen == blockLen { + drafted = pf + prefetchedDraft = nil + prefetchedBlockLen = nil + } else { + let draftStart = Int(DispatchTime.now().uptimeNanoseconds) + drafted = draftBackend.draftGreedy( + targetModel: targetModel, + draftModel: draftModel, + draftCache: draftCache, + stagedFirst: stagedFirst, + targetHidden: targetHidden!, + blockLen: blockLen, + maskTokenTail: maskTokenTail, + suppressTokenMask: suppressTokenMask + ) + draftNsTotal += Int(DispatchTime.now().uptimeNanoseconds) - draftStart + } + if DFlashDumper.isEnabled { + DFlashDumper.save("swift_cycle_draft", drafted ?? MLXArray()) + } + } + + // ── Verify Phase ──────────────────────────────────── + let verifyTokenCount = min(blockLen, verifyLenCap) + let verifyTokenIDs: MLXArray + if blockLen <= 1 { + verifyTokenIDs = currentStagedFirst[..<1] + } else if let drafted = drafted, verifyTokenCount > 1 { + verifyTokenIDs = concatenated( + [currentStagedFirst[..<1], drafted[..<(verifyTokenCount - 1)]], + axis: 0 + ) + } else { + verifyTokenIDs = currentStagedFirst[..<1] + } + let verifyIDs = verifyTokenIDs[.newAxis] + + armTargetRollback(targetCache: targetCache, prefixLen: start) + + let verifyStart = Int(DispatchTime.now().uptimeNanoseconds) + let (verifyLogits, verifyHiddenStates) = targetModel.dflashForwardWithCapture( + inputIDs: verifyIDs, + cache: targetCache, + captureLayerIDs: captureLayerIDs + ) + // Batched asyncEval: enqueue logits + all hidden states without blocking + asyncEval(verifyLogits) + for v in verifyHiddenStates.values { asyncEval(v) } + verifyNsTotal += Int(DispatchTime.now().uptimeNanoseconds) - verifyStart + + // ── Accept/Reject ────────────────────────────────── + let posterior = greedyTokensWithMask( + logits: verifyLogits[0], + suppressTokenMask: suppressTokenMask + ) + // Don't asyncEval(posterior) here — we need .item() immediately below + if DFlashDumper.isEnabled { + DFlashDumper.save("swift_cycle_posterior", posterior) + DFlashDumper.saveInt("swift_cycle_verifyIDs", verifyTokenIDs) + } + + let acceptanceLen: Int + if verifyTokenIDs.dim(0) > 1 { + acceptanceLen = Int( + matchAcceptanceLength( + draftedTokens: verifyTokenIDs[1...], + posteriorTokens: posterior[..<(verifyTokenIDs.dim(0) - 1)] + ).item(Int.self) + ) + } else { + acceptanceLen = 0 + } + print("[DFlash] Cycle \(cyclesCompleted + 1): blockLen=\(blockLen), verifyLen=\(verifyTokenIDs.dim(0)), acceptanceLen=\(acceptanceLen), commitCount=\(1 + acceptanceLen)") + fflush(stdout) + + let committedHidden = extractContextFeatureFromDict( + capturedDict: verifyHiddenStates, + targetLayerIDs: targetLayerIDList + )[0..., ..<(1 + acceptanceLen), 0...] + // asyncEval: don't block — prefetch + rollback can overlap + asyncEval(committedHidden) + + let commitCount = 1 + acceptanceLen + let committedSegment = verifyTokenIDs[..<(commitCount)] + + let stagedFirstNext = posterior[acceptanceLen ..< (acceptanceLen + 1)] + + // ── Prefetch next draft (overlaps with rollback on GPU) ── + let nextRemaining = maxNewTokens - generatedTokenIDs.count - commitCount + let nextBlockLen = max(1, min(effectiveBlockTokens, nextRemaining)) + if nextBlockLen > 1 && generatedTokenIDs.count + commitCount < maxNewTokens { + prefetchedDraft = draftBackend.draftGreedy( + targetModel: targetModel, + draftModel: draftModel, + draftCache: draftCache, + stagedFirst: stagedFirstNext, + targetHidden: committedHidden, + blockLen: nextBlockLen, + maskTokenTail: maskTokenTail, + suppressTokenMask: suppressTokenMask + ) + prefetchedBlockLen = nextBlockLen + asyncEval(prefetchedDraft!) + } else { + prefetchedDraft = nil + prefetchedBlockLen = nil + } + + // ── Rollback ─────────────────────────────────────── + start += commitCount + targetHidden = committedHidden + let replayNs = engine.rollback( + targetCache: targetCache, + targetLen: start, + acceptanceLength: acceptanceLen, + draftedTokens: blockLen - 1 + ) + replayNsTotal += replayNs + cyclesCompleted += 1 + acceptedFromDraft += acceptanceLen + + // ── Emit tokens ─────────────────────────────────── + let committedIDs = committedSegment.asArray(Int.self) + for tokenID in committedIDs { + guard generatedTokenIDs.count < maxNewTokens else { break } + + if firstTokenYielded { + firstTokenYielded = false + continue + } + + generatedTokenIDs.append(tokenID) + + let acceptanceRatio = generatedTokenIDs.count > 0 + ? Double(acceptedFromDraft) / Double(generatedTokenIDs.count) + : 0.0 + yield(.token( + tokenID: tokenID, + generatedTokens: generatedTokenIDs.count, + acceptanceRatio: acceptanceRatio, + cyclesCompleted: cyclesCompleted + )) + } + + // Check for stop tokens (O(1) via Set) + let hit = committedIDs.contains { stopTokenSet.contains($0) } + if hit { break } + + stagedFirst = stagedFirstNext + } + + // ── Summary ──────────────────────────────────────────── + let elapsedNanos = Int(DispatchTime.now().uptimeNanoseconds) - Int(startNanos) + let acceptanceRatio = generatedTokenIDs.count > 0 + ? Double(acceptedFromDraft) / Double(generatedTokenIDs.count) + : 0.0 + + yield(.summary(DFlashSummary( + elapsedUs: Double(elapsedNanos) / 1000.0, + promptTokenCount: promptLen, + generatedTokenIDs: generatedTokenIDs, + acceptedFromDraft: acceptedFromDraft, + acceptanceRatio: acceptanceRatio, + blockTokens: effectiveBlockTokens, + cyclesCompleted: cyclesCompleted, + phaseTimingsUs: .init( + prefill: Double(prefillNanos) / 1000.0, + draft: Double(draftNsTotal) / 1000.0, + verify: Double(verifyNsTotal) / 1000.0, + replay: Double(replayNsTotal) / 1000.0 + ) + ))) + } +} diff --git a/Sources/DFlash/RecurrentRollbackCache.swift b/Sources/DFlash/RecurrentRollbackCache.swift new file mode 100644 index 00000000..9082509a --- /dev/null +++ b/Sources/DFlash/RecurrentRollbackCache.swift @@ -0,0 +1,233 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Based on DFlash (arXiv:2602.06036) + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - DFlashRollbackCache + +public protocol DFlashRollbackCache: AnyObject { + var isArmed: Bool { get } + func armRollback(prefixLen: Int) + func rollback(nAccepted: Int) + func clearTransients() + func recordTape(tape: MLXArray, k: MLXArray, g: MLXArray, qkv: MLXArray) +} + +// MARK: - RecurrentRollbackCache + + +/// A cache for GatedDeltaNet (recurrent) layers that supports +/// speculative decoding rollback via innovation tape replay. +/// +/// Subclasses MambaCache so that `cache as? MambaCache` succeeds in +/// Qwen35GatedDeltaNet.callAsFunction — this is critical for the normal +/// (non-armed) forward pass during prefill to work correctly. +/// +/// During the verify phase, the cache is "armed" which causes the +/// GatedDeltaNet forward pass to record an innovation tape. If draft +/// tokens are rejected, the cache is rolled back by replaying only +/// the accepted steps from the tape. +public final class RecurrentRollbackCache: MambaCache, DFlashRollbackCache, @unchecked Sendable { + + /// Whether the cache is currently armed for tape recording. + private var armed = false + + /// The recorded innovation tape: delta values per step. + private var tape: MLXArray? + /// The recorded keys for tape replay. + private var tapeK: MLXArray? + /// The recorded gates for tape replay. + private var tapeG: MLXArray? + /// The recorded QKV for conv state reconstruction. + private var tapeQKV: MLXArray? + + /// Snapshot of the cache state before the verify pass. + private var snapshotState: [MLXArray?]? + + public init(convKernelSize: Int = 4) { + super.init() + } + + // MARK: - Arming & Recording + + /// Arm the cache for tape recording and snapshot the current state. + /// + /// Uses lazy reference capture (no MLX.contiguous copy) — MLXArray is + /// reference-counted so the old arrays remain alive after the cache is + /// updated during the forward pass. The copy only happens if/when + /// rollback() actually replays the tape. + public func armRollback(prefixLen: Int = 0) { + armed = true + tape = nil + tapeK = nil + tapeG = nil + tapeQKV = nil + // Lazy snapshot: just hold references, no GPU copy needed + snapshotState = [self[0], self[1]] + } + + /// Record the innovation tape from a GatedDeltaNet forward step. + /// Arrays are stored by reference — MLX evaluates them lazily when needed. + public func recordTape( + tape: MLXArray, + k: MLXArray, + g: MLXArray, + qkv: MLXArray + ) { + self.tape = tape + self.tapeK = k + self.tapeG = g + self.tapeQKV = qkv + } + + /// Whether the cache is currently armed. + public var isArmed: Bool { armed } + + // MARK: - Rollback + + /// Roll back the cache to the state after `nAccepted` tokens. + /// Uses tape replay for the recurrent state (slot 1) and + /// conv state reconstruction for slot 0. + public func rollback(nAccepted: Int) { + guard let snapshot = snapshotState else { + clearTransients() + return + } + + // Calculate the offset to restore to + // offset was incremented by the verify forward pass (by verifyLen tokens) + // We need to set it to what it should be after accepting nAccepted+1 tokens + // The Python reference doesn't explicitly manage offset in rollback, + // but the cache offset needs to be consistent for subsequent forward passes. + + // Restore snapshot + if snapshot.count > 0, let s0 = snapshot[0] { self[0] = s0 } + if snapshot.count > 1, let s1 = snapshot[1] { self[1] = s1 } + + // Replay accepted steps through tape + if let tape = tape, let tapeK = tapeK, let tapeG = tapeG, + let state = self[1] + { + let acceptedSteps = nAccepted + 1 + let stateSlice = tape[0..., .. MLXArray? { + guard let tapeQKV = tapeQKV else { return self[0] } + let keep = RecurrentRollbackCache.defaultConvKernelSize - 1 + guard keep > 0 else { return nil } + + let prefix: MLXArray + if let snap = snapshotState, snap.count > 0, let convState = snap[0] { + prefix = convState + } else { + prefix = MLXArray.zeros( + [tapeQKV.dim(0), keep, tapeQKV.dim(-1)], + dtype: tapeQKV.dtype + ) + } + + let convInput = concatenated([prefix, tapeQKV], axis: 1) + let start = acceptedSteps + let end = min(start + keep, convInput.dim(1)) + return convInput[0..., start ..< end, 0...] + } + + // MARK: - Cleanup + + /// Clear all transient state (tape, snapshot, armed flag). + public func clearTransients() { + armed = false + tape = nil + tapeK = nil + tapeG = nil + tapeQKV = nil + snapshotState = nil + } + + // MARK: - Override MambaCache trim to use tape rollback instead + + @discardableResult + public override func trim(_ n: Int) -> Int { + // For recurrent caches with tape, rollback handles trimming + // Don't use the MambaCache checkpoint/trim path + let trimmed = min(offset, n) + offset -= trimmed + return trimmed + } +} + +// MARK: - MambaSnapshotCache + +/// Lightweight snapshot-based rollback for hybrid SSM models (e.g. Qwen3Next). +/// +/// Unlike RecurrentRollbackCache, this does NOT record an innovation tape. +/// On partial acceptance, it restores the pre-verify SSM state snapshot. +/// The accepted tokens' state contributions are lost (state reverts to +/// pre-verify position), but rejected tokens' contamination is prevented. +/// Overhead: O(1) per cycle (lazy reference capture, no GPU copies). +public final class MambaSnapshotCache: MambaCache, DFlashRollbackCache, @unchecked Sendable { + + private var snapshotConv: MLXArray? + private var snapshotRecurrent: MLXArray? + private var armed = false + + public var isArmed: Bool { armed } + + public func armRollback(prefixLen: Int = 0) { + armed = true + // Lazy reference capture — no GPU copy, O(1) + snapshotConv = self[0] + snapshotRecurrent = self[1] + } + + public func rollback(nAccepted: Int) { + // Restore pre-verify state. Accepted tokens' contributions are + // not replayed, but rejected tokens are excluded. + self[0] = snapshotConv + self[1] = snapshotRecurrent + clearTransients() + } + + public func clearTransients() { + armed = false + snapshotConv = nil + snapshotRecurrent = nil + } + + public func recordTape(tape: MLXArray, k: MLXArray, g: MLXArray, qkv: MLXArray) { + // No tape needed for snapshot-based rollback + } + + @discardableResult + public override func trim(_ n: Int) -> Int { + let trimmed = min(offset, n) + offset -= trimmed + return trimmed + } +} + diff --git a/Sources/DFlashKernelBench/main.swift b/Sources/DFlashKernelBench/main.swift new file mode 100644 index 00000000..54b5c934 --- /dev/null +++ b/Sources/DFlashKernelBench/main.swift @@ -0,0 +1,691 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// +// Micro-benchmark for DFlash Metal kernels. +// Run under Metal System Trace: +// xcrun xctrace record --template "Metal System Trace" \ +// --launch .build/release/DFlashKernelBench -- [flags] +// +// Flags: +// --iterations N kernel calls per benchmark (default: 200) +// --warmup N warmup calls before timing (default: 20) +// --kernels list comma-separated subset: tape,gdelta,sdpa,variants,ops (default: tape,gdelta,sdpa) +// --long-ctx include long-context SDPA sizes (nKV 16k, 32k) + +import Foundation +import MLX +import MLXNN +import DFlash +import os.log + +// MARK: - Signpost log + +private let log = OSLog(subsystem: "com.swiftlm.dflash", category: "kernels") + +// MARK: - Helpers + +/// Fill an array with uniform random values in bf16. +private func rand(_ shape: [Int], dtype: DType = .bfloat16) -> MLXArray { + uniform(low: -0.1, high: 0.1, shape, dtype: dtype) +} + +/// Wall-clock time in seconds for one synchronised MLX eval. +private func timeEval(_ body: () -> MLXArray) -> Double { + let arr = body() + let t0 = clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW) + MLX.eval(arr) + let t1 = clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW) + return Double(t1 - t0) * 1e-9 +} + +/// Run `iterations` timed calls, return (median_s, min_s, max_s). +private func measure(label: String, iterations: Int, body: () -> MLXArray) -> (median: Double, min: Double, max: Double) { + var samples = [Double]() + samples.reserveCapacity(iterations) + + let signpostID = OSSignpostID(log: log) + for _ in 0 ..< iterations { + os_signpost(.begin, log: log, name: "kernel", signpostID: signpostID, "%{public}s", label) + let t = timeEval(body) + os_signpost(.end, log: log, name: "kernel", signpostID: signpostID, "%{public}s", label) + samples.append(t) + } + + samples.sort() + let med = samples[samples.count / 2] + return (med, samples.first!, samples.last!) +} + +private func printResult(label: String, r: (median: Double, min: Double, max: Double), extraInfo: String = "") { + let medUs = r.median * 1e6 + let minUs = r.min * 1e6 + let maxUs = r.max * 1e6 + let extra = extraInfo.isEmpty ? "" : " \(extraInfo)" + let pad = label.padding(toLength: 42, withPad: " ", startingAt: 0) + print(String(format: " %@ med %7.1f µs min %7.1f µs max %7.1f µs%@", + pad, medUs, minUs, maxUs, extra)) +} + +/// Theoretical memory bandwidth figure (GB/s) for a kernel that touches `bytes` bytes. +private func bwStr(bytes: Int, seconds: Double) -> String { + let gb = Double(bytes) / 1e9 / seconds + return String(format: "%.1f GB/s", gb) +} + +// MARK: - Argument parsing + +struct Args { + var iterations = 200 + var warmup = 20 + var kernels: Set = ["tape", "gdelta", "sdpa"] + var longCtx = false + + init() { + let argv = CommandLine.arguments + func intArg(_ flag: String, default d: Int) -> Int { + guard let i = argv.firstIndex(of: flag), i + 1 < argv.count else { return d } + return Int(argv[i + 1]) ?? d + } + iterations = intArg("--iterations", default: 200) + warmup = intArg("--warmup", default: 20) + if let i = argv.firstIndex(of: "--kernels"), i + 1 < argv.count { + kernels = Set(argv[i + 1].split(separator: ",").map(String.init)) + } + longCtx = argv.contains("--long-ctx") + } +} + +// MARK: - Tape Replay benchmarks + +/// Shapes matching Qwen3.5 GDN layers: +/// Hk=8, Hv=16, Dk=128, Dv=128, T=blockSize=16, B=1 +private func benchTapeReplay(args: Args) { + print("\n── Tape Replay ──────────────────────────────────────────────────────────") + + let B = 1; let T = 16; let Hk = 8; let Hv = 16; let Dk = 128; let Dv = 128 + + let tape = rand([B, T, Hv, Dv]) + let k = rand([B, T, Hk, Dk]) + let gScalar = rand([B, T, Hv]) // scalar gate + let gVec = rand([B, T, Hv, Dk]) // vectorised gate + let state = rand([B, Hv, Dv, Dk]) + let mask = (uniform(low: 0, high: 1, [B, T]) .>= MLXArray(0.5)).asType(DType.bfloat16) + + // warm up + for _ in 0 ..< args.warmup { + MLX.eval(DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: gScalar, state: state)) + } + + let stateBytes = B * Hv * Dv * Dk * 2 // bfloat16 = 2 bytes + + let r1 = measure(label: "tape_replay scalar-g", iterations: args.iterations) { + DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: gScalar, state: state) + } + printResult(label: "scalar-g, no mask", r: r1, extraInfo: bwStr(bytes: stateBytes * 2, seconds: r1.median)) + + let r2 = measure(label: "tape_replay scalar-g masked", iterations: args.iterations) { + DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: gScalar, state: state, mask: mask) + } + printResult(label: "scalar-g, mask", r: r2, extraInfo: bwStr(bytes: stateBytes * 2, seconds: r2.median)) + + let r3 = measure(label: "tape_replay vec-g", iterations: args.iterations) { + DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: gVec, state: state) + } + printResult(label: "vec-g, no mask", r: r3, extraInfo: bwStr(bytes: stateBytes * 2, seconds: r3.median)) + + let r4 = measure(label: "tape_replay vec-g masked", iterations: args.iterations) { + DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: gVec, state: state, mask: mask) + } + printResult(label: "vec-g, mask", r: r4, extraInfo: bwStr(bytes: stateBytes * 2, seconds: r4.median)) +} + +// MARK: - GatedDelta with Tape benchmarks + +private func benchGatedDelta(args: Args) { + print("\n── GatedDelta + Tape ────────────────────────────────────────────────────") + + let B = 1; let T = 16; let Hk = 8; let Hv = 16; let Dk = 128; let Dv = 128 + + let q = rand([B, T, Hk, Dk]) + let k = rand([B, T, Hk, Dk]) + let v = rand([B, T, Hv, Dv]) + let gScalar = rand([B, T, Hv]) + let gVec = rand([B, T, Hv, Dk]) + let beta = rand([B, T, Hv]) + let state = rand([B, Hv, Dv, Dk]) + let mask = (uniform(low: 0, high: 1, [B, T]) .>= MLXArray(0.5)).asType(DType.bfloat16) + + for _ in 0 ..< args.warmup { + let (y, s, t) = DFlashKernels.gatedDeltaKernelWithTape(q: q, k: k, v: v, g: gScalar, beta: beta, state: state) + MLX.eval(y, s, t) + } + + // bytes read+written per call (approximate): q+k+v+state_in+state_out+tape_out + let callBytes = (B*T*Hk*Dk + B*T*Hk*Dk + B*T*Hv*Dv) * 2 // q,k,v inputs + + B*Hv*Dv*Dk * 2 * 2 // state in+out + + B*T*Hv*Dv * 4 // tape (f32) + + let r1 = measure(label: "gdelta scalar-g", iterations: args.iterations) { + let (y, _, _) = DFlashKernels.gatedDeltaKernelWithTape(q: q, k: k, v: v, g: gScalar, beta: beta, state: state) + return y + } + printResult(label: "scalar-g, no mask", r: r1, extraInfo: bwStr(bytes: callBytes, seconds: r1.median)) + + let r2 = measure(label: "gdelta scalar-g masked", iterations: args.iterations) { + let (y, _, _) = DFlashKernels.gatedDeltaKernelWithTape(q: q, k: k, v: v, g: gScalar, beta: beta, state: state, mask: mask) + return y + } + printResult(label: "scalar-g, mask", r: r2, extraInfo: bwStr(bytes: callBytes, seconds: r2.median)) + + let r3 = measure(label: "gdelta vec-g", iterations: args.iterations) { + let (y, _, _) = DFlashKernels.gatedDeltaKernelWithTape(q: q, k: k, v: v, g: gVec, beta: beta, state: state) + return y + } + printResult(label: "vec-g, no mask", r: r3, extraInfo: bwStr(bytes: callBytes, seconds: r3.median)) + + let r4 = measure(label: "gdelta vec-g masked", iterations: args.iterations) { + let (y, _, _) = DFlashKernels.gatedDeltaKernelWithTape(q: q, k: k, v: v, g: gVec, beta: beta, state: state, mask: mask) + return y + } + printResult(label: "vec-g, mask", r: r4, extraInfo: bwStr(bytes: callBytes, seconds: r4.median)) +} + +// MARK: - Batched SDPA 2-pass benchmarks + +private func benchSDPA(args: Args) { + print("\n── Batched SDPA 2-Pass ──────────────────────────────────────────────────") + + // Shapes: B=1, Hq=32, Hk=8 (GQA 4x), qLen=16, D=128 + // Vary nKV to cover prefill (2k), mid (8k), long (32k) + let B = 1; let Hq = 32; let Hk = 8; let qLen = 16; let D = 128 + let scale = Float(1.0 / sqrt(Float(D))) + + var kvSizes = [512, 2048, 8192] + if args.longCtx { kvSizes += [16384, 32768] } + + let q = rand([B, Hq, qLen, D]) + + for nKV in kvSizes { + let k = rand([B, Hk, nKV, D]) + let v = rand([B, Hk, nKV, D]) + + // warm up + for _ in 0 ..< args.warmup { + if let out = DFlashKernels.batchedSDPA2Pass(queries: q, keys: k, values: v, scale: scale) { + MLX.eval(out) + } + } + + // bytes: read Q + K + V, write output + let readBytes = (B*Hq*qLen*D + B*Hk*nKV*D + B*Hk*nKV*D) * 2 + let writeBytes = B*Hq*qLen*D * 2 + let totalBytes = readBytes + writeBytes + + let r = measure(label: "sdpa nKV=\(nKV)", iterations: args.iterations) { + DFlashKernels.batchedSDPA2Pass(queries: q, keys: k, values: v, scale: scale) ?? q + } + printResult(label: "nKV=\(nKV)", r: r, extraInfo: bwStr(bytes: totalBytes, seconds: r.median)) + + // Also time the MLXFast fallback for comparison + let rf = measure(label: "sdpa_fallback nKV=\(nKV)", iterations: args.iterations) { + DFlashKernels.sdpaFallback(queries: q, keys: k, values: v, scale: scale) + } + printResult(label: "nKV=\(nKV) [MLXFast fallback]", r: rf, extraInfo: bwStr(bytes: totalBytes, seconds: rf.median)) + + let speedup = rf.median / r.median + print(String(format: " → custom vs fallback: %.2fx", speedup)) + } +} + +// MARK: - Kernel Variant Comparison (branching vs branchless Metal source) + +private func benchKernelVariants(args: Args) { + print("\n── Kernel Variants: Branching vs Branchless ─────────────────────────────") + + let B = 1; let T = 16; let Hk = 8; let Hv = 16; let Dk = 128; let Dv = 128 + let tape = rand([B, T, Hv, Dv]) + let k = rand([B, T, Hk, Dk]) + let g = rand([B, T, Hv]) + let state = rand([B, Hv, Dv, Dk]) + let mask = (uniform(low: 0, high: 1, [B, T]) .>= MLXArray(0.5)).asType(DType.bfloat16) + let q = rand([B, T, Hk, Dk]) + let v = rand([B, T, Hv, Dv]) + let beta = rand([B, T, Hv]) + + let inputType = DType.bfloat16 + + // ── Tape Replay ────────────────────────────────────────────────────────── + + // Current: if-guard wraps entire inner loop body; two-line state update + let tapeBranchingSrc = """ + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + auto tape_ = tape + b_idx * T * Hv * Dv + hv_idx * Dv; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + auto g_ = g + b_idx * T * Hv; + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = static_cast(i_state[s_idx]); + } + for (int t = 0; t < T; ++t) { + if (mask[b_idx * T + t]) { + auto delta = static_cast(tape_[dv_idx]); + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] * g_[hv_idx]; + state[i] = state[i] + k_[s_idx] * delta; + } + for (int i = 0; i < n_per_t; ++i) { + state[i] = static_cast(static_cast(state[i])); + } + } + tape_ += Hv * Dv; + k_ += Hk * Dk; + g_ += Hv; + } + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + o_state[s_idx] = static_cast(state[i]); + } + """ + + // Corrected: metal::select — no decay when masked, no branch, correct semantics + let tapeSelectSrc = """ + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + auto tape_ = tape + b_idx * T * Hv * Dv + hv_idx * Dv; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + auto g_ = g + b_idx * T * Hv; + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) + state[i] = static_cast(i_state[n_per_t * dk_idx + i]); + for (int t = 0; t < T; ++t) { + bool do_step = static_cast(mask[b_idx * T + t]) > 0.5f; + float delta = static_cast(tape_[dv_idx]); + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + float next = state[i] * g_[hv_idx] + k_[s_idx] * delta; + next = static_cast(static_cast(next)); + state[i] = metal::select(state[i], next, do_step); + } + tape_ += Hv * Dv; + k_ += Hk * Dk; + g_ += Hv; + } + for (int i = 0; i < n_per_t; ++i) + o_state[n_per_t * dk_idx + i] = static_cast(state[i]); + """ + + let tapeKernelBranching = MLXFast.metalKernel( + name: "bench_tape_branching_mask", + inputNames: ["tape", "k", "g", "state_in", "T", "mask"], + outputNames: ["state_out"], + source: tapeBranchingSrc + ) + + let tapeKernelSelect = MLXFast.metalKernel( + name: "bench_tape_select_mask", + inputNames: ["tape", "k", "g", "state_in", "T", "mask"], + outputNames: ["state_out"], + source: tapeSelectSrc + ) + + let steps = T + func runTape(_ kernel: MLXFast.MLXFastKernel) -> MLXArray { + kernel( + [tape, k, g, state, MLXArray(steps), mask], + template: [("InT", inputType), ("Dk", Dk), ("Dv", Dv), ("Hk", Hk), ("Hv", Hv)], + grid: (32, Dv, B * Hv), threadGroup: (32, 4, 1), + outputShapes: [state.shape], outputDTypes: [inputType] + )[0] + } + + for _ in 0 ..< args.warmup { + MLX.eval(runTape(tapeKernelBranching)) + MLX.eval(runTape(tapeKernelSelect)) + } + + let stateBytes = B * Hv * Dv * Dk * 2 + let r1 = measure(label: "bench_tape_branching_mask", iterations: args.iterations) { + runTape(tapeKernelBranching) + } + printResult(label: "tape branching (scalar-g, masked)", r: r1, + extraInfo: bwStr(bytes: stateBytes * 2, seconds: r1.median)) + + let r2 = measure(label: "bench_tape_select_mask", iterations: args.iterations) { + runTape(tapeKernelSelect) + } + printResult(label: "tape select (scalar-g, masked)", r: r2, + extraInfo: bwStr(bytes: stateBytes * 2, seconds: r2.median)) + print(String(format: " → select vs branching: %.2fx", r1.median / r2.median)) + + // ── GatedDelta + Tape ───────────────────────────────────────────────────── + + // Current: if-guard, separate decay and accumulate assignments + let gdeltaBranchingSrc = """ + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv; + y += b_idx * T * Hv * Dv + hv_idx * Dv; + auto tape_ = innovation_tape + b_idx * T * Hv * Dv + hv_idx * Dv; + auto g_ = g + b_idx * T * Hv; + auto beta_ = beta + b_idx * T * Hv; + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = static_cast(i_state[s_idx]); + } + for (int t = 0; t < T; ++t) { + float delta = 0.0f; + if (mask[b_idx * T + t]) { + float kv_mem = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] * g_[hv_idx]; + kv_mem += state[i] * k_[s_idx]; + } + kv_mem = simd_sum(kv_mem); + delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx]; + float out = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] + k_[s_idx] * delta; + out += state[i] * q_[s_idx]; + } + out = simd_sum(out); + if (thread_index_in_simdgroup == 0) { + y[dv_idx] = static_cast(out); + } + } + if (thread_index_in_simdgroup == 0) { + tape_[dv_idx] = delta; + } + for (int i = 0; i < n_per_t; ++i) { + state[i] = static_cast(static_cast(state[i])); + } + q_ += Hk * Dk; k_ += Hk * Dk; v_ += Hv * Dv; + y += Hv * Dv; tape_ += Hv * Dv; g_ += Hv; beta_ += Hv; + } + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + o_state[s_idx] = static_cast(state[i]); + } + """ + + // Corrected: uniform predicate skips simd_sums when masked (no divergence); + // metal::select restores pre-decay state when !do_step. + let gdeltaSelectSrc = """ + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv; + y += b_idx * T * Hv * Dv + hv_idx * Dv; + auto tape_ = innovation_tape + b_idx * T * Hv * Dv + hv_idx * Dv; + auto g_ = g + b_idx * T * Hv; + auto beta_ = beta + b_idx * T * Hv; + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) + state[i] = static_cast(i_state[n_per_t * dk_idx + i]); + for (int t = 0; t < T; ++t) { + bool do_step = static_cast(mask[b_idx * T + t]) > 0.5f; + float old_state[n_per_t]; + float kv_mem = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + old_state[i] = state[i]; + state[i] = state[i] * g_[hv_idx]; + kv_mem += state[i] * k_[s_idx]; + } + float delta = 0.0f; + float out = 0.0f; + if (do_step) { + kv_mem = simd_sum(kv_mem); + delta = (static_cast(v_[dv_idx]) - kv_mem) + * static_cast(beta_[hv_idx]); + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] += k_[s_idx] * delta; + out += state[i] * static_cast(q_[s_idx]); + } + out = simd_sum(out); + } + if (thread_index_in_simdgroup == 0) { + y[dv_idx] = static_cast(out); + tape_[dv_idx] = delta; + } + for (int i = 0; i < n_per_t; ++i) { + float quant_new = static_cast(static_cast(state[i])); + state[i] = metal::select(old_state[i], quant_new, do_step); + } + q_ += Hk * Dk; k_ += Hk * Dk; v_ += Hv * Dv; + y += Hv * Dv; tape_ += Hv * Dv; g_ += Hv; beta_ += Hv; + } + for (int i = 0; i < n_per_t; ++i) + o_state[n_per_t * dk_idx + i] = static_cast(state[i]); + """ + + let gdeltaKernelBranching = MLXFast.metalKernel( + name: "bench_gdelta_branching_mask", + inputNames: ["q", "k", "v", "g", "beta", "state_in", "T", "mask"], + outputNames: ["y", "state_out", "innovation_tape"], + source: gdeltaBranchingSrc + ) + + let gdeltaKernelSelect = MLXFast.metalKernel( + name: "bench_gdelta_select_mask", + inputNames: ["q", "k", "v", "g", "beta", "state_in", "T", "mask"], + outputNames: ["y", "state_out", "innovation_tape"], + source: gdeltaSelectSrc + ) + + func runGdelta(_ kernel: MLXFast.MLXFastKernel) -> MLXArray { + kernel( + [q, k, v, g, beta, state, MLXArray(steps), mask], + template: [("InT", inputType), ("Dk", Dk), ("Dv", Dv), ("Hk", Hk), ("Hv", Hv)], + grid: (32, Dv, B * Hv), threadGroup: (32, 4, 1), + outputShapes: [[B, T, Hv, Dv], state.shape, [B, T, Hv, Dv]], + outputDTypes: [inputType, inputType, DType.float32] + )[0] + } + + for _ in 0 ..< args.warmup { + MLX.eval(runGdelta(gdeltaKernelBranching)) + MLX.eval(runGdelta(gdeltaKernelSelect)) + } + + let callBytes = (B*T*Hk*Dk + B*T*Hk*Dk + B*T*Hv*Dv) * 2 + + B*Hv*Dv*Dk * 2 * 2 + + B*T*Hv*Dv * 4 + + let r3 = measure(label: "bench_gdelta_branching_mask", iterations: args.iterations) { + runGdelta(gdeltaKernelBranching) + } + printResult(label: "gdelta branching (scalar-g, masked)", r: r3, + extraInfo: bwStr(bytes: callBytes, seconds: r3.median)) + + let r4 = measure(label: "bench_gdelta_select_mask", iterations: args.iterations) { + runGdelta(gdeltaKernelSelect) + } + printResult(label: "gdelta select (scalar-g, masked)", r: r4, + extraInfo: bwStr(bytes: callBytes, seconds: r4.median)) + print(String(format: " → select vs branching: %.2fx", r3.median / r4.median)) +} + +// MARK: - Ops Fallback Comparison (MLX.where vs arithmetic masking) + +private func benchOpsFallback(args: Args) { + print("\n── Ops Fallback: MLX.where vs Arithmetic Masking ───────────────────────") + + let B = 1; let T = 16; let Hk = 8; let Hv = 16; let Dk = 128; let Dv = 128 + let tape = rand([B, T, Hv, Dv]) + let k = rand([B, T, Hk, Dk]) + let g = rand([B, T, Hv]) + let state = rand([B, Hv, Dv, Dk]) + let mask = (uniform(low: 0, high: 1, [B, T]) .>= MLXArray(0.5)).asType(DType.bfloat16) + let q = rand([B, T, Hk, Dk]) + let v = rand([B, T, Hv, Dv]) + let beta = rand([B, T, Hv]) + + // ── Tape Replay Ops ─────────────────────────────────────────────────────── + + // Current: MLX.where selects between new state and old state + func tapeOpsWhere() -> MLXArray { + let k_ = MLX.repeated(k, count: Hv / Hk, axis: 2) + var st = state + for t in 0 ..< T { + let prev = st + let decay = expandedDimensions(g[0..., t, 0...], axes: [2, 3]) + let delta = tape[0..., t, 0..., .newAxis] + let kT = expandedDimensions(k_[0..., t, 0...], axis: -2) + st = st * decay + delta * kT + let stepMask = mask[0..., t][.newAxis, .newAxis, .newAxis] + st = MLX.where(stepMask, st, prev) + } + return st + } + + // Optimized: arithmetic gate — next * gate + state * (1 - gate) + func tapeOpsArith() -> MLXArray { + let k_ = MLX.repeated(k, count: Hv / Hk, axis: 2) + var st = state + for t in 0 ..< T { + let decay = expandedDimensions(g[0..., t, 0...], axes: [2, 3]) + let delta = tape[0..., t, 0..., .newAxis] + let kT = expandedDimensions(k_[0..., t, 0...], axis: -2) + let next = st * decay + delta * kT + let gate = expandedDimensions(mask[0..., t], axes: [1, 2, 3]).asType(st.dtype) + st = next * gate + st * (1 - gate) + } + return st + } + + for _ in 0 ..< args.warmup { + MLX.eval(tapeOpsWhere()) + MLX.eval(tapeOpsArith()) + } + + let r1 = measure(label: "tape_ops_where", iterations: args.iterations) { tapeOpsWhere() } + printResult(label: "tape ops MLX.where (scalar-g, masked)", r: r1) + + let r2 = measure(label: "tape_ops_arith", iterations: args.iterations) { tapeOpsArith() } + printResult(label: "tape ops arith gate (scalar-g, masked)", r: r2) + print(String(format: " → arith vs where: %.2fx", r1.median / r2.median)) + + // ── GatedDelta + Tape Ops ───────────────────────────────────────────────── + + // Current: MLX.where for state and output gating + func gdeltaOpsWhere() -> MLXArray { + let rf = Hv / Hk + let q_ = MLX.repeated(q, count: rf, axis: 2) + let k_ = MLX.repeated(k, count: rf, axis: 2) + var st = state + var outs = [MLXArray]() + outs.reserveCapacity(T) + for t in 0 ..< T { + let oldSt = st + let decay = expandedDimensions(g[0..., t, 0...], axes: [2, 3]) + let decayed = st * decay + let kvMem = (decayed * expandedDimensions(k_[0..., t, 0...], axis: -2)).sum(axis: -1) + let delta = (v[0..., t, 0...] - kvMem) * expandedDimensions(beta[0..., t, 0...], axis: -1) + let newSt = decayed + expandedDimensions(k_[0..., t, 0...], axis: -2) + * expandedDimensions(delta, axis: -1) + let y = (newSt * expandedDimensions(q_[0..., t, 0...], axis: -2)).sum(axis: -1) + let sMask = mask[0..., t][.newAxis, .newAxis, .newAxis] + let yMask = mask[0..., t][.newAxis, .newAxis] + st = MLX.where(sMask, newSt, oldSt) + outs.append(MLX.where(yMask, y, MLXArray.zeros(y.shape, dtype: y.dtype))) + } + return MLX.stacked(outs, axis: 1) + } + + // Optimized: arithmetic gate — no MLX.where + func gdeltaOpsArith() -> MLXArray { + let rf = Hv / Hk + let q_ = MLX.repeated(q, count: rf, axis: 2) + let k_ = MLX.repeated(k, count: rf, axis: 2) + var st = state + var outs = [MLXArray]() + outs.reserveCapacity(T) + for t in 0 ..< T { + let decay = expandedDimensions(g[0..., t, 0...], axes: [2, 3]) + let decayed = st * decay + let kvMem = (decayed * expandedDimensions(k_[0..., t, 0...], axis: -2)).sum(axis: -1) + let delta = (v[0..., t, 0...] - kvMem) * expandedDimensions(beta[0..., t, 0...], axis: -1) + let next = decayed + expandedDimensions(k_[0..., t, 0...], axis: -2) + * expandedDimensions(delta, axis: -1) + let y = (next * expandedDimensions(q_[0..., t, 0...], axis: -2)).sum(axis: -1) + let sGate = expandedDimensions(mask[0..., t], axes: [1, 2, 3]).asType(st.dtype) + let yGate = expandedDimensions(mask[0..., t], axes: [1, 2]).asType(y.dtype) + st = next * sGate + st * (1 - sGate) + outs.append(y * yGate) + } + return MLX.stacked(outs, axis: 1) + } + + for _ in 0 ..< args.warmup { + MLX.eval(gdeltaOpsWhere()) + MLX.eval(gdeltaOpsArith()) + } + + let r3 = measure(label: "gdelta_ops_where", iterations: args.iterations) { gdeltaOpsWhere() } + printResult(label: "gdelta ops MLX.where (scalar-g, masked)", r: r3) + + let r4 = measure(label: "gdelta_ops_arith", iterations: args.iterations) { gdeltaOpsArith() } + printResult(label: "gdelta ops arith gate (scalar-g, masked)", r: r4) + print(String(format: " → arith vs where: %.2fx", r3.median / r4.median)) +} + +// MARK: - Main + +let args = Args() + +print("DFlash Kernel Micro-Benchmark") +print("═══════════════════════════════════════════════════════════════════════") +print(" Device: \(Device.defaultDevice().description)") +print(" Iterations: \(args.iterations) Warmup: \(args.warmup)") +print(" Kernels: \(args.kernels.sorted().joined(separator: ", "))") +print(" Long-ctx: \(args.longCtx)") +print("═══════════════════════════════════════════════════════════════════════") + +// Force GPU initialisation before any timing +MLX.eval(MLX.zeros([1])) + +if args.kernels.contains("tape") { benchTapeReplay(args: args) } +if args.kernels.contains("gdelta") { benchGatedDelta(args: args) } +if args.kernels.contains("sdpa") { benchSDPA(args: args) } +if args.kernels.contains("variants") { benchKernelVariants(args: args) } +if args.kernels.contains("ops") { benchOpsFallback(args: args) } + +print("\nDone.") diff --git a/Sources/SwiftLM/DFlashModelRegistry.swift b/Sources/SwiftLM/DFlashModelRegistry.swift new file mode 100644 index 00000000..35430c8f --- /dev/null +++ b/Sources/SwiftLM/DFlashModelRegistry.swift @@ -0,0 +1,38 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// +// Registers SwiftLM-owned DFlash model types with the shared LLMTypeRegistry, +// overriding any MLXLLM defaults so DFlashTargetModel conformance is available. +// +// Called once at startup, before any model loading. + +import Foundation +import MLXLLM +import MLXLMCommon + +/// Register SwiftLM-owned model types that conform to DFlashTargetModel. +/// +/// Must be called before any `LLMModelFactory.shared.loadContainer()` call so +/// that the factory produces SwiftLM types (which carry DFlash conformance) +/// rather than the MLXLLM defaults. +func registerDFlashModelTypes() async { + let registry = LLMTypeRegistry.shared + + // DeepSeek V3 — override MLXLLM default with DFlash-capable version. + await registry.registerModelType("deepseek_v3") { data in + let config = try JSONDecoder.json5().decode(DSV3Config.self, from: data) + return DeepseekV3DFlashModel(config) + } + + // kimi_k25 uses the DeepSeek V3 architecture (different model_type string only). + await registry.registerModelType("kimi_k25") { data in + let config = try JSONDecoder.json5().decode(DSV3Config.self, from: data) + return DeepseekV3DFlashModel(config) + } + + // Kimi linear — hybrid KDA/MLA architecture (kimi 2.6). + await registry.registerModelType("kimi_linear") { data in + let config = try JSONDecoder.json5().decode(KimiLinearConfiguration.self, from: data) + return KimiLinearDFlashModel(config) + } +} diff --git a/Sources/SwiftLM/DeepseekV3DFlash.swift b/Sources/SwiftLM/DeepseekV3DFlash.swift new file mode 100644 index 00000000..6432a883 --- /dev/null +++ b/Sources/SwiftLM/DeepseekV3DFlash.swift @@ -0,0 +1,486 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// +// DeepSeek V3 model owned by SwiftLM with DFlash speculative decoding support. +// +// Port of mlx-lm/mlx_lm/models/deepseek_v3.py +// Also handles kimi_k25 model type (wraps the same architecture). +// +// Kept in SwiftLM to avoid upstream submodule changes: +// callCapturing and DFlashTargetModel conformance live here alongside +// the model implementation so no public API surface is needed in MLXLLM. + +import DFlash +import Foundation +import MLX +import MLXLLM +import MLXLMCommon +import MLXNN + +// MARK: - Configuration + +struct DSV3Config: Codable, Sendable { + var vocabSize: Int + var hiddenSize: Int + var intermediateSize: Int + var moeIntermediateSize: Int + var numHiddenLayers: Int + var numAttentionHeads: Int + var numKeyValueHeads: Int + var nSharedExperts: Int? + var nRoutedExperts: Int? + var routedScalingFactor: Float + var kvLoraRank: Int + var qLoraRank: Int? + var qkRopeHeadDim: Int + var vHeadDim: Int + var qkNopeHeadDim: Int + var normTopkProb: Bool + var nGroup: Int? + var topkGroup: Int? + var numExpertsPerTok: Int? + var moeLayerFreq: Int + var firstKDenseReplace: Int + var maxPositionEmbeddings: Int + var rmsNormEps: Float + var ropeTheta: Float + var ropeScaling: [String: StringOrNumber]? + var attentionBias: Bool + + enum CodingKeys: String, CodingKey { + case vocabSize = "vocab_size" + case hiddenSize = "hidden_size" + case intermediateSize = "intermediate_size" + case moeIntermediateSize = "moe_intermediate_size" + case numHiddenLayers = "num_hidden_layers" + case numAttentionHeads = "num_attention_heads" + case numKeyValueHeads = "num_key_value_heads" + case nSharedExperts = "n_shared_experts" + case nRoutedExperts = "n_routed_experts" + case routedScalingFactor = "routed_scaling_factor" + case kvLoraRank = "kv_lora_rank" + case qLoraRank = "q_lora_rank" + case qkRopeHeadDim = "qk_rope_head_dim" + case vHeadDim = "v_head_dim" + case qkNopeHeadDim = "qk_nope_head_dim" + case normTopkProb = "norm_topk_prob" + case nGroup = "n_group" + case topkGroup = "topk_group" + case numExpertsPerTok = "num_experts_per_tok" + case moeLayerFreq = "moe_layer_freq" + case firstKDenseReplace = "first_k_dense_replace" + case maxPositionEmbeddings = "max_position_embeddings" + case rmsNormEps = "rms_norm_eps" + case ropeTheta = "rope_theta" + case ropeScaling = "rope_scaling" + case attentionBias = "attention_bias" + } +} + +// MARK: - Helpers + +private func clippedSilu(_ x: MLXArray) -> MLXArray { + clip(x * sigmoid(x), min: -100, max: 100) +} + +// MARK: - Attention + +private class DSV3Attention: Module { + let numHeads: Int + let qLoraRank: Int? + let qkRopeHeadDim: Int + let kvLoraRank: Int + let vHeadDim: Int + let qkNopeHeadDim: Int + let qHeadDim: Int + var scale: Float + + let rope: RoPELayer + @ModuleInfo(key: "q_proj") var qProj: Linear? + @ModuleInfo(key: "q_a_proj") var qAProj: Linear? + @ModuleInfo(key: "q_a_layernorm") var qALayerNorm: RMSNorm? + @ModuleInfo(key: "q_b_proj") var qBProj: Linear? + @ModuleInfo(key: "o_proj") var oProj: Linear + @ModuleInfo(key: "kv_a_proj_with_mqa") var kvAProjWithMqa: Linear + @ModuleInfo(key: "kv_a_layernorm") var kvALayerNorm: RMSNorm + @ModuleInfo(key: "kv_b_proj") var kvBProj: Linear + + init(config: DSV3Config) { + numHeads = config.numAttentionHeads + qLoraRank = config.qLoraRank + qkRopeHeadDim = config.qkRopeHeadDim + kvLoraRank = config.kvLoraRank + vHeadDim = config.vHeadDim + qkNopeHeadDim = config.qkNopeHeadDim + qHeadDim = config.qkNopeHeadDim + config.qkRopeHeadDim + scale = pow(Float(qHeadDim), -0.5) + + if let r = config.qLoraRank { + _qAProj.wrappedValue = Linear(config.hiddenSize, r, bias: config.attentionBias) + _qALayerNorm.wrappedValue = RMSNorm(dimensions: r) + _qBProj.wrappedValue = Linear(r, numHeads * qHeadDim, bias: false) + } else { + _qProj.wrappedValue = Linear(config.hiddenSize, numHeads * qHeadDim, bias: false) + } + + _kvAProjWithMqa.wrappedValue = Linear( + config.hiddenSize, kvLoraRank + qkRopeHeadDim, bias: config.attentionBias) + _kvALayerNorm.wrappedValue = RMSNorm(dimensions: kvLoraRank) + _kvBProj.wrappedValue = Linear( + kvLoraRank, numHeads * (qHeadDim - qkRopeHeadDim + vHeadDim), bias: false) + _oProj.wrappedValue = Linear(numHeads * vHeadDim, config.hiddenSize, bias: config.attentionBias) + + if let ropeScaling = config.ropeScaling { + let mScaleAllDim = ropeScaling["mscale_all_dim"]?.asFloat() ?? 0.0 + if mScaleAllDim != 0 { + let scalingFactor = ropeScaling["factor"]?.asFloat() ?? 1.0 + if scalingFactor > 1 { + let s = 0.1 * mScaleAllDim * log(scalingFactor) + 1.0 + scale = scale * s * s + } + } + } + + rope = initializeRope( + dims: qkRopeHeadDim, base: config.ropeTheta, traditional: true, + scalingConfig: config.ropeScaling, + maxPositionEmbeddings: config.maxPositionEmbeddings) + } + + func callAsFunction( + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache? + ) -> MLXArray { + let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2)) + + var q: MLXArray + if qLoraRank == nil { + q = qProj!(x) + } else { + q = qBProj!(qALayerNorm!(qAProj!(x))) + } + + q = q.reshaped(B, L, numHeads, qHeadDim).transposed(0, 2, 1, 3) + let splitQ = split(q, indices: [qkNopeHeadDim], axis: -1) + var (qNope, qPe) = (splitQ[0], splitQ[1]) + + var compressedKv = kvAProjWithMqa(x) + let splitKv = split(compressedKv, indices: [kvLoraRank], axis: -1) + compressedKv = splitKv[0] + var kPe = splitKv[1] + kPe = kPe.reshaped(B, L, 1, qkRopeHeadDim).transposed(0, 2, 1, 3) + + var kv = kvBProj(kvALayerNorm(compressedKv)) + kv = kv.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3) + let splitKV2 = split(kv, indices: [qkNopeHeadDim], axis: -1) + var (kNope, values) = (splitKV2[0], splitKV2[1]) + + qPe = applyRotaryPosition(rope, to: qPe, cache: cache) + kPe = applyRotaryPosition(rope, to: kPe, cache: cache) + kPe = repeated(kPe, count: numHeads, axis: 1) + + var keys: MLXArray + if let cache { + (keys, values) = cache.update( + keys: concatenated([kNope, kPe], axis: -1), values: values) + } else { + keys = concatenated([kNope, kPe], axis: -1) + } + + let queries = concatenated([qNope, qPe], axis: -1) + let output = attentionWithCacheUpdate( + queries: queries, keys: keys, values: values, + cache: cache, scale: scale, mask: mask + ).transposed(0, 2, 1, 3).reshaped(B, L, -1) + + return oProj(output) + } +} + +// MARK: - MLP + +private class DSV3MLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gateProj: Linear + @ModuleInfo(key: "up_proj") var upProj: Linear + @ModuleInfo(key: "down_proj") var downProj: Linear + + init(config: DSV3Config, hiddenSize: Int? = nil, intermediateSize: Int? = nil) { + let h = hiddenSize ?? config.hiddenSize + let i = intermediateSize ?? config.intermediateSize + _gateProj.wrappedValue = Linear(h, i, bias: false) + _upProj.wrappedValue = Linear(h, i, bias: false) + _downProj.wrappedValue = Linear(i, h, bias: false) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + downProj(silu(gateProj(x)) * upProj(x)) + } +} + +// MARK: - MoE Gate + +private class DSV3MoEGate: Module { + let topK: Int + let normTopkProb: Bool + let nRoutedExperts: Int + let routedScalingFactor: Float + let nGroup: Int + let topkGroup: Int + + var weight: MLXArray + var e_score_correction_bias: MLXArray + + init(config: DSV3Config) { + topK = config.numExpertsPerTok ?? 1 + normTopkProb = config.normTopkProb + nRoutedExperts = config.nRoutedExperts ?? 1 + routedScalingFactor = config.routedScalingFactor + nGroup = config.nGroup ?? 1 + topkGroup = config.topkGroup ?? 1 + weight = zeros([nRoutedExperts, config.hiddenSize]) + e_score_correction_bias = zeros([nRoutedExperts]) + } + + func callAsFunction(_ x: MLXArray) -> (MLXArray, MLXArray) { + let (bsz, seqLen, _) = (x.dim(0), x.dim(1), x.dim(2)) + let hiddenStates = x.matmul(weight.T) + var scores = sigmoid(hiddenStates) + let scoresForChoice = scores + e_score_correction_bias + let groupScores = scoresForChoice.reshaped(bsz, seqLen, nGroup, -1) + let topKGroup = top(groupScores, k: 2, axis: -1).sum(axis: -1, keepDims: true) + let k = nGroup - topkGroup + var groupIdx = argPartition(topKGroup, kth: k - 1, axis: -2)[.ellipsis, .. 1, normTopkProb { + scores = scores / (scores.sum(axis: -1, keepDims: true) + 1e-20) * routedScalingFactor + } + return (inds, scores) + } +} + +// MARK: - MoE + +private class DSV3MoE: Module, UnaryLayer { + let numExpertsPerTok: Int + @ModuleInfo(key: "switch_mlp") var switchMLP: SwitchGLU + var gate: DSV3MoEGate + @ModuleInfo(key: "shared_experts") var sharedExperts: DSV3MLP? + + init(config: DSV3Config) { + numExpertsPerTok = config.numExpertsPerTok ?? 1 + _switchMLP.wrappedValue = SwitchGLU( + inputDims: config.hiddenSize, + hiddenDims: config.moeIntermediateSize, + numExperts: config.nRoutedExperts ?? 1, + activation: clippedSilu) + gate = DSV3MoEGate(config: config) + if let sharedCount = config.nSharedExperts { + _sharedExperts.wrappedValue = DSV3MLP( + config: config, intermediateSize: config.moeIntermediateSize * sharedCount) + } + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let (indices, scores) = gate(x) + var y = switchMLP(x, indices) + y = (y * scores[.ellipsis, .newAxis]).sum(axis: -2) + if let shared = sharedExperts { y = y + shared(x) } + return y + } +} + +// MARK: - Decoder Layer + +private class DSV3DecoderLayer: Module { + @ModuleInfo(key: "self_attn") var selfAttn: DSV3Attention + var mlp: UnaryLayer + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + init(config: DSV3Config, layerIdx: Int) { + _selfAttn.wrappedValue = DSV3Attention(config: config) + if config.nRoutedExperts != nil, + layerIdx >= config.firstKDenseReplace, + layerIdx % config.moeLayerFreq == 0 + { + mlp = DSV3MoE(config: config) + } else { + mlp = DSV3MLP(config: config) + } + _inputLayerNorm.wrappedValue = RMSNorm( + dimensions: config.hiddenSize, eps: config.rmsNormEps) + _postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: config.hiddenSize, eps: config.rmsNormEps) + } + + func callAsFunction( + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache? + ) -> MLXArray { + let h = x + selfAttn(inputLayerNorm(x), mask: mask, cache: cache) + return h + mlp(postAttentionLayerNorm(h)) + } +} + +// MARK: - Model Inner + +private class DSV3ModelInner: Module, LayerPartitionable { + var gpuLayerCount: Int? = nil + var totalLayerCount: Int { layers.count } + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + let layers: [DSV3DecoderLayer] + @ModuleInfo(key: "norm") var norm: RMSNorm + + init(config: DSV3Config) { + _embedTokens.wrappedValue = Embedding( + embeddingCount: config.vocabSize, dimensions: config.hiddenSize) + layers = (0 ..< config.numHiddenLayers).map { + DSV3DecoderLayer(config: config, layerIdx: $0) + } + _norm.wrappedValue = RMSNorm(dimensions: config.hiddenSize, eps: config.rmsNormEps) + } + + func callAsFunction(_ x: MLXArray, cache: [KVCache]?) -> MLXArray { + var h = embedTokens(x) + let mask = createAttentionMask(h: h, cache: cache?.first) + for (i, layer) in layers.enumerated() { + h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount) { + layer(h, mask: mask, cache: cache?[i]) + } + } + return norm(h) + } + + func callCapturing( + _ x: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = embedTokens(x) + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layers.count) } + var out = Array(repeating: nil as KVCache?, count: layers.count) + for (i, v) in c.prefix(layers.count).enumerated() { out[i] = v } + return out + }() + let mask = createAttentionMask(h: h, cache: kvCache.first ?? nil) + var captured: [Int: MLXArray] = [:] + for (i, layer) in layers.enumerated() { + h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount) { + layer(h, mask: mask, cache: kvCache[i]) + } + if captureLayerIDs.contains(i) { captured[i] = h } + } + return (norm(h), captured) + } +} + +// MARK: - Public Model + +/// DeepSeek V3 model owned by SwiftLM. +/// Registered for `deepseek_v3` and `kimi_k25` model types at DFlash setup time, +/// overriding the MLXLLM factory default so DFlash conformance is available. +public class DeepseekV3DFlashModel: Module, LLMModel, KVCacheDimensionProvider, LoRAModel, + DFlashTargetModel +{ + public var kvHeads: [Int] = [] + + private let args: DSV3Config + @ModuleInfo(key: "model") private var inner: DSV3ModelInner + @ModuleInfo(key: "lm_head") var lmHead: Linear + + init(_ args: DSV3Config) { + self.args = args + _inner.wrappedValue = DSV3ModelInner(config: args) + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabSize, bias: false) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { + lmHead(inner(inputs, cache: cache)) + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + // Strip HuggingFace VLM wrapper prefix present in some checkpoints (e.g. kimi_k25). + let llmPrefix = "language_model." + var weights = weights.count > 0 && weights.keys.first!.hasPrefix(llmPrefix) + ? Dictionary(uniqueKeysWithValues: weights.map { k, v in + (k.hasPrefix(llmPrefix) ? String(k.dropFirst(llmPrefix.count)) : k, v) + }) + : weights + + var w = weights + + func dequant(weight: MLXArray, scaleInv: MLXArray) -> MLXArray { + let bs = 128 + let (m, n) = (weight.dim(0), weight.dim(1)) + let padBottom = (bs - m % bs) % bs + let padSide = (bs - n % bs) % bs + var p = padded(weight, widths: [.init((0, padBottom)), .init((0, padSide))]) + p = p.reshaped([(m + padBottom) / bs, bs, (n + padSide) / bs, bs]) + let scaled = p * scaleInv[0..., .newAxis, 0..., .newAxis] + return scaled.reshaped([m + padBottom, n + padSide])[0 ..< m, 0 ..< n] + } + + for (key, value) in weights { + if key.contains("weight_scale_inv") { + let weightKey = key.replacingOccurrences(of: "_scale_inv", with: "") + if let weight = weights[weightKey] { + w[weightKey] = dequant(weight: weight, scaleInv: value) + } + } else if w[key] == nil { + w[key] = value + } + } + + for l in 0 ..< args.numHiddenLayers { + let prefix = "model.layers.\(l)" + for (_, projName) in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")] { + for key in ["weight", "scales", "biases"] { + let firstKey = "\(prefix).mlp.experts.0.\(projName).\(key)" + if weights[firstKey] != nil { + let joined = (0 ..< (args.nRoutedExperts ?? 1)).map { + weights["\(prefix).mlp.experts.\($0).\(projName).\(key)"]! + } + w["\(prefix).mlp.switch_mlp.\(projName).\(key)"] = stacked(joined) + // Remove per-expert keys — they have no corresponding module path + // after stacking and would fail verify: .noUnusedKeys. + for e in 0 ..< (args.nRoutedExperts ?? 1) { + w.removeValue(forKey: "\(prefix).mlp.experts.\(e).\(projName).\(key)") + } + } + } + } + } + + return w.filter { key, _ in + !key.starts(with: "model.layers.\(args.numHiddenLayers)") + && !key.contains("rotary_emb.inv_freq") + } + } + + public var loraLayers: [Module] { inner.layers } + + // MARK: DFlashTargetModel + + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + inner.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + lmHead(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + let cacheOpt: [KVCache?] = cache.map { $0 } + let (hidden, captured) = inner.callCapturing( + inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs) + return (dflashLmHeadLogits(hidden), captured) + } + + public var dflashIsHybridGDN: Bool { false } +} diff --git a/Sources/SwiftLM/KimiLinearDFlash.swift b/Sources/SwiftLM/KimiLinearDFlash.swift new file mode 100644 index 00000000..100a6496 --- /dev/null +++ b/Sources/SwiftLM/KimiLinearDFlash.swift @@ -0,0 +1,681 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// +// Kimi linear (hybrid KDA/MLA) model owned by SwiftLM with DFlash support. +// +// Port of mlx-lm/mlx_lm/models/kimi_linear.py +// Handles model types: "kimi_linear" +// +// Kept in SwiftLM to avoid upstream submodule changes. +// DFlashTargetModel conformance and callCapturing live here with the model. + +import DFlash +import Foundation +import MLX +import MLXLLM +import MLXLMCommon +import MLXNN + +// MARK: - Configuration + +private struct LinearAttnConfig: Codable, Sendable { + var kdaLayers: [Int] // 1-indexed layer indices that use KimiDeltaAttention + var numHeads: Int + var headDim: Int + var shortConvKernelSize: Int + + enum CodingKeys: String, CodingKey { + case kdaLayers = "kda_layers" + case numHeads = "num_heads" + case headDim = "head_dim" + case shortConvKernelSize = "short_conv_kernel_size" + } + + init(from decoder: Decoder) throws { + let c = try decoder.container(keyedBy: CodingKeys.self) + kdaLayers = try c.decode([Int].self, forKey: .kdaLayers) + numHeads = try c.decode(Int.self, forKey: .numHeads) + headDim = try c.decode(Int.self, forKey: .headDim) + shortConvKernelSize = try c.decodeIfPresent(Int.self, forKey: .shortConvKernelSize) ?? 4 + } +} + +public struct KimiLinearConfiguration: Codable, Sendable { + var modelType: String + var vocabSize: Int + var hiddenSize: Int + var numHiddenLayers: Int + var numAttentionHeads: Int + var intermediateSize: Int + var headDim: Int + var rmsNormEps: Float + fileprivate var linearAttnConfig: LinearAttnConfig + var modelMaxLength: Int + var numExperts: Int + var moeIntermediateSize: Int + var kvLoraRank: Int + var ropeScaling: [String: StringOrNumber]? + var tieWordEmbeddings: Bool + var qkNopeHeadDim: Int? + var qkRopeHeadDim: Int? + var vHeadDim: Int? + var numExpertsPerToken: Int + var numSharedExperts: Int + var moeRouterActivationFunc: String + var moeRenormalize: Bool + var routedScalingFactor: Float + var firstKDenseReplace: Int + var moeLayerFreq: Int + var numExpertGroup: Int + var topkGroup: Int + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case vocabSize = "vocab_size" + case hiddenSize = "hidden_size" + case numHiddenLayers = "num_hidden_layers" + case numAttentionHeads = "num_attention_heads" + case intermediateSize = "intermediate_size" + case headDim = "head_dim" + case rmsNormEps = "rms_norm_eps" + case linearAttnConfig = "linear_attn_config" + case modelMaxLength = "model_max_length" + case numExperts = "num_experts" + case moeIntermediateSize = "moe_intermediate_size" + case kvLoraRank = "kv_lora_rank" + case ropeScaling = "rope_scaling" + case tieWordEmbeddings = "tie_word_embeddings" + case qkNopeHeadDim = "qk_nope_head_dim" + case qkRopeHeadDim = "qk_rope_head_dim" + case vHeadDim = "v_head_dim" + case numExpertsPerToken = "num_experts_per_token" + case numSharedExperts = "num_shared_experts" + case moeRouterActivationFunc = "moe_router_activation_func" + case moeRenormalize = "moe_renormalize" + case routedScalingFactor = "routed_scaling_factor" + case firstKDenseReplace = "first_k_dense_replace" + case moeLayerFreq = "moe_layer_freq" + case numExpertGroup = "num_expert_group" + case topkGroup = "topk_group" + } + + var resolvedQkNopeHeadDim: Int { qkNopeHeadDim ?? headDim } + var resolvedQkRopeHeadDim: Int { qkRopeHeadDim ?? 0 } + var resolvedVHeadDim: Int { vHeadDim ?? headDim } + var qHeadDim: Int { resolvedQkNopeHeadDim + resolvedQkRopeHeadDim } + + public init(from decoder: Decoder) throws { + let c = try decoder.container(keyedBy: CodingKeys.self) + modelType = try c.decode(String.self, forKey: .modelType) + vocabSize = try c.decode(Int.self, forKey: .vocabSize) + hiddenSize = try c.decode(Int.self, forKey: .hiddenSize) + numHiddenLayers = try c.decode(Int.self, forKey: .numHiddenLayers) + numAttentionHeads = try c.decode(Int.self, forKey: .numAttentionHeads) + intermediateSize = try c.decode(Int.self, forKey: .intermediateSize) + headDim = try c.decode(Int.self, forKey: .headDim) + rmsNormEps = try c.decode(Float.self, forKey: .rmsNormEps) + linearAttnConfig = try c.decode(LinearAttnConfig.self, forKey: .linearAttnConfig) + modelMaxLength = try c.decode(Int.self, forKey: .modelMaxLength) + numExperts = try c.decode(Int.self, forKey: .numExperts) + moeIntermediateSize = try c.decode(Int.self, forKey: .moeIntermediateSize) + kvLoraRank = try c.decode(Int.self, forKey: .kvLoraRank) + ropeScaling = try c.decodeIfPresent([String: StringOrNumber].self, forKey: .ropeScaling) + tieWordEmbeddings = try c.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false + qkNopeHeadDim = try c.decodeIfPresent(Int.self, forKey: .qkNopeHeadDim) + qkRopeHeadDim = try c.decodeIfPresent(Int.self, forKey: .qkRopeHeadDim) + vHeadDim = try c.decodeIfPresent(Int.self, forKey: .vHeadDim) + numExpertsPerToken = try c.decodeIfPresent(Int.self, forKey: .numExpertsPerToken) ?? 1 + numSharedExperts = try c.decodeIfPresent(Int.self, forKey: .numSharedExperts) ?? 0 + moeRouterActivationFunc = + try c.decodeIfPresent(String.self, forKey: .moeRouterActivationFunc) ?? "sigmoid" + moeRenormalize = try c.decodeIfPresent(Bool.self, forKey: .moeRenormalize) ?? true + routedScalingFactor = + try c.decodeIfPresent(Float.self, forKey: .routedScalingFactor) ?? 1.0 + firstKDenseReplace = try c.decodeIfPresent(Int.self, forKey: .firstKDenseReplace) ?? 0 + moeLayerFreq = try c.decodeIfPresent(Int.self, forKey: .moeLayerFreq) ?? 1 + numExpertGroup = try c.decodeIfPresent(Int.self, forKey: .numExpertGroup) ?? 1 + topkGroup = try c.decodeIfPresent(Int.self, forKey: .topkGroup) ?? 1 + } +} + +// MARK: - KimiMLP + +private class KimiMLP: Module, UnaryLayer { + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "up_proj") var up: Linear + @ModuleInfo(key: "down_proj") var down: Linear + + init(dimensions: Int, hiddenDimensions: Int) { + _gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + _up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + _down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { down(gate(x) * silu(up(x))) } +} + +// MARK: - KimiMultiLinear + +private class KimiMultiLinear: Module { + var weight: MLXArray + + init(inputDims: Int, outputDims: Int, numHeads: Int) { + weight = MLXArray.zeros([numHeads, outputDims, inputDims]) + } + + func callAsFunction(_ x: MLXArray, transpose: Bool = true) -> MLXArray { + transpose ? x.matmul(weight.transposed(-1, -2)) : x.matmul(weight) + } +} + +// MARK: - KimiMLAAttention + +private class KimiMLAAttention: Module { + let numHeads: Int + let qkNopeHeadDim: Int + let qkRopeHeadDim: Int + let qHeadDim: Int + let vHeadDim: Int + let kvLoraRank: Int + let scale: Float + + @ModuleInfo(key: "q_proj") var qProj: Linear + @ModuleInfo(key: "kv_a_proj_with_mqa") var kvAProj: Linear + @ModuleInfo(key: "kv_a_layernorm") var kvALayerNorm: RMSNorm + @ModuleInfo(key: "embed_q") var embedQ: KimiMultiLinear + @ModuleInfo(key: "unembed_out") var unembedOut: KimiMultiLinear + @ModuleInfo(key: "o_proj") var oProj: Linear + + init(_ args: KimiLinearConfiguration) { + numHeads = args.numAttentionHeads + qkNopeHeadDim = args.resolvedQkNopeHeadDim + qkRopeHeadDim = args.resolvedQkRopeHeadDim + qHeadDim = args.qHeadDim + vHeadDim = args.resolvedVHeadDim + kvLoraRank = args.kvLoraRank + scale = pow(Float(args.qHeadDim), -0.5) + + let h = args.hiddenSize + _qProj.wrappedValue = Linear(h, numHeads * qHeadDim, bias: false) + _kvAProj.wrappedValue = Linear(h, kvLoraRank + max(qkRopeHeadDim, 0), bias: false) + _kvALayerNorm.wrappedValue = RMSNorm(dimensions: kvLoraRank, eps: args.rmsNormEps) + _embedQ.wrappedValue = KimiMultiLinear( + inputDims: qkNopeHeadDim, outputDims: kvLoraRank, numHeads: numHeads) + _unembedOut.wrappedValue = KimiMultiLinear( + inputDims: kvLoraRank, outputDims: vHeadDim, numHeads: numHeads) + _oProj.wrappedValue = Linear(numHeads * vHeadDim, h, bias: false) + } + + func callAsFunction( + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: ArraysCache? + ) -> MLXArray { + let (B, L) = (x.dim(0), x.dim(1)) + let q = qProj(x).reshaped(B, L, numHeads, qHeadDim).transposed(0, 2, 1, 3) + let qNope = q[.ellipsis, .. 0 + ? kvRaw[.ellipsis, kvLoraRank...].reshaped(B, L, 1, qkRopeHeadDim) + .transposed(0, 2, 1, 3) + : MLXArray.zeros([B, 1, L, 0], dtype: kvLatent.dtype) + cache[0] = kvLatent + cache[1] = concatenated([prev1, curKpe], axis: -2) + } else { + cache[0] = kvLatent + cache[1] = qkRopeHeadDim > 0 + ? kvRaw[.ellipsis, kvLoraRank...].reshaped(B, L, 1, qkRopeHeadDim) + .transposed(0, 2, 1, 3) + : MLXArray.zeros([B, 1, L, 0], dtype: kvLatent.dtype) + } + cache.offset += L + } + let totalL = kvLatent.dim(-2) + + var peScores: MLXArray? = nil + if qkRopeHeadDim > 0, let kPe = cache?[1] { + let qPe = q[.ellipsis, qkNopeHeadDim...] + peScores = (qPe * scale).matmul(kPe.transposed(-1, -2)) + } + + let output: MLXArray + if L == 1 { + let qMapped = embedQ(qNope) + var scores = qMapped.matmul(kvLatent.transposed(-1, -2)) * scale + if let pe = peScores { scores = scores + pe } + let weights = softmax(scores, axis: -1) + output = unembedOut(weights.matmul(kvLatent)) + } else { + let k = embedQ(kvLatent, transpose: false) + let v = unembedOut(kvLatent) + var scores = qNope.matmul(k.transposed(-1, -2)) * scale + scores = scores + makeCausalBias(L: L, totalL: totalL, dtype: scores.dtype) + if let pe = peScores { scores = scores + pe } + let weights = softmax(scores.asType(.float32), axis: -1).asType(scores.dtype) + output = weights.matmul(v) + } + + return oProj(output.transposed(0, 2, 1, 3).reshaped(B, L, -1)) + } + + private func makeCausalBias(L: Int, totalL: Int, dtype: DType) -> MLXArray { + let rows = MLXArray(Array(totalL - L ..< totalL)).reshaped(L, 1) + let cols = MLXArray(Array(0 ..< totalL)).reshaped(1, totalL) + return ((rows .< cols).asType(.float32) * Float(-1e9)).asType(dtype).reshaped(1, 1, L, totalL) + } +} + +// MARK: - ShortConv1d + +private class ShortConv1d: Module { + let kernelSize: Int + @ModuleInfo(key: "conv") var conv: Conv1d + + init(channels: Int, kernelSize: Int) { + self.kernelSize = kernelSize + _conv.wrappedValue = Conv1d( + inputChannels: 1, outputChannels: channels, kernelSize: kernelSize, + stride: 1, padding: 0, dilation: 1, groups: channels, bias: false) + } + + func callAsFunction(_ x: MLXArray, state: MLXArray?) -> (MLXArray, MLXArray) { + let (B, T, C) = (x.dim(0), x.dim(1), x.dim(2)) + let nKeep = kernelSize - 1 + let prevState = state ?? MLXArray.zeros([B, nKeep, C], dtype: x.dtype) + let convInput = concatenated([prevState, x], axis: 1) + let out = silu(conv(convInput)) + return (out, convInput[0..., T...]) + } +} + +// MARK: - KimiDeltaAttention + +private class KimiDeltaAttention: Module { + let numHeads: Int + let headDim: Int + let projDim: Int + let scale: Float + + @ModuleInfo(key: "q_proj") var qProj: Linear + @ModuleInfo(key: "k_proj") var kProj: Linear + @ModuleInfo(key: "v_proj") var vProj: Linear + @ModuleInfo(key: "q_conv") var qConv: ShortConv1d + @ModuleInfo(key: "k_conv") var kConv: ShortConv1d + @ModuleInfo(key: "v_conv") var vConv: ShortConv1d + @ModuleInfo(key: "f_a_proj") var faProj: Linear + @ModuleInfo(key: "f_b_proj") var fbProj: Linear + @ModuleInfo(key: "b_proj") var bProj: Linear + @ModuleInfo(key: "g_a_proj") var gaProj: Linear + @ModuleInfo(key: "g_b_proj") var gbProj: Linear + @ModuleInfo(key: "o_norm") var oNorm: RMSNorm + @ModuleInfo(key: "o_proj") var oProj: Linear + + var aLog: MLXArray + var dtBias: MLXArray + + init(_ args: KimiLinearConfiguration, layerIdx: Int) { + let cfg = args.linearAttnConfig + numHeads = cfg.numHeads + headDim = cfg.headDim + projDim = numHeads * headDim + scale = pow(Float(headDim), -0.5) + + let h = args.hiddenSize + let K = cfg.shortConvKernelSize + _qProj.wrappedValue = Linear(h, projDim, bias: false) + _kProj.wrappedValue = Linear(h, projDim, bias: false) + _vProj.wrappedValue = Linear(h, projDim, bias: false) + _qConv.wrappedValue = ShortConv1d(channels: projDim, kernelSize: K) + _kConv.wrappedValue = ShortConv1d(channels: projDim, kernelSize: K) + _vConv.wrappedValue = ShortConv1d(channels: projDim, kernelSize: K) + _faProj.wrappedValue = Linear(h, headDim, bias: false) + _fbProj.wrappedValue = Linear(headDim, projDim, bias: false) + _bProj.wrappedValue = Linear(h, numHeads, bias: false) + _gaProj.wrappedValue = Linear(h, headDim, bias: false) + _gbProj.wrappedValue = Linear(headDim, projDim, bias: false) + _oNorm.wrappedValue = RMSNorm(dimensions: headDim, eps: args.rmsNormEps) + _oProj.wrappedValue = Linear(projDim, h, bias: false) + aLog = MLXArray.zeros([numHeads]) + dtBias = MLXArray.zeros([projDim]) + } + + func callAsFunction( + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: ArraysCache? + ) -> MLXArray { + let (B, T) = (x.dim(0), x.dim(1)) + let (qConvOut, newQState) = qConv(qProj(x), state: cache?[0]) + let (kConvOut, newKState) = kConv(kProj(x), state: cache?[1]) + let (vConvOut, newVState) = vConv(vProj(x), state: cache?[2]) + if let cache { + cache[0] = newQState + cache[1] = newKState + cache[2] = newVState + } + var q = qConvOut.reshaped(B, T, numHeads, headDim) + var k = kConvOut.reshaped(B, T, numHeads, headDim) + let v = vConvOut.reshaped(B, T, numHeads, headDim) + q = (scale * scale) * MLXFast.rmsNorm(q, weight: MLXArray.mlxNone, eps: 1e-6) + k = scale * MLXFast.rmsNorm(k, weight: MLXArray.mlxNone, eps: 1e-6) + let aLogits = fbProj(faProj(x)).reshaped(B, T, numHeads, headDim) + let bLogits = bProj(x).reshaped(B, T, numHeads) + let (out, newSsmState) = kimiGatedDeltaUpdate( + q: q, k: k, v: v, + aLogits: aLogits, bLogits: bLogits, + aLog: aLog.reshaped(numHeads, 1), + dtBias: dtBias.reshaped(numHeads, headDim), + state: cache?[3]) + if let cache { + cache[3] = newSsmState + cache.offset += T + } + let gate = gbProj(gaProj(x)).reshaped(B, T, numHeads, headDim) + return oProj((oNorm(out) * sigmoid(gate)).reshaped(B, T, -1)) + } +} + +// MARK: - Kimi Gated Delta Update + +private func kimiGatedDeltaUpdate( + q: MLXArray, k: MLXArray, v: MLXArray, + aLogits: MLXArray, bLogits: MLXArray, + aLog: MLXArray, dtBias: MLXArray, + state: MLXArray? +) -> (MLXArray, MLXArray) { + let (B, T, H, Dv, Dk) = (q.dim(0), q.dim(1), q.dim(2), v.dim(3), q.dim(3)) + let g = exp(-exp(aLog) * softplus(aLogits + dtBias)) + let beta = sigmoid(bLogits) + var s = state ?? MLXArray.zeros([B, H, Dv, Dk], dtype: q.dtype) + var ys = [MLXArray]() + ys.reserveCapacity(T) + for t in 0 ..< T { + let qt = q[0..., t]; let kt = k[0..., t]; let vt = v[0..., t] + let gt = g[0..., t]; let betat = beta[0..., t] + s = s * expandedDimensions(gt, axis: -2) + let kvMem = (s * expandedDimensions(kt, axis: -2)).sum(axis: -1) + let delta = (vt - kvMem) * expandedDimensions(betat, axis: -1) + s = s + expandedDimensions(kt, axis: -2) * expandedDimensions(delta, axis: -1) + ys.append((s * expandedDimensions(qt, axis: -2)).sum(axis: -1)) + } + return (MLX.stacked(ys, axis: 1), s) +} + +// MARK: - KimiSparseMoE + +private class KimiSparseMoE: Module, UnaryLayer { + let numExperts: Int + let numExpertsPerToken: Int + let numExpertGroup: Int + let topkGroup: Int + let routedScalingFactor: Float + let renormalize: Bool + let scoreFunction: String + + @ModuleInfo(key: "gate") var gate: Linear + @ModuleInfo(key: "switch_mlp") var switchMLP: SwitchGLU + var eScoreCorrectionBias: MLXArray + + @ModuleInfo(key: "shared_experts") var sharedExperts: KimiMLP? + + init(_ args: KimiLinearConfiguration) { + numExperts = args.numExperts + numExpertsPerToken = args.numExpertsPerToken + numExpertGroup = args.numExpertGroup + topkGroup = args.topkGroup + routedScalingFactor = args.routedScalingFactor + renormalize = args.moeRenormalize + scoreFunction = args.moeRouterActivationFunc + _gate.wrappedValue = Linear(args.hiddenSize, numExperts, bias: false) + _switchMLP.wrappedValue = SwitchGLU( + inputDims: args.hiddenSize, hiddenDims: args.moeIntermediateSize, numExperts: numExperts) + eScoreCorrectionBias = MLXArray.zeros([numExperts]) + if args.numSharedExperts > 0 { + _sharedExperts.wrappedValue = KimiMLP( + dimensions: args.hiddenSize, + hiddenDimensions: args.moeIntermediateSize * args.numSharedExperts) + } + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let logits = gate(x) + var scores = scoreFunction == "softmax" + ? MLX.softmax(logits, axis: -1, precise: true) + : sigmoid(logits) + let origScores = scores + scores = scores + eScoreCorrectionBias.asType(scores.dtype) + if numExpertGroup > 1 { + let grouped = scores.reshaped(scores.shape.dropLast() + [numExpertGroup, -1]) + let groupTop = top(grouped, k: 2, axis: -1).sum(axis: -1, keepDims: true) + let k = numExpertGroup - topkGroup + let groupIdx = argPartition(groupTop, kth: k - 1, axis: -2)[.ellipsis, .. 1 && renormalize { + weights = weights / (weights.sum(axis: -1, keepDims: true) + 1e-20) + } + weights = weights * routedScalingFactor + var out = (switchMLP(x, inds) * weights[.ellipsis, .newAxis]).sum(axis: -2) + if let shared = sharedExperts { out = out + shared(x) } + return out + } +} + +// MARK: - KimiDecoderLayer + +private class KimiDecoderLayer: Module { + let isLinear: Bool + @ModuleInfo(key: "self_attn") var deltaAttn: KimiDeltaAttention? + @ModuleInfo(key: "self_attn") var mlaAttn: KimiMLAAttention? + var mlp: UnaryLayer + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttnLayerNorm: RMSNorm + + init(_ args: KimiLinearConfiguration, layerIdx: Int) { + let kdaSet = Set(args.linearAttnConfig.kdaLayers) + isLinear = kdaSet.contains(layerIdx + 1) + if isLinear { + _deltaAttn.wrappedValue = KimiDeltaAttention(args, layerIdx: layerIdx) + } else { + _mlaAttn.wrappedValue = KimiMLAAttention(args) + } + if args.numExperts > 0 + && layerIdx >= args.firstKDenseReplace + && layerIdx % args.moeLayerFreq == 0 + { + mlp = KimiSparseMoE(args) + } else { + mlp = KimiMLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + } + _inputLayerNorm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + _postAttnLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + func callAsFunction( + _ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: ArraysCache? + ) -> MLXArray { + let attended = isLinear + ? deltaAttn!(inputLayerNorm(x), mask: mask, cache: cache) + : mlaAttn!(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + attended + return h + mlp(postAttnLayerNorm(h)) + } +} + +// MARK: - KimiLinearModelInner + +private class KimiLinearModelInner: Module, LayerPartitionable { + var gpuLayerCount: Int? + var totalLayerCount: Int { layers.count } + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + let layers: [KimiDecoderLayer] + @ModuleInfo(key: "norm") var norm: RMSNorm + let attnLayerIdx: Int // first MLA (full-attention) layer index + + init(_ args: KimiLinearConfiguration) { + precondition(args.vocabSize > 0) + _embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabSize, dimensions: args.hiddenSize) + layers = (0 ..< args.numHiddenLayers).map { KimiDecoderLayer(args, layerIdx: $0) } + _norm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + let kdaSet = Set(args.linearAttnConfig.kdaLayers) + attnLayerIdx = (0 ..< args.numHiddenLayers).first { !kdaSet.contains($0 + 1) } ?? 0 + } + + func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { + var h = embedTokens(inputs) + let mask = createAttentionMask(h: h, cache: cache?[attnLayerIdx] as? ArraysCache) + for (i, layer) in layers.enumerated() { + h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount) { + layer(h, mask: mask, cache: cache?[i] as? ArraysCache) + } + } + return norm(h) + } + + func callCapturing( + _ inputs: MLXArray, cache: [KVCache?]? = nil, captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + var h = embedTokens(inputs) + let kvCache: [KVCache?] = { + guard let c = cache else { return Array(repeating: nil, count: layers.count) } + var out = Array(repeating: nil as KVCache?, count: layers.count) + for (i, v) in c.prefix(layers.count).enumerated() { out[i] = v } + return out + }() + let mask = createAttentionMask(h: h, cache: kvCache[attnLayerIdx] as? ArraysCache) + var captured: [Int: MLXArray] = [:] + for (i, layer) in layers.enumerated() { + h = partitionedLayerCall(index: i, gpuLayerCount: gpuLayerCount) { + layer(h, mask: mask, cache: kvCache[i] as? ArraysCache) + } + if captureLayerIDs.contains(i) { captured[i] = h } + } + return (norm(h), captured) + } +} + +// MARK: - Public Model + +/// Kimi linear (hybrid KDA/MLA) model owned by SwiftLM. +/// Registered for `kimi_linear` model type at DFlash setup time. +public class KimiLinearDFlashModel: Module, LLMModel, KVCacheDimensionProvider, LoRAModel, + DFlashTargetModel +{ + public let vocabularySize: Int + public let kvHeads: [Int] + + @ModuleInfo(key: "model") private var inner: KimiLinearModelInner + private let configuration: KimiLinearConfiguration + + @ModuleInfo(key: "lm_head") var lmHead: Linear? + + public init(_ args: KimiLinearConfiguration) { + configuration = args + vocabularySize = args.vocabSize + kvHeads = Array(repeating: 1, count: args.numHiddenLayers) + _inner.wrappedValue = KimiLinearModelInner(args) + if !args.tieWordEmbeddings { + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabSize, bias: false) + } + } + + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { + let out = inner(inputs, cache: cache) + return lmHead.map { $0(out) } ?? inner.embedTokens.asLinear(out) + } + + public func makeCache(parameters: GenerateParameters?) -> [any KVCache] { + inner.layers.map { layer in + layer.isLinear + ? ArraysCache(size: 4) // [q_state, k_state, v_state, ssm_state] + : ArraysCache(size: 2) // [kv_latent, k_pe] + } + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + var w = weights.filter { !$0.key.hasPrefix("model.mtp") } + if configuration.tieWordEmbeddings { w["lm_head.weight"] = nil } + + for (i, layer) in inner.layers.enumerated() { + let prefix = "model.layers.\(i)" + if layer.mlp is KimiSparseMoE { + let src = "\(prefix).block_sparse_moe" + let dst = "\(prefix).mlp" + for (srcN, dstN) in [("w1","gate_proj"),("w2","down_proj"),("w3","up_proj")] { + let key0 = "\(src).experts.0.\(srcN).weight" + if w[key0] != nil { + let n = configuration.numExperts + let stacked = (0 ..< n).map { + w.removeValue(forKey: "\(src).experts.\($0).\(srcN).weight")! + } + w["\(dst).switch_mlp.\(dstN).weight"] = MLX.stacked(stacked) + } + } + for name in ["gate_proj","up_proj","down_proj"] { + if let v = w.removeValue(forKey: "\(src).shared_experts.\(name).weight") { + w["\(dst).shared_experts.\(name).weight"] = v + } + } + if let v = w.removeValue(forKey: "\(src).gate.weight") { w["\(dst).gate.weight"] = v } + if let v = w.removeValue(forKey: "\(src).gate.e_score_correction_bias") { + w["\(dst).e_score_correction_bias"] = v + } + } + let attnP = "\(prefix).self_attn" + for (srcN, dstN) in [("q_conv1d","q_conv"),("k_conv1d","k_conv"),("v_conv1d","v_conv")] { + if var convW = w.removeValue(forKey: "\(attnP).\(srcN).weight") { + if convW.ndim == 3 { convW = convW.transposed(0, 2, 1) } + w["\(attnP).\(dstN).conv.weight"] = convW + } + } + if let dtW = w["\(attnP).dt_bias"], dtW.ndim > 1 { + w["\(attnP).dt_bias"] = dtW.reshaped(-1) + } + if let kvB = w.removeValue(forKey: "\(attnP).kv_b_proj.weight") { + let qkNope = configuration.resolvedQkNopeHeadDim + let vHead = configuration.resolvedVHeadDim + let heads = configuration.numAttentionHeads + let r = kvB.reshaped(heads, qkNope + vHead, -1) + w["\(attnP).embed_q.weight"] = MLX.contiguous(r[0..., .. MLXArray { + inner.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + lmHead.map { $0(hiddenStates) } ?? inner.embedTokens.asLinear(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + let cacheOpt: [KVCache?] = cache.map { $0 } + let (hidden, captured) = inner.callCapturing( + inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs) + return (dflashLmHeadLogits(hidden), captured) + } + + // Kimi linear uses ArraysCache-backed KDA + MLA layers (no GDN rollback needed). + public var dflashIsHybridGDN: Bool { false } +} diff --git a/Sources/SwiftLM/Llama+DFlash.swift b/Sources/SwiftLM/Llama+DFlash.swift new file mode 100644 index 00000000..d19bdc97 --- /dev/null +++ b/Sources/SwiftLM/Llama+DFlash.swift @@ -0,0 +1,34 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Bridge: LlamaModel (and Mistral) conform to DFlashTargetModel + +import DFlash +import MLX +import MLXLLM +import MLXLMCommon + +extension LlamaModel: DFlashTargetModel { + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + model.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + if let lmHead { + return lmHead(hiddenStates) + } + return model.embedTokens.asLinear(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + let cacheOpt: [KVCache?] = cache.map { $0 } + let (hiddenStates, captured) = model.callCapturing( + inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs) + return (dflashLmHeadLogits(hiddenStates), captured) + } + + public var dflashIsHybridGDN: Bool { false } +} diff --git a/Sources/SwiftLM/ModelProfiler.swift b/Sources/SwiftLM/ModelProfiler.swift index 7ee89800..ea5f76a8 100644 --- a/Sources/SwiftLM/ModelProfiler.swift +++ b/Sources/SwiftLM/ModelProfiler.swift @@ -343,13 +343,14 @@ enum ModelProfiler { // MARK: Partition Planning /// Compute a partition plan for the given model on the current system. - static func plan(model: ModelProfile, system: SystemProfile, contextSize: Int) -> PartitionPlan { + static func plan(model: ModelProfile, system: SystemProfile, contextSize: Int, draftWeightBytes: Int = 0) -> PartitionPlan { let weightGB = model.weightMemoryGB > 0 ? model.weightMemoryGB : model.estimatedParamsB * (Double(model.quantBits) / 8.0) + let draftGB = Double(draftWeightBytes) / 1e9 let kvGB = model.kvCacheMemoryGB(contextLength: contextSize) let overheadFactor = 1.2 - let totalGB = weightGB * overheadFactor + kvGB + let totalGB = (weightGB + draftGB) * overheadFactor + kvGB let availableGB = system.availableRAMGB let overcommit = totalGB / availableGB @@ -397,7 +398,7 @@ enum ModelProfiler { memoryLimit = Int(Double(system.recommendedWorkingSetBytes) * 1.5) cacheLimit = system.recommendedWorkingSetBytes // default case .swapAssisted: - memoryLimit = Int(totalGB * 1.1 * 1e9) + memoryLimit = 200 * 1024 * 1024 * 1024 // 200 GB sentinel to bypass MLX eval_impl spin loop (let macOS swap handle it) cacheLimit = 2 * 1024 * 1024 // 2MB — let OS manage caching case .layerPartitioned: memoryLimit = Int(availableGB * 0.85 * 1e9) diff --git a/Sources/SwiftLM/Qwen3+DFlash.swift b/Sources/SwiftLM/Qwen3+DFlash.swift new file mode 100644 index 00000000..fcc1c482 --- /dev/null +++ b/Sources/SwiftLM/Qwen3+DFlash.swift @@ -0,0 +1,34 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Bridge: Qwen3 dense models conform to DFlashTargetModel + +import DFlash +import MLX +import MLXLLM +import MLXLMCommon + +extension Qwen3Model: DFlashTargetModel { + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + model.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + if let lmHead { + return lmHead(hiddenStates) + } + return model.embedTokens.asLinear(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + let cacheOpt: [KVCache?] = cache.map { $0 } + let (hiddenStates, captured) = model.callCapturing( + inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs) + return (dflashLmHeadLogits(hiddenStates), captured) + } + + public var dflashIsHybridGDN: Bool { false } +} diff --git a/Sources/SwiftLM/Qwen35+DFlash.swift b/Sources/SwiftLM/Qwen35+DFlash.swift new file mode 100644 index 00000000..e9508bae --- /dev/null +++ b/Sources/SwiftLM/Qwen35+DFlash.swift @@ -0,0 +1,62 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Bridge: Qwen35 models conform to DFlashTargetModel +// +// The dflash* methods are defined on Qwen35TextModel/Qwen35Model in the +// MLXLLM module. This file adds the DFlashTargetModel protocol conformance +// so the DFlash runtime can use them generically. + +import DFlash +import MLX +import MLXLLM +import MLXLMCommon + +// MARK: - Qwen35TextModel + DFlashTargetModel + +extension Qwen35TextModel: DFlashTargetModel { + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + model.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + if let lmHead { + return lmHead(hiddenStates) + } + return model.embedTokens.asLinear(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + let cacheOpt: [KVCache?] = cache.map { $0 } + let (hiddenStates, captured) = model.callCapturing( + inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs) + return (dflashLmHeadLogits(hiddenStates), captured) + } + + public var dflashIsHybridGDN: Bool { false } +} + +// MARK: - Qwen35Model + DFlashTargetModel + +extension Qwen35Model: DFlashTargetModel { + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + languageModel.dflashEmbedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + languageModel.dflashLmHeadLogits(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + languageModel.dflashForwardWithCapture(inputIDs: inputIDs, cache: cache, captureLayerIDs: captureLayerIDs) + } + + public var dflashIsHybridGDN: Bool { languageModel.dflashIsHybridGDN } +} diff --git a/Sources/SwiftLM/Qwen3MoE+DFlash.swift b/Sources/SwiftLM/Qwen3MoE+DFlash.swift new file mode 100644 index 00000000..68d4c6a8 --- /dev/null +++ b/Sources/SwiftLM/Qwen3MoE+DFlash.swift @@ -0,0 +1,34 @@ +// Copyright 2026 SwiftLM Contributors +// MIT License — see LICENSE file +// Bridge: Qwen3 MoE models conform to DFlashTargetModel + +import DFlash +import MLX +import MLXLLM +import MLXLMCommon + +extension Qwen3MoEModel: DFlashTargetModel { + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + model.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + if let lmHead { + return lmHead(hiddenStates) + } + return model.embedTokens.asLinear(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + let cacheOpt: [KVCache?] = cache.map { $0 } + let (hiddenStates, captured) = model.callCapturing( + inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs) + return (dflashLmHeadLogits(hiddenStates), captured) + } + + public var dflashIsHybridGDN: Bool { false } +} diff --git a/Sources/SwiftLM/Qwen3Next+DFlash.swift b/Sources/SwiftLM/Qwen3Next+DFlash.swift new file mode 100644 index 00000000..3b970d67 --- /dev/null +++ b/Sources/SwiftLM/Qwen3Next+DFlash.swift @@ -0,0 +1,36 @@ +import DFlash +import MLX +import MLXLLM +import MLXLMCommon + +extension Qwen3NextModel: DFlashTargetModel { + + public func dflashEmbedTokens(_ tokens: MLXArray) -> MLXArray { + model.embedTokens(tokens) + } + + public func dflashLmHeadLogits(_ hiddenStates: MLXArray) -> MLXArray { + if let lmHead { + return lmHead(hiddenStates) + } + return model.embedTokens.asLinear(hiddenStates) + } + + public func dflashForwardWithCapture( + inputIDs: MLXArray, + cache: [KVCache], + captureLayerIDs: Set + ) -> (MLXArray, [Int: MLXArray]) { + let cacheOpt: [KVCache?] = cache.map { $0 } + let (hiddenStates, captured) = model.callCapturing( + inputIDs, cache: cacheOpt, captureLayerIDs: captureLayerIDs) + return (dflashLmHeadLogits(hiddenStates), captured) + } + + /// Qwen3Next has GDN-style linear attention layers, but any rollback scheme + /// (tape or snapshot) degrades acceptance rate by leaving recurrent state stale. + /// Without rollback, rejected-token contamination is empirically negligible + /// (< 1 reject per accepted cycle at long context) and gives ~3x speedup. + /// Python avoids this tradeoff via @mx.compile on the verify pass (free tape). + public var dflashIsHybridGDN: Bool { false } +} diff --git a/Sources/SwiftLM/Server.swift b/Sources/SwiftLM/Server.swift index 28c51f59..9bf56ec4 100644 --- a/Sources/SwiftLM/Server.swift +++ b/Sources/SwiftLM/Server.swift @@ -11,6 +11,7 @@ import ArgumentParser import CoreImage +import DFlash import Foundation import HTTPTypes import Hummingbird @@ -18,6 +19,7 @@ import Hub import MLX import MLXLLM import MLXLMCommon +import MLXNN import MLXVLM import MLXInferenceCore import Tokenizers @@ -272,7 +274,34 @@ struct MLXServer: AsyncParsableCommand { @Option(name: .long, help: "Number of draft tokens per speculation round (default: 4)") var numDraftTokens: Int = 4 + @Flag(name: .long, help: "Enable DFlash block-diffusion speculative decoding. Requires a DFlash draft model (auto-resolved or specified via --draft-model).") + var dflash: Bool = false + + @Option(name: .long, help: "DFlash block size (number of tokens per draft block). Default: use draft model's configured block_size.") + var dflashBlockSize: Int? + mutating func run() async throws { + // Raise the open-file limit: large sharded models (e.g. Kimi K2.5, 182 safetensor + // shards) + draft model + metallib + dylibs can exhaust the default macOS FD limit of 256. + var rl = rlimit() + getrlimit(RLIMIT_NOFILE, &rl) + if rl.rlim_cur < 4096 { + rl.rlim_cur = min(4096, rl.rlim_max) + setrlimit(RLIMIT_NOFILE, &rl) + } + + // Cap Metal command buffer size BEFORE any MLX operation to prevent the + // 5-second Apple GPU Watchdog from killing processes under swap pressure. + // This env var must be set before MLX's Metal backend initializes. + // Value 50 splits large computation graphs into ~1-layer chunks so macOS + // can page in weights incrementally without exceeding the watchdog timeout. + if self.draftModel != nil || self.streamExperts { + setenv("MLX_MAX_OPS_PER_BUFFER", "50", 1) + } + + // Register SwiftLM-owned DFlash model types before any model loading. + await registerDFlashModelTypes() + print("[SwiftLM] Loading model: \(model)") let modelId = model @@ -336,8 +365,7 @@ struct MLXServer: AsyncParsableCommand { // instead of weightMemoryGB * 1_073_741_824 to avoid the ~7% GiB/GB // mismatch flagged in Copilot review (weightMemoryGB = bytes / 1e9, not /2^30). let draftFootprintBytes: Int - if self.streamExperts, - let draftPath = self.draftModel, + if let draftPath = self.draftModel, let draftDir = resolveModelDirectory(modelId: draftPath), let draftProfile = ModelProfiler.profile(modelDirectory: draftDir, modelId: draftPath) { draftFootprintBytes = draftProfile.weightFileSizeBytes @@ -425,7 +453,7 @@ struct MLXServer: AsyncParsableCommand { if let profile = profile { let system = ModelProfiler.systemProfile() let contextSize = self.ctxSize ?? 4096 - let plan = ModelProfiler.plan(model: profile, system: system, contextSize: contextSize) + let plan = ModelProfiler.plan(model: profile, system: system, contextSize: contextSize, draftWeightBytes: draftFootprintBytes) partitionPlan = plan // --info mode: print report and exit @@ -563,10 +591,22 @@ struct MLXServer: AsyncParsableCommand { print("[SwiftLM] Loaded model configuration. Inferred tool call format: \(String(describing: await container.configuration.toolCallFormat))") + // ── Check if target model supports DFlash ── + let dflashTargetModel: (any DFlashTargetModel)? = await container.perform { context -> (any DFlashTargetModel)? in + context.model as? any DFlashTargetModel + } + if self.dflash { + if dflashTargetModel != nil { + print("[SwiftLM] DFlash: target model supports DFlashTargetModel") + } else { + print("[SwiftLM] ⚠️ DFlash enabled but target model does NOT conform to DFlashTargetModel") + } + } + // ── Load draft model for speculative decoding ── let draftModelRef: DraftModelRef? let numDraftTokensConfig = self.numDraftTokens - if let draftModelPath = self.draftModel { + if let draftModelPath = self.draftModel, !self.dflash { print("[SwiftLM] Loading draft model for speculative decoding: \(draftModelPath)") var draftConfig: ModelConfiguration let draftFM = FileManager.default @@ -596,10 +636,71 @@ struct MLXServer: AsyncParsableCommand { } draftModelRef = await draftContainer.extractDraftModel() print("[SwiftLM] Draft model loaded successfully (\(numDraftTokensConfig) tokens/round)") + print("[SwiftLM] Using speculative decoding: \(draftModelPath) → \(modelId) (\(numDraftTokensConfig) draft tokens/round)") } else { draftModelRef = nil } + // ── Load DFlash draft model for block-diffusion speculative decoding ── + let dflashModel: DFlashDraftModel? + let dflashBlockSizeConfig = self.dflashBlockSize + let dflashConfig = DFlashDraftConfiguration.self + if self.dflash { + // Resolve draft model reference + let resolvedDraftRef: String + if let explicit = self.draftModel { + resolvedDraftRef = explicit + } else if let autoRef = DFlashDraftRegistry.resolveDraftRef(modelRef: modelId) { + resolvedDraftRef = autoRef + print("[SwiftLM] DFlash: auto-resolved draft model → \(autoRef)") + } else { + print("[SwiftLM] ⚠️ DFlash enabled but no draft model found for '\(modelId)'. Use --draft-model to specify one.") + resolvedDraftRef = "" + } + + if !resolvedDraftRef.isEmpty { + print("[SwiftLM] Loading DFlash draft model: \(resolvedDraftRef)") + let draftDir = resolveModelDirectory(modelId: resolvedDraftRef) + if let dir = draftDir { + do { + let configURL = dir.appendingPathComponent("config.json") + let data = try Data(contentsOf: configURL) + let config = try JSONDecoder().decode(dflashConfig, from: data) + let model = DFlashDraftModel(config) + + // Load weights + let weightURL = dir.appendingPathComponent("weights.safetensors") + let ntURL = dir.appendingPathComponent("model.safetensors") + let actualWeightURL = FileManager.default.fileExists(atPath: weightURL.path) ? weightURL : ntURL + + let weights = try loadArrays(url: actualWeightURL) + let sanitized = model.sanitize(weights: weights) + let parameters = ModuleParameters.unflattened(sanitized) + try model.update(parameters: parameters, verify: .none) + + dflashModel = model + // Register DFlashKernels as the global provider + // so Qwen35GatedDeltaNet can use tape-recording forward + DFlashKernelRegistry.provider = DFlashKernels.shared + DFlashDumper.setup() + print("[SwiftLM] DFlash draft model loaded (block_size=\(model.blockSize), \(model.targetLayerIDs.count) target layers, mask_token=\(model.maskTokenID))") + print("[SwiftLM] Draft model loaded successfully (\(model.blockSize) block size, DFlash mode)") + print("[SwiftLM] Using speculative decoding: \(resolvedDraftRef) → \(modelId) (DFlash block-diffusion)") + } catch { + print("[SwiftLM] ⚠️ Failed to load DFlash draft model: \(error)") + dflashModel = nil + } + } else { + print("[SwiftLM] ⚠️ DFlash draft model not found locally: \(resolvedDraftRef). Download it first with: hf download \(resolvedDraftRef)") + dflashModel = nil + } + } else { + dflashModel = nil + } + } else { + dflashModel = nil + } + // ── Apply GPU/CPU layer partitioning ── if let gpuCount = requestedGPULayers { @@ -772,7 +873,9 @@ struct MLXServer: AsyncParsableCommand { let bodyData = try await collectBody(request) return try await handleChatCompletion( request: request, bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats, promptCache: promptCache, - draftModelRef: draftModelRef, numDraftTokens: numDraftTokensConfig + draftModelRef: draftModelRef, numDraftTokens: numDraftTokensConfig, + dflashModel: dflashModel, dflashBlockSize: dflashBlockSizeConfig, + dflashTargetModel: dflashTargetModel ) } catch { let errMsg = String(describing: error).replacingOccurrences(of: "\"", with: "'") @@ -1055,7 +1158,7 @@ actor ServerStats { } } -// ── Prompt Cache ───────────────────────────────────────────────────────────── + actor PromptCache { struct CachedState { @@ -1074,6 +1177,9 @@ actor PromptCache { /// If not materialized now, those lazy references point to the live cache tensors /// which get overwritten by subsequent requests, causing stale data / SIGTRAP on restore. func save(tokens: [Int], cache: [KVCache]) { + if cache.contains(where: { $0 is MambaCache }) { + return + } let states = cache.map { $0.state } let metaStates = cache.map { $0.metaState } // Materialize all lazy MLX arrays so they survive cache mutations @@ -1088,6 +1194,14 @@ actor PromptCache { /// Restores matched KV state, trims any excess — mirrors llama-server behaviour. /// Returns the number of matched tokens, or nil on a complete miss. func restore(newTokens: [Int], into cache: [KVCache]) -> Int? { + // MambaCache/RNN states cannot be arbitrarily rolled back or safely saved + // after the fact without exact sequence-boundary synchronization. + // Disable prompt caching entirely for hybrid models (e.g. Qwen3Next). + if cache.contains(where: { $0 is MambaCache }) { + misses += 1 + return nil + } + guard let cached, !cached.tokens.isEmpty else { misses += 1 return nil @@ -1156,7 +1270,10 @@ func handleChatCompletion( stats: ServerStats, promptCache: PromptCache, draftModelRef: DraftModelRef? = nil, - numDraftTokens: Int = 4 + numDraftTokens: Int = 4, + dflashModel: DFlashDraftModel? = nil, + dflashBlockSize: Int? = nil, + dflashTargetModel: (any DFlashTargetModel)? = nil ) async throws -> Response { let chatReq = try JSONDecoder().decode(ChatCompletionRequest.self, from: bodyData) let isStream = chatReq.stream ?? false @@ -1319,7 +1436,69 @@ func handleChatCompletion( fflush(stdout) let prefillStart = Date() - // ── Cache-aware generation ── + // ── DFlash block-diffusion speculative decoding ── + // When --dflash is enabled and both DFlash draft model and target model conform + // to DFlashTargetModel, we use DFlashRuntime.generate instead of the standard path. + if let dflashDraft = dflashModel, let targetModel = dflashTargetModel { + print("[SwiftLM] ⚡ DFlash block-diffusion speculative decoding active") + print("[SwiftLM] Using speculative decoding: DFlash block-diffusion mode active") + fflush(stdout) + // Convert DFlashEvent stream to Generation stream with proper streaming detokenizer + let dflashTokenizer = await container.tokenizer + let dflashStream = DFlashRuntime.generate( + targetModel: targetModel, + draftModel: dflashDraft, + promptTokens: promptTokens, + maxNewTokens: tokenLimit, + blockTokens: dflashBlockSize + ) + + // Use a class wrapper so the detokenizer can be mutated inside the closure + final class DetokenizerBox: @unchecked Sendable { + var detokenizer: NaiveStreamingDetokenizer + init(_ d: NaiveStreamingDetokenizer) { self.detokenizer = d } + } + let box = DetokenizerBox(NaiveStreamingDetokenizer(tokenizer: dflashTokenizer)) + + let genStream = AsyncStream { continuation in + Task { + for await event in dflashStream { + switch event { + case .token(let tokenID, _, _, _): + box.detokenizer.append(token: tokenID) + if let chunk = box.detokenizer.next() { + continuation.yield(.chunk(chunk, tokenId: tokenID)) + } + case .prefill, .prefillProgress: + break + case .summary(let summary): + print("[SwiftLM] DFlash summary: \(summary.generationTokens) tokens, \(String(format: "%.1f", summary.tokensPerSecond)) tok/s, acceptance=\(String(format: "%.1f%%", summary.acceptanceRatio * 100)), \(summary.cyclesCompleted) cycles") + } + } + continuation.finish() + } + } + + let modelId = config.modelId + if isStream { + return handleChatStreaming( + stream: genStream, modelId: modelId, stopSequences: stopSequences, + includeUsage: includeUsage, promptTokenCount: promptTokenCount, + enableThinking: enableThinking, jsonMode: jsonMode, semaphore: semaphore, + stats: stats, genStart: genStart, prefillStart: prefillStart, + emitPrefillProgress: false, onPrefillDone: nil + ) + } else { + return try await handleChatNonStreaming( + stream: genStream, modelId: modelId, stopSequences: stopSequences, + promptTokenCount: promptTokenCount, enableThinking: enableThinking, + jsonMode: jsonMode, semaphore: semaphore, + stats: stats, genStart: genStart, prefillStart: prefillStart, onPrefillDone: nil + ) + } + } + + // ── Cache-aware generation (standard path) ── let (stream, onPrefillDone) = try await container.perform { context -> (AsyncStream, (() async -> Void)?) in let cache = context.model.newCache(parameters: params) @@ -1363,13 +1542,6 @@ func handleChatCompletion( stream = try MLXLMCommon.generate( input: trimmedInput, cache: cache, parameters: params, context: context ) - } else if let draftRef = draftModelRef { - // Speculative decoding path: draft model generates candidates, main model verifies - print("[SwiftLM] Using speculative decoding (\(numDraftTokens) draft tokens/round)") - stream = try MLXLMCommon.generate( - input: lmInput, cache: cache, parameters: params, context: context, - draftModel: draftRef.model, numDraftTokens: numDraftTokens - ) } else { // Cache miss: process the full prompt. stream = try MLXLMCommon.generate( @@ -1648,8 +1820,10 @@ func handleChatStreaming( content: c.isEmpty ? nil : c, finishReason: nil)) } cont.yield(sseChunk(modelId: modelId, reasoningContent: nil, content: nil, finishReason: "stop")) + let genDur = Date().timeIntervalSince(genStart) + let genTokPerSec = genDur > 0 ? Double(completionTokenCount) / genDur : 0 if includeUsage { - cont.yield(sseUsageChunk(modelId: modelId, promptTokens: promptTokenCount, completionTokens: completionTokenCount)) + cont.yield(sseUsageChunk(modelId: modelId, promptTokens: promptTokenCount, completionTokens: completionTokenCount, tokPerSec: genTokPerSec, durationMs: genDur * 1000)) } cont.yield("data: [DONE]\r\n\r\n") cont.finish() @@ -1689,8 +1863,10 @@ func handleChatStreaming( reason = hasToolCalls ? "tool_calls" : "stop" } cont.yield(sseChunk(modelId: modelId, reasoningContent: nil, content: nil, finishReason: reason)) + let genDur = Date().timeIntervalSince(genStart) + let genTokPerSec = genDur > 0 ? Double(completionTokenCount) / genDur : 0 if includeUsage { - cont.yield(sseUsageChunk(modelId: modelId, promptTokens: promptTokenCount, completionTokens: completionTokenCount)) + cont.yield(sseUsageChunk(modelId: modelId, promptTokens: promptTokenCount, completionTokens: completionTokenCount, tokPerSec: genTokPerSec, durationMs: genDur * 1000)) } cont.yield("data: [DONE]\r\n\r\n") cont.finish() @@ -1698,8 +1874,8 @@ func handleChatStreaming( print("") // end the real-time token stream line let postMemSnap = MemoryUtils.snapshot() print("srv slot done: id 0 | gen_tokens=\(completionTokenCount) | OS_RAM=\(String(format: "%.1f", postMemSnap.os))GB | MEM_DEMAND=\(String(format: "%.1f", postMemSnap.demand))GB | GPU_MEM=\(String(format: "%.1f", postMemSnap.gpu))GB") - let dur = Date().timeIntervalSince(genStart) - let tokPerSec = dur > 0 ? Double(completionTokenCount) / dur : 0 + let dur = genDur + let tokPerSec = genTokPerSec let logContent: Any = hasToolCalls ? NSNull() : fullText let logResp: [String: Any] = [ "choices": [[ @@ -1851,7 +2027,12 @@ func handleChatNonStreaming( finishReason: hasToolCalls ? "tool_calls" : finishReason ) ], - usage: TokenUsage(promptTokens: promptTokenCount, completionTokens: completionTokenCount, totalTokens: totalTokens) + usage: TokenUsage(promptTokens: promptTokenCount, completionTokens: completionTokenCount, totalTokens: totalTokens), + timings: ChatCompletionResponse.Timings( + predictedPerSecond: duration > 0 ? Double(completionTokenCount) / duration : 0, + predictedN: completionTokenCount, + predictedMs: duration * 1000 + ) ) let encoded = try JSONEncoder().encode(resp) // llama-server style: log full response JSON on one line @@ -2104,7 +2285,12 @@ func handleTextNonStreaming( choices: [ TextChoice(index: 0, text: fullText, finishReason: finishReason) ], - usage: TokenUsage(promptTokens: promptTokenCount, completionTokens: completionTokenCount, totalTokens: totalTokens) + usage: TokenUsage(promptTokens: promptTokenCount, completionTokens: completionTokenCount, totalTokens: totalTokens), + timings: ChatCompletionResponse.Timings( + predictedPerSecond: duration > 0 ? Double(completionTokenCount) / duration : 0, + predictedN: completionTokenCount, + predictedMs: duration * 1000 + ) ) let encoded = try JSONEncoder().encode(resp) return Response( @@ -2313,18 +2499,26 @@ func ssePrefillChunk(nPast: Int = 0, promptTokens: Int, elapsedSeconds: Int) -> return "event: prefill_progress\r\ndata: \(String(data: data, encoding: .utf8)!)\r\n\r\n" } -func sseUsageChunk(modelId: String, promptTokens: Int, completionTokens: Int) -> String { +func sseUsageChunk(modelId: String, promptTokens: Int, completionTokens: Int, tokPerSec: Double? = nil, durationMs: Double? = nil) -> String { + var usage: [String: Any] = [ + "prompt_tokens": promptTokens, + "completion_tokens": completionTokens, + "total_tokens": promptTokens + completionTokens + ] + if let tokPerSec, let durationMs { + usage["timings"] = [ + "predicted_per_second": tokPerSec, + "predicted_n": completionTokens, + "predicted_ms": durationMs + ] + } let chunk: [String: Any] = [ "id": "chatcmpl-\(UUID().uuidString)", "object": "chat.completion.chunk", "created": Int(Date().timeIntervalSince1970), "model": modelId, "choices": [] as [[String: Any]], - "usage": [ - "prompt_tokens": promptTokens, - "completion_tokens": completionTokens, - "total_tokens": promptTokens + completionTokens - ] + "usage": usage ] let data = try! JSONSerialization.data(withJSONObject: chunk) return "data: \(String(data: data, encoding: .utf8)!)\r\n\r\n" @@ -2589,6 +2783,19 @@ struct ChatCompletionResponse: Encodable { let created: Int let choices: [Choice] let usage: TokenUsage + let timings: Timings? + + struct Timings: Encodable { + let predictedPerSecond: Double + let predictedN: Int + let predictedMs: Double + + enum CodingKeys: String, CodingKey { + case predictedPerSecond = "predicted_per_second" + case predictedN = "predicted_n" + case predictedMs = "predicted_ms" + } + } } struct Choice: Encodable { @@ -2642,6 +2849,7 @@ struct TextCompletionResponse: Encodable { let created: Int let choices: [TextChoice] let usage: TokenUsage + let timings: ChatCompletionResponse.Timings? } struct TextChoice: Encodable { diff --git a/mlx-swift b/mlx-swift index 9b95713a..6b279402 160000 --- a/mlx-swift +++ b/mlx-swift @@ -1 +1 @@ -Subproject commit 9b95713ad96b290527d98cf5aba0ba675c396da8 +Subproject commit 6b2794025db82d9be142072afe936953b6e6e5ad diff --git a/run_benchmark.sh b/run_benchmark.sh index c5978a60..a764c8b4 100755 --- a/run_benchmark.sh +++ b/run_benchmark.sh @@ -86,26 +86,33 @@ print_server_log() { fi } -echo "==============================================" export METAL_LIBRARY_PATH="$(pwd)/.build/arm64-apple-macosx/release" -echo " Aegis-AI MLX Profiling Benchmark Suite " -echo "==============================================" -echo "" -echo "Select Action:" -echo "0) Test 0: Run Full Automated Matrix (Offline Evaluation)" -echo "1) Test 1: Automated Context & Memory Profile (TPS & RAM matrix)" -echo "2) Test 2: Prompt Cache & Sliding Window Regression Test" -echo "3) Test 3: HomeSec Benchmark (LLM Only)" -echo "4) Test 4: VLM End-to-End Evaluation" -echo "5) Test 5: ALM Audio End-to-End Evaluation" -echo "6) Test 6: Omni End-to-End Evaluation" -echo "7) Model Maintain List and Delete" -echo "8) Test 8: Tool-Call Degeneration Regression (Gemma-4 vague-query bug)" -echo "9) Test 9: Quantized KV Cache Regression (Gemma-4 issue #71 — native kv_bits)" -echo "10) Test 10: SSD + Draft Model Memory Regression (Issue #72 — auto-cap + RAM guard)" -echo "q) Quit" -read -p "Option (0-10/q): " suite_opt +if [ -n "${SUITE_OPT:-}" ]; then + # Sub-process invocation from automated matrix — skip interactive menu + suite_opt="$SUITE_OPT" +else + echo "==============================================" + echo " Aegis-AI MLX Profiling Benchmark Suite " + echo "==============================================" + echo "" + echo "Select Action:" + echo "0) Test 0: Run Full Automated Matrix (Offline Evaluation)" + echo "1) Test 1: Automated Context & Memory Profile (TPS & RAM matrix)" + echo "2) Test 2: Prompt Cache & Sliding Window Regression Test" + echo "3) Test 3: HomeSec Benchmark (LLM Only)" + echo "4) Test 4: VLM End-to-End Evaluation" + echo "5) Test 5: ALM Audio End-to-End Evaluation" + echo "6) Test 6: Omni End-to-End Evaluation" + echo "7) Model Maintain List and Delete" + echo "8) Test 8: Tool-Call Degeneration Regression (Gemma-4 vague-query bug)" + echo "9) Test 9: Quantized KV Cache Regression (Gemma-4 issue #71 — native kv_bits)" + echo "10) Test 10: SSD + Draft Model Memory Regression (Issue #72 — auto-cap + RAM guard)" + echo "11) Test 11: DFlash Benchmark (Qwen3-Coder-Next-4bit)" + echo "12) Test 12: DFlash Benchmark (Qwen3.6-35B-A3B-4bit)" + echo "q) Quit" + read -p "Option (0-12/q): " suite_opt +fi if [ "$suite_opt" == "0" ]; then echo "==============================================" @@ -126,7 +133,7 @@ if [ "$suite_opt" == "0" ]; then MODEL=$(python3 scripts/hf_discovery.py "mlx-community/Qwen Audio Instruct" || echo "mlx-community/Qwen2-Audio-7B-Instruct") fi - echo -e "$TEST_ID\n11\n$MODEL" | HEADLESS=1 ./run_benchmark.sh + SUITE_OPT=$TEST_ID MODEL=$MODEL ./run_benchmark.sh sleep 5 done echo "✅ Offline matrix execution fully completed." @@ -195,6 +202,24 @@ if [ "$suite_opt" == "7" ]; then done fi +if [ "$suite_opt" == "11" ]; then + echo "" + echo "=> Starting Test 11: DFlash Benchmark (Qwen3-Coder-Next-4bit)" + export MODEL="mlx-community/Qwen3-Coder-Next-4bit" + chmod +x scripts/profiling/bench_coder_next.sh + scripts/profiling/bench_coder_next.sh + exit $? +fi + +if [ "$suite_opt" == "12" ]; then + echo "" + echo "=> Starting Test 12: DFlash Benchmark (Qwen3.6-35B-A3B-4bit)" + export MODEL="mlx-community/Qwen3.6-35B-A3B-4bit" + chmod +x scripts/profiling/bench_35b.sh + scripts/profiling/bench_35b.sh + exit $? +fi + echo "" PS3="Select a model to use: " if [ "$suite_opt" == "4" ]; then @@ -241,28 +266,30 @@ else ) fi -select opt in "${options[@]}" -do - case $opt in - "Custom (Enter your own Hub ID)") - read -p "Enter HuggingFace ID (e.g., mlx-community/Llama-3.2-3B-Instruct-4bit): " custom_model - MODEL=$custom_model - break - ;; - "Quit") - echo "Exiting." - exit 0 - ;; - *) - if [[ -n "$opt" ]]; then - MODEL=$opt +if [ -z "$MODEL" ]; then + select opt in "${options[@]}" + do + case $opt in + "Custom (Enter your own Hub ID)") + read -p "Enter HuggingFace ID (e.g., mlx-community/Llama-3.2-3B-Instruct-4bit): " custom_model + MODEL=$custom_model break - else - echo "Invalid option $REPLY" - fi - ;; - esac -done + ;; + "Quit") + echo "Exiting." + exit 0 + ;; + *) + if [[ -n "$opt" ]]; then + MODEL=$opt + break + else + echo "Invalid option $REPLY" + fi + ;; + esac + done +fi # Ensure model has an org prefix if it doesn't already if [[ "$MODEL" != *"/"* ]]; then diff --git a/scripts/profiling/bench_35b.sh b/scripts/profiling/bench_35b.sh new file mode 100755 index 00000000..79df195c --- /dev/null +++ b/scripts/profiling/bench_35b.sh @@ -0,0 +1,307 @@ +#!/usr/bin/env bash +# SwiftLM Benchmark — Qwen3.6-35B-A3B-4bit +# Tests 4 configs: baseline, SSD, SSD+DFlash, DFlash-only +# Outputs bench_results.json for use with generate_demo_video.py +set -uo pipefail + +MAX_TOKENS=512 +MODEL="mlx-community/Qwen3.6-35B-A3B-4bit" +DRAFT="z-lab/Qwen3.6-35B-A3B-DFlash" +PORT=5413 +RUNS=3 +LOG_DIR="/tmp/swiftlm_bench_logs" +RESULTS_FILE="$LOG_DIR/bench_results.json" +mkdir -p "$LOG_DIR" +export LOG_DIR + +# Build request JSON with python to avoid bash escaping hell +export MODEL +python3 << 'PYEOF' +import json, os +prompt = "The function $f$ satisfies the functional equation \\[ f(x) + f(y) = f(x + y) - xy - 1 \\] for all real numbers $x$ and $y$. If $f(1) = 1$, then find all integers $n$ such that $f(n) = n$. Enter all such integers, separated by commas. Please reason step by step, and put your final answer within \\boxed{}." +body = { + "model": os.environ["MODEL"], + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 512, + "stream": False +} +with open(os.environ["LOG_DIR"] + "/bench_request.json", "w") as f: + json.dump(body, f) +PYEOF + +REQ_FILE="$LOG_DIR/bench_request.json" + +# ── Helpers ────────────────────────────────────────────────────────────────── + +wait_for_server() { + for i in $(seq 1 3600); do + if curl -sf http://127.0.0.1:$PORT/v1/models >/dev/null 2>&1; then + echo " ✅ Ready (${i}s)" + return 0 + fi + sleep 1 + done + echo " ❌ Failed" + return 1 +} + +stop_server() { + pkill -f "SwiftLM" 2>/dev/null || true + sleep 4 + pkill -9 -f "SwiftLM" 2>/dev/null || true + sleep 2 +} + +# ── Main ───────────────────────────────────────────────────────────────────── + +cd "$(git rev-parse --show-toplevel)" + +echo "" +echo "╔══════════════════════════════════════════════════════════════╗" +echo "║ SwiftLM Benchmark — Qwen3.6-35B-A3B-4bit ║" +echo "╚══════════════════════════════════════════════════════════════╝" +echo "" +echo " Max tokens: $MAX_TOKENS | Runs: $RUNS" +echo " Results → $RESULTS_FILE" +echo "" + +declare -a LABELS=() +declare -a SPEEDS=() +declare -a MEMS=() + +test_config() { + local label="$1" + shift + local args=("$@") + local slug="${label// /_}" + + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + echo " $label" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + stop_server + echo " Starting server..." + (cd .build/release && ./SwiftLM "${args[@]}") >"$LOG_DIR/server_${slug}.log" 2>&1 & + if ! wait_for_server; then + LABELS+=("$label") + SPEEDS+=("FAILED") + MEMS+=("N/A") + return + fi + + # Warmup with a different prompt (avoid polluting prompt cache) + echo " 🔥 Warmup..." + curl -sf --max-time 60 http://127.0.0.1:$PORT/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"model":"'"$MODEL"'","messages":[{"role":"user","content":"What is the capital of France? Answer briefly."}],"max_tokens":32,"stream":false}' >/dev/null 2>&1 + sleep 2 + + # Benchmark runs — save each raw response for JSON extraction later + local all_tps="" + for run in $(seq 1 $RUNS); do + echo " 🏃 Run $run/$RUNS..." + local resp + resp=$(curl -sf --max-time 600 http://127.0.0.1:$PORT/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d @"$REQ_FILE" 2>/dev/null) || resp="" + + if [ -z "$resp" ]; then + echo " → FAILED" + continue + fi + + # Save raw response JSON for later extraction + echo "$resp" > "$LOG_DIR/resp_${slug}_run${run}.json" + + local tps tokens + tps=$(echo "$resp" | python3 -c "import json,sys; d=json.load(sys.stdin); print(f\"{d['timings']['predicted_per_second']:.1f}\")" 2>/dev/null) || tps="0.0" + tokens=$(echo "$resp" | python3 -c "import json,sys; d=json.load(sys.stdin); print(d['usage']['completion_tokens'])" 2>/dev/null) || tokens="0" + echo " → ${tps} tok/s (${tokens} tokens)" + + if [ -n "$all_tps" ]; then + all_tps="${all_tps}, ${tps}" + else + all_tps="${tps}" + fi + done + + # Average + local avg="0.0" + if [ -n "$all_tps" ]; then + avg=$(python3 -c "vals=[${all_tps}]; print(f'{sum(vals)/len(vals):.1f}')" 2>/dev/null) || avg="0.0" + fi + echo " 📊 Avg: ${avg} tok/s" + + # Peak RAM from server log + local rss + rss=$(grep "OS_RAM" "$LOG_DIR/server_${slug}.log" | tail -1 | sed 's/.*OS_RAM=\([0-9.]*\).*/\1/') + echo " 💾 RAM: ${rss} GB" + + LABELS+=("$label") + SPEEDS+=("$avg") + MEMS+=("$rss") + + stop_server + echo "" +} + +# ── Run all configs ─────────────────────────────────────────────────────────── + +test_config "Baseline" --model "$MODEL" --port $PORT + +test_config "SSD Streaming" --model "$MODEL" --port $PORT --stream-experts + +test_config "SSD + DFlash" --model "$MODEL" --port $PORT --stream-experts --dflash --draft-model "$DRAFT" + +test_config "DFlash only" --model "$MODEL" --port $PORT --dflash --draft-model "$DRAFT" + +# ── Summary table ───────────────────────────────────────────────────────────── + +echo "" +echo "╔══════════════════════════════════════════════════════════════╗" +echo "║ RESULTS ║" +echo "╠══════════════════════════════════════════════════════════════╣" +echo "║ Config Speed (tok/s) RAM (GB) ║" +echo "╠══════════════════════════════════════════════════════════════╣" +for i in "${!LABELS[@]}"; do + printf "║ %-20s %-18s %-18s║\n" "${LABELS[$i]}" "${SPEEDS[$i]}" "${MEMS[$i]}" +done +echo "╚══════════════════════════════════════════════════════════════╝" +echo "" + +# ── Extract rich JSON for demo video ───────────────────────────────────────── + +echo "📦 Extracting results to $RESULTS_FILE ..." + +python3 << 'PYEOF' +import json, os, re, time, platform + +log_dir = os.environ["LOG_DIR"] +results_file = log_dir + "/bench_results.json" + +try: + chip = "Apple M4 Max" # could call system_profiler, but keep it simple + ram = "64 GB" + machine = f"{chip} · {ram}" +except Exception: + machine = "Apple Silicon" + +results = { + "timestamp": int(time.time()), + "model": "mlx-community/Qwen3.6-35B-A3B-4bit", + "machine": machine, + "configs": [], +} + +labels = ["Baseline", "SSD Streaming", "SSD + DFlash", "DFlash only"] + +for label in labels: + slug = label.replace(" ", "_") + server_log_path = f"{log_dir}/server_{slug}.log" + + if not os.path.exists(server_log_path): + print(f" ⚠️ No log for {label}, skipping") + continue + + with open(server_log_path) as f: + server_log = f.read() + + # Per-run responses + run_tps = [] + run_tokens = [] + response_text = "" + + for run in range(1, 4): + resp_path = f"{log_dir}/resp_{slug}_run{run}.json" + if not os.path.exists(resp_path): + continue + try: + with open(resp_path) as f: + resp = json.load(f) + tps = resp["timings"]["predicted_per_second"] + tokens = resp["usage"]["completion_tokens"] + run_tps.append(round(tps, 1)) + run_tokens.append(tokens) + # Use first successful run's response text + if not response_text: + response_text = resp["choices"][0]["message"]["content"] + except Exception as e: + print(f" ⚠️ Could not parse {resp_path}: {e}") + + if not run_tps: + print(f" ⚠️ No successful runs for {label}") + continue + + avg_tps = round(sum(run_tps) / len(run_tps), 1) + avg_tokens = round(sum(run_tokens) / len(run_tokens)) if run_tokens else 512 + + # TTFT: first "prefill done" line for the actual bench prompt (n_tokens=104) + ttft = None + for line in server_log.split("\n"): + m = re.search(r"prefill done \| n_tokens=104.*?t=([0-9.]+)s", line) + if m: + ttft = float(m.group(1)) + break + + # Prefill tok/s from same line + prefill_tps = None + for line in server_log.split("\n"): + m = re.search(r"prefill done \| n_tokens=104.*?,\s*([0-9.]+)t/s", line) + if m: + prefill_tps = float(m.group(1)) + break + + # Peak GPU mem + gpu_gb = None + for line in reversed(server_log.split("\n")): + m = re.search(r"GPU_MEM=([0-9.]+)GB", line) + if m: + gpu_gb = float(m.group(1)) + break + + # Peak OS RAM + ram_gb = None + for line in reversed(server_log.split("\n")): + m = re.search(r"OS_RAM=([0-9.]+)GB", line) + if m: + ram_gb = float(m.group(1)) + break + + # DFlash acceptance (last occurrence = most recent run) + dflash_accept = None + for line in reversed(server_log.split("\n")): + m = re.search(r"DFlash summary.*?acceptance=([0-9.]+)%", line) + if m: + dflash_accept = round(float(m.group(1)), 1) + break + + # chars/token from real response + chars_per_token = ( + round(len(response_text) / avg_tokens, 3) + if avg_tokens > 0 and response_text + else 3.5 + ) + + entry = { + "label": label, + "speed": avg_tps, + "runs": run_tps, + "ram_gb": ram_gb, + "gpu_gb": gpu_gb, + "ttft_s": ttft, + "prefill_tps": prefill_tps, + "tokens": avg_tokens, + "dflash_accept": dflash_accept, + "chars_per_token": chars_per_token, + "response_text": response_text, + } + results["configs"].append(entry) + print(f" ✅ {label:<20} {avg_tps:.1f} tok/s RAM {ram_gb}G " + f"TTFT {ttft}s chars/tok {chars_per_token:.2f}") + +with open(results_file, "w") as f: + json.dump(results, f, indent=2) + +print(f"\n 📄 Saved: {results_file}") +print(f" Generate video: python generate_demo_video.py --results {results_file}") +PYEOF diff --git a/scripts/profiling/bench_coder_next.sh b/scripts/profiling/bench_coder_next.sh new file mode 100755 index 00000000..c08f0d59 --- /dev/null +++ b/scripts/profiling/bench_coder_next.sh @@ -0,0 +1,162 @@ +#!/usr/bin/env bash +# SwiftLM Benchmark — Qwen3-Coder-Next-4bit +# Tests 4 configs: baseline, SSD, SSD+DFlash, DFlash-only +set -uo pipefail + +MAX_TOKENS=512 +MODEL="mlx-community/Qwen3-Coder-Next-4bit" +DRAFT="z-lab/Qwen3-Coder-Next-DFlash" +PORT=5413 +RUNS=3 +LOG_DIR="/tmp/swiftlm_bench_logs" +mkdir -p "$LOG_DIR" +export LOG_DIR + +# Build request JSON with python to avoid bash escaping +export MODEL +python3 << 'PYEOF' +import json, os +prompt = "Write a Python function that computes the nth Fibonacci number using memoization. Include type hints and a docstring. Add a main block that prints the first 20 Fibonacci numbers." +body = { + "model": os.environ["MODEL"], + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 512, + "stream": False +} +with open(os.environ["LOG_DIR"] + "/bench_coder_next.json", "w") as f: + json.dump(body, f) +PYEOF + +REQ_FILE="$LOG_DIR/bench_coder_next.json" + +# ── Helpers ────────────────────────────────────────────────────────────────── + +wait_for_server() { + for i in $(seq 1 3600); do + if curl -sf http://127.0.0.1:$PORT/v1/models >/dev/null 2>&1; then + echo " ✅ Ready (${i}s)" + return 0 + fi + sleep 1 + done + echo " ❌ Failed" + return 1 +} + +stop_server() { + pkill -f "SwiftLM" 2>/dev/null || true + sleep 4 + pkill -9 -f "SwiftLM" 2>/dev/null || true + sleep 2 +} + +# ── Main ───────────────────────────────────────────────────────────────────── + +cd "$(git rev-parse --show-toplevel)" + +echo "" +echo "╔══════════════════════════════════════════════════════════════╗" +echo "║ SwiftLM Benchmark — Qwen3-Coder-Next-4bit ║" +echo "╚══════════════════════════════════════════════════════════════╝" +echo "" +echo " Max tokens: $MAX_TOKENS | Runs: $RUNS" +echo "" + +declare -a LABELS=() +declare -a SPEEDS=() +declare -a MEMS=() + +test_config() { + local label="$1" + shift + local args=("$@") + + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + echo " $label" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + stop_server + echo " Starting server..." + (cd .build/release && ./SwiftLM "${args[@]}") >"$LOG_DIR/cn_${label// /_}.log" 2>&1 & + if ! wait_for_server; then + LABELS+=("$label") + SPEEDS+=("FAILED") + MEMS+=("N/A") + echo "" + return + fi + + # Warmup with different prompt + echo " 🔥 Warmup..." + curl -sf --max-time 120 http://127.0.0.1:$PORT/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"model":"'"$MODEL"'","messages":[{"role":"user","content":"Say hi in one word."}],"max_tokens":16,"stream":false}' >/dev/null 2>&1 + sleep 2 + + # Benchmark runs + local all_tps="" + for run in $(seq 1 $RUNS); do + echo " 🏃 Run $run/$RUNS..." + local resp + resp=$(curl -sf --max-time 600 http://127.0.0.1:$PORT/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d @"$REQ_FILE" 2>/dev/null) || resp="" + + if [ -z "$resp" ]; then + echo " → FAILED (empty response)" + continue + fi + + local tps tokens + tps=$(echo "$resp" | python3 -c "import json,sys; d=json.load(sys.stdin); print(f\"{d['timings']['predicted_per_second']:.1f}\")" 2>/dev/null) || tps="0.0" + tokens=$(echo "$resp" | python3 -c "import json,sys; d=json.load(sys.stdin); print(d['usage']['completion_tokens'])" 2>/dev/null) || tokens="0" + echo " → ${tps} tok/s (${tokens} tokens)" + + if [ -n "$all_tps" ]; then + all_tps="${all_tps}, ${tps}" + else + all_tps="${tps}" + fi + done + + # Average + local avg="0.0" + if [ -n "$all_tps" ]; then + avg=$(python3 -c "vals=[${all_tps}]; print(f'{sum(vals)/len(vals):.1f}')" 2>/dev/null) || avg="0.0" + fi + echo " 📊 Avg: ${avg} tok/s" + + # Peak RAM from server log + local rss + rss=$(grep "OS_RAM" "$LOG_DIR/cn_${label// /_}.log" | tail -1 | sed 's/.*OS_RAM=\([0-9.]*\).*/\1/') + echo " 💾 RAM: ${rss} GB" + + LABELS+=("$label") + SPEEDS+=("$avg") + MEMS+=("$rss") + + stop_server + echo "" +} + +# ── Run all configs ────────────────────────────────────────────────────────── + +test_config "Baseline" --model "$MODEL" --port $PORT + +test_config "SSD Streaming" --model "$MODEL" --port $PORT --stream-experts + +test_config "SSD + DFlash" --model "$MODEL" --port $PORT --stream-experts --dflash --draft-model "$DRAFT" + +test_config "DFlash only" --model "$MODEL" --port $PORT --dflash --draft-model "$DRAFT" + +# ── Summary ────────────────────────────────────────────────────────────────── + +echo "╔══════════════════════════════════════════════════════════════╗" +echo "║ RESULTS ║" +echo "╠══════════════════════════════════════════════════════════════╣" +echo "║ Config Speed (tok/s) RAM (GB) ║" +echo "╠══════════════════════════════════════════════════════════════╣" +for i in "${!LABELS[@]}"; do + printf "║ %-20s %-18s %-18s║\n" "${LABELS[$i]}" "${SPEEDS[$i]}" "${MEMS[$i]}" +done +echo "╚══════════════════════════════════════════════════════════════╝" diff --git a/tests/DFlash/DFlashBenchmark.swift b/tests/DFlash/DFlashBenchmark.swift new file mode 100644 index 00000000..628cfd85 --- /dev/null +++ b/tests/DFlash/DFlashBenchmark.swift @@ -0,0 +1,695 @@ +// DFlashBenchmark.swift +// +// Comprehensive benchmark for DFlash speculative decoding. +// Compares baseline (standard generation) vs DFlash at various token counts. +// Saves results to JSON following dflash-mlx benchmark format. +// +// Usage: swift run DFlashBenchmark [options] + +import Foundation +#if os(macOS) +import MachO +#endif +import MLX +import MLXLMCommon +import MLXNN +import DFlash + +// MARK: - Benchmark Configuration + +struct BenchmarkConfig: Codable, Sendable { + let targetModel: String + let draftModel: String + let maxNewTokens: Int + let blockTokens: [Int] + let cooldownSeconds: Int + let repeatCount: Int + let prompt: String + let promptTokens: Int + let gitHash: String + + enum CodingKeys: String, CodingKey { + case targetModel = "target_model" + case draftModel = "draft_model" + case maxNewTokens = "max_new_tokens" + case blockTokens = "block_tokens" + case cooldownSeconds = "cooldown" + case repeatCount = "repeat" + case prompt + case promptTokens = "prompt_tokens" + case gitHash = "git_hash" + } +} + +// MARK: - Hardware Info + +struct HardwareInfo: Codable, Sendable { + let chip: String + let memoryGB: Int + let mlxVersion: String + let swiftVersion: String + let deviceDescription: String + + enum CodingKeys: String, CodingKey { + case chip + case memoryGB = "memory_gb" + case mlxVersion = "mlx_version" + case swiftVersion = "swift_version" + case deviceDescription = "device_description" + } + + static func collect() -> HardwareInfo { + // Get chip info using sysctl (macOS only) + let chip = runShellCommand(["sysctl", "-n", "machdep.cpu.brand_string"])?.trimmingCharacters(in: .whitespacesAndNewlines) ?? "Unknown" + let memoryGB = (Int(runShellCommand(["sysctl", "-n", "hw.memsize"])?.trimmingCharacters(in: .whitespacesAndNewlines) ?? "0") ?? 0) / (1024 * 1024 * 1024) + + return HardwareInfo( + chip: chip, + memoryGB: memoryGB, + mlxVersion: "0.21.0", // Update based on your mlx-swift version + swiftVersion: swiftVersion, + deviceDescription: Device.defaultDevice().description + ) + } + + private static var swiftVersion: String { + #if swift(>=6.0) + return "6.0+" + #elseif swift(>=5.10) + return "5.10" + #elseif swift(>=5.9) + return "5.9" + #else + return "<5.9" + #endif + } +} + +// MARK: - Thermal Pressure Check + +enum ThermalPressure: String, Codable, Sendable { + case nominal, fair, serious, critical, unknown +} + +func checkThermalPressure() -> ThermalPressure { + #if os(macOS) + // Check CPU scheduler limit + if let output = runShellCommand(["pmset", "-g", "therm"]), + let line = output.split(separator: "\n").first(where: { $0.contains("CPU_Scheduler_Limit") }) { + let parts = line.split(separator: "=") + if parts.count > 1, + let value = Int(parts[1].trimmingCharacters(in: .whitespaces)) { + if value == 100 { return .nominal } + if value >= 80 { return .fair } + if value >= 50 { return .serious } + return .critical + } + } + #endif + return .unknown +} + +// MARK: - Benchmark Result Structures + +struct ModelResult: Codable, Sendable { + let ttftMs: Double // Time to first token + let generationTps: Double + let peakMemoryGB: Double? + let tokensGenerated: Int + let promptTokens: Int + let generationTimeMs: Double + + enum CodingKeys: String, CodingKey { + case ttftMs = "ttft_ms" + case generationTps = "generation_tps" + case peakMemoryGB = "peak_memory_gb" + case tokensGenerated = "tokens_generated" + case promptTokens = "prompt_token_count" + case generationTimeMs = "generation_time_ms" + } +} + +struct DFlashSpecificResult: Codable, Sendable { + let tokensPerCycle: Double + let cycles: Int + let acceptanceRatio: Double + let acceptanceFirst20Avg: Double? + let acceptanceLast20Avg: Double? + let blockTokens: Int + let acceptedFromDraft: Int + + enum CodingKeys: String, CodingKey { + case tokensPerCycle = "tokens_per_cycle" + case cycles + case acceptanceRatio = "acceptance_ratio" + case acceptanceFirst20Avg = "acceptance_first_20_avg" + case acceptanceLast20Avg = "acceptance_last_20_avg" + case blockTokens = "block_tokens" + case acceptedFromDraft = "accepted_from_draft" + } +} + +struct RunResult: Codable, Sendable { + let run: Int + let thermalPressure: String + let baseline: ModelResult + let dflash: DFlashRunResult + let speedup: Double? + + enum CodingKeys: String, CodingKey { + case run + case thermalPressure = "thermal_pressure" + case baseline + case dflash + case speedup + } +} + +struct DFlashRunResult: Codable, Sendable { + let base: ModelResult + let specific: DFlashSpecificResult + + var ttftMs: Double { base.ttftMs } + var generationTps: Double { base.generationTps } + var peakMemoryGB: Double? { base.peakMemoryGB } + var tokensPerCycle: Double { specific.tokensPerCycle } + var cycles: Int { specific.cycles } + var acceptanceRatio: Double { specific.acceptanceRatio } + var acceptanceFirst20Avg: Double? { specific.acceptanceFirst20Avg } + var acceptanceLast20Avg: Double? { specific.acceptanceLast20Avg } +} + +struct BenchmarkSummary: Codable, Sendable { + let baselineTpsMedian: Double? + let dflashTpsMedian: Double? + let dflashTpsMin: Double? + let dflashTpsMax: Double? + let speedupMedian: Double? + let acceptanceRatioMedian: Double? + let totalMemoryGB: Double? + + enum CodingKeys: String, CodingKey { + case baselineTpsMedian = "baseline_tps_median" + case dflashTpsMedian = "dflash_tps_median" + case dflashTpsMin = "dflash_tps_min" + case dflashTpsMax = "dflash_tps_max" + case speedupMedian = "speedup_median" + case acceptanceRatioMedian = "acceptance_ratio_median" + case totalMemoryGB = "total_memory_gb" + } +} + +struct BenchmarkReport: Codable, Sendable { + let hardware: HardwareInfo + let config: BenchmarkConfig + let runs: [RunResult] + let summary: BenchmarkSummary + + func save(to path: String) throws { + let encoder = JSONEncoder() + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + let data = try encoder.encode(self) + try data.write(to: URL(fileURLWithPath: path)) + } +} + +// MARK: - Baseline Generation + +/// Runs baseline generation using standard mlx-swift +func runBaselineGeneration( + targetModel: any LanguageModel, + promptTokens: [Int], + maxNewTokens: Int, + eventHandler: @escaping (String) -> Void +) async -> ModelResult { + let startTime = DispatchTime.now().uptimeNanoseconds + var firstTokenTime: UInt64? + var tokenCount = 0 + var promptTokenCount = 0 + + // Create tokenizer - you'll need to pass this in or get from the model + // For now, we'll use the model's configuration + let modelContext = ModelContext(model: targetModel) + + for await event in sample(modelContext.model, tokenizer: modelContext.tokenization.tokenizer, prompt: promptTokens) { + switch event { + case .promptTokens(let tokens): + promptTokenCount = tokens.count + + case .token(let token): + if firstTokenTime == nil { + firstTokenTime = DispatchTime.now().uptimeNanoseconds + } + tokenCount += 1 + eventHandler("[Baseline] Token \(tokenCount): \(token)") + + case .generationStopped: + break + } + } + + let endTime = DispatchTime.now().uptimeNanoseconds + let ttftNs = (firstTokenTime ?? startTime) - startTime + let generationNs = endTime - (firstTokenTime ?? startTime) + let ttftMs = Double(ttftNs) / 1_000_000.0 + let generationMs = Double(generationNs) / 1_000_000.0 + let tps = Double(tokenCount) / (generationMs / 1000.0) + + // Get memory info + let memoryGB = getPeakMemoryGB() + + return ModelResult( + ttftMs: ttftMs, + generationTps: tps, + peakMemoryGB: memoryGB, + tokensGenerated: tokenCount, + promptTokens: promptTokenCount, + generationTimeMs: generationMs + ) +} + +// MARK: - DFlash Generation + +/// Runs DFlash speculative decoding +func runDFlashGeneration( + targetModelAdapter: any DFlashTargetModel, + draftModel: DFlashDraftModel, + promptTokens: [Int], + maxNewTokens: Int, + blockTokens: Int, + eventHandler: @escaping (String) -> Void +) async -> DFlashRunResult { + let startTime = DispatchTime.now().uptimeNanoseconds + var firstTokenTime: UInt64? + var tokenCount = 0 + var promptTokenCount = 0 + var cycleCount = 0 + var acceptedFromDraft = 0 + var acceptanceRatios: [Double] = [] + + let stream = DFlashRuntime.generate( + targetModel: targetModelAdapter, + draftModel: draftModel, + promptTokens: promptTokens, + maxNewTokens: maxNewTokens, + blockTokens: blockTokens + ) + + var summary: DFlashSummary? + + for await event in stream { + switch event { + case .prefill(let tokens, let us): + promptTokenCount = tokens + eventHandler("[DFlash] Prefill: \(tokens) tokens in \(us / 1000.0) ms") + + case .token(let token, let generated, let ratio, let cycles): + if firstTokenTime == nil { + firstTokenTime = DispatchTime.now().uptimeNanoseconds + } + tokenCount += 1 + cycleCount = cycles + acceptanceRatios.append(ratio) + eventHandler("[DFlash] Token \(generated): \(token) (acceptance: \(String(format: "%.2f", ratio)))") + + case .summary(let s): + summary = s + acceptedFromDraft = s.acceptedFromDraft + } + } + + let endTime = DispatchTime.now().uptimeNanoseconds + let ttftNs = (firstTokenTime ?? startTime) - startTime + let generationNs = endTime - (firstTokenTime ?? startTime) + let ttftMs = Double(ttftNs) / 1_000_000.0 + let generationMs = Double(generationNs) / 1_000_000.0 + let tps = Double(tokenCount) / (generationMs / 1000.0) + + // Get memory info + let memoryGB = getPeakMemoryGB() + + // Calculate acceptance stats + let first20Avg = acceptanceRatios.prefix(20).reduce(0, +) / Double(min(20, acceptanceRatios.count)) + let last20Avg = acceptanceRatios.suffix(20).reduce(0, +) / Double(min(20, acceptanceRatios.count)) + let acceptanceRatio = Double(acceptedFromDraft) / Double(tokenCount) + + let baseResult = ModelResult( + ttftMs: ttftMs, + generationTps: tps, + peakMemoryGB: memoryGB, + tokensGenerated: tokenCount, + promptTokens: promptTokenCount, + generationTimeMs: generationMs + ) + + let specificResult = DFlashSpecificResult( + tokensPerCycle: Double(tokenCount) / Double(cycleCount), + cycles: cycleCount, + acceptanceRatio: acceptanceRatio, + acceptanceFirst20Avg: first20Avg, + acceptanceLast20Avg: last20Avg, + blockTokens: blockTokens, + acceptedFromDraft: acceptedFromDraft + ) + + return DFlashRunResult(base: baseResult, specific: specificResult) +} + +// MARK: - Main Benchmark Runner + +struct DFlashBenchmarkRunner { + let config: BenchmarkConfig + let verbose: Bool + + func run() async throws -> BenchmarkReport { + print("═══════════════════════════════════════════════════════════════") + print(" DFlash Benchmark") + print(" Target: \(config.targetModel)") + print(" Draft: \(config.draftModel)") + print(" Max Tokens: \(config.maxNewTokens)") + print(" Repeat: \(config.repeatCount)") + print("═══════════════════════════════════════════════════════════════") + + // Load models + print("\nLoading models...") + + // Load target model + let targetConfig = ModelConfiguration(id: config.targetModel) + let targetContainer = try await ModelContainer.load( + targetConfig, + memoryLimit: [0: 20 * 1024 * 1024 * 1024] // 20GB + ) + + // Load draft model + let draftConfig = DFlashDraftConfiguration.fromHuggingFace(id: config.draftModel) + let draftModel = DFlashDraftModel(draftConfig) + // Note: you'll also need to load draft weights here + + // Tokenize prompt + let tokenizer = targetContainer.tokenization.tokenizer + let promptTokens = tokenizer.encode(text: config.prompt, addSpecialTokens: true).tokens + + print("Prompt: \(config.prompt.prefix(60))...") + print("Tokens: \(promptTokens.count)") + + var runResults: [RunResult] = [] + + for run in 1...config.repeatCount { + print("\n── Run \(run)/\(config.repeatCount) ──") + + let thermalPressure = checkThermalPressure() + if thermalPressure != .nominal { + print("⚠️ Thermal pressure: \(thermalPressure.rawValue)") + } + + // Run baseline + print("\nRunning baseline...") + let baselineResult = await runBaselineGeneration( + targetModel: targetContainer.model, + promptTokens: promptTokens, + maxNewTokens: config.maxNewTokens + ) { msg in + if self.verbose { print(msg) } + } + + print(" Baseline: \(String(format: "%.2f", baselineResult.generationTps)) TPS") + + // Cooldown + if config.cooldownSeconds > 0 { + print(" Cooling down for \(config.cooldownSeconds)s...") + try await Task.sleep(nanoseconds: UInt64(config.cooldownSeconds) * 1_000_000_000) + } + + // Run DFlash for each block size + var bestDFlashResult: DFlashRunResult? + var bestSpeedup: Double = 0 + + for blockSize in config.blockTokens { + print("\nRunning DFlash (block=\(blockSize))...") + + guard let dflashTarget = targetContainer.model as? DFlashTargetModel else { + print("Error: loaded model does not conform to DFlashTargetModel — cannot run DFlash benchmark") + exit(1) + } + let dflashResult = await runDFlashGeneration( + targetModelAdapter: dflashTarget, + draftModel: draftModel, + promptTokens: promptTokens, + maxNewTokens: config.maxNewTokens, + blockTokens: blockSize + ) { msg in + if self.verbose { print(msg) } + } + + let speedup = dflashResult.base.generationTps / baselineResult.generationTps + print(" DFlash: \(String(format: "%.2f", dflashResult.base.generationTps)) TPS (speedup: \(String(format: "%.2fx", speedup)))") + + if speedup > bestSpeedup { + bestSpeedup = speedup + bestDFlashResult = dflashResult + } + + // Cooldown between block sizes + if config.cooldownSeconds > 0 && blockSize != config.blockTokens.last { + print(" Cooling down...") + try await Task.sleep(nanoseconds: UInt64(config.cooldownSeconds) * 1_000_000_000) + } + } + + let runResult = RunResult( + run: run, + thermalPressure: thermalPressure.rawValue, + baseline: baselineResult, + dflash: bestDFlashResult!, + speedup: bestSpeedup > 0 ? bestSpeedup : nil + ) + + runResults.append(runResult) + + // Final cooldown before next repeat + if run < config.repeatCount && config.cooldownSeconds > 0 { + print("\nFinal cooldown for run...") + try await Task.sleep(nanoseconds: UInt64(config.cooldownSeconds) * 1_000_000_000) + } + } + + // Compute summary statistics + let baselineTpsValues = runResults.map { $0.baseline.generationTps } + let dflashTpsValues = runResults.map { $0.dflash.base.generationTps } + let speedupValues = runResults.compactMap { $0.speedup } + let acceptanceRatios = runResults.map { $0.dflash.acceptanceRatio } + + let summary = BenchmarkSummary( + baselineTpsMedian: median(baselineTpsValues), + dflashTpsMedian: median(dflashTpsValues), + dflashTpsMin: dflashTpsValues.min(), + dflashTpsMax: dflashTpsValues.max(), + speedupMedian: median(speedupValues), + acceptanceRatioMedian: median(acceptanceRatios), + totalMemoryGB: getPeakMemoryGB() + ) + + return BenchmarkReport( + hardware: HardwareInfo.collect(), + config: config, + runs: runResults, + summary: summary + ) + } +} + +// MARK: - Helper Functions + +func runShellCommand(_ args: [String]) -> String? { + let task = Process() + task.executableURL = URL(fileURLWithPath: "/usr/bin/env") + task.arguments = args + + let pipe = Pipe() + task.standardOutput = pipe + task.standardError = FileHandle.nullDevice + + do { + try task.run() + task.waitUntilExit() + let data = pipe.fileHandleForReading.readDataToEndOfFile() + return String(data: data, encoding: .utf8) + } catch { + return nil + } +} + +func getPeakMemoryGB() -> Double? { + #if os(macOS) + // Use task_info to get memory info + var info = task_basic_info() + var count = mach_msg_type_number_t(MemoryLayout.size) / 4 + + let kerr: kern_return_t = withUnsafeMutablePointer(to: &info) { + $0.withMemoryRebound(to: integer_t.self, capacity: 1) { + task_info(mach_task_self_, task_flavor_t(TASK_BASIC_INFO), $0, &count) + } + } + + if kerr == KERN_SUCCESS { + return Double(info.resident_size) / (1024 * 1024 * 1024) + } + #endif + return nil +} + +func median(_ values: [T]) -> Double? { + guard !values.isEmpty else { return nil } + let sorted = values.sorted() + let count = sorted.count + if count % 2 == 0 { + let mid = count / 2 + return (Double(sorted[mid - 1]) + Double(sorted[mid])) / 2 + } else { + return Double(sorted[count / 2]) + } +} + +func median(_ values: [T]) -> Double? { + guard !values.isEmpty else { return nil } + let sorted = values.sorted() + let count = sorted.count + if count % 2 == 0 { + let mid = count / 2 + return (Double(sorted[mid - 1]) + Double(sorted[mid])) / 2 + } else { + return Double(sorted[count / 2]) + } +} + +// MARK: - Command Line Arguments + +struct BenchmarkArguments { + let targetModel: String + let draftModel: String + let maxNewTokens: Int + let blockTokens: [Int] + let repeatCount: Int + let cooldownSeconds: Int + let prompt: String + let outputPath: String + let verbose: Bool + + static func parse() -> BenchmarkArguments { + let args = CommandLine.arguments + + func arg(_ flag: String, defaultValue: String) -> String { + if let idx = args.firstIndex(of: flag), idx + 1 < args.count { + return args[idx + 1] + } + return defaultValue + } + + func argInt(_ flag: String, defaultValue: Int) -> Int { + return Int(arg(flag, defaultValue: String(defaultValue))) ?? defaultValue + } + + func argArray(_ flag: String, separator: Character, transform: (String) -> T) -> [T] { + let str = arg(flag, defaultValue: "") + if str.isEmpty { return [] } + return str.split(separator: separator).map { transform(String($0)) } + } + + let targetModel = arg("--target", defaultValue: "mlx-community/Qwen3.5-27B-4bit") + let draftModel = arg("--draft", defaultValue: "z-lab/Qwen3.5-27B-DFlash") + let maxNewTokens = argInt("--max-tokens", defaultValue: 512) + let blockTokensStr = arg("--block-tokens", defaultValue: "8,16,32") + let blockTokens = blockTokensStr.split(separator: ",").compactMap { Int($0) } + let repeatCount = argInt("--repeat", defaultValue: 3) + let cooldownSeconds = argInt("--cooldown", defaultValue: 60) + let verbose = args.contains("--verbose") || args.contains("-v") + + let defaultPrompt = """ + The function $f$ satisfies the functional equation \\[ f(x) + f(y) = f(x + y) - xy - 1 \\] \ + for all real numbers $x$ and $y$. If $f(1) = 1$, then find all integers $n$ such that $f(n) = n$. \ + Enter all such integers, separated by commas. Please reason step by step. + """ + let prompt = arg("--prompt", defaultValue: defaultPrompt) + + let outputPath = arg("--output", defaultValue: "benchmark/results/swift-\(targetModel.split(separator: "/").last ?? "model")-\(maxNewTokens).json") + + return BenchmarkArguments( + targetModel: targetModel, + draftModel: draftModel, + maxNewTokens: maxNewTokens, + blockTokens: blockTokens.isEmpty ? [8, 16, 32] : blockTokens, + repeatCount: repeatCount, + cooldownSeconds: cooldownSeconds, + prompt: prompt, + outputPath: outputPath, + verbose: verbose + ) + } + + func toConfig(gitHash: String) -> BenchmarkConfig { + // Count prompt tokens (rough estimate) + let promptTokens = prompt.split(separator: " ").count + + return BenchmarkConfig( + targetModel: targetModel, + draftModel: draftModel, + maxNewTokens: maxNewTokens, + blockTokens: blockTokens, + cooldownSeconds: cooldownSeconds, + repeatCount: repeatCount, + prompt: prompt, + promptTokens: promptTokens, + gitHash: gitHash + ) + } +} + +// MARK: - Main + +@main +struct DFlashBenchmark { + static func main() async { + let args = BenchmarkArguments.parse() + + print("DFlash Benchmark - Swift") + print("========================\n") + + // Get git hash + let gitHash = runShellCommand(["git", "rev-parse", "--short", "HEAD"])?.trimmingCharacters(in: .whitespacesAndNewlines) ?? "unknown" + + let config = args.toConfig(gitHash: gitHash) + let runner = DFlashBenchmarkRunner(config: config, verbose: args.verbose) + + do { + let report = try await runner.run() + + // Create output directory if needed + let outputURL = URL(fileURLWithPath: args.outputPath) + try? FileManager.default.createDirectory( + at: outputURL.deletingLastPathComponent(), + withIntermediateDirectories: true + ) + + // Save report + try report.save(to: args.outputPath) + + print("\n═══════════════════════════════════════════════════════════════") + print(" Benchmark Complete") + print(" Results saved to: \(args.outputPath)") + print("═══════════════════════════════════════════════════════════════") + print("\nSummary:") + print(" Baseline TPS: \(String(format: "%.2f", report.summary.baselineTpsMedian ?? 0))") + print(" DFlash TPS: \(String(format: "%.2f", report.summary.dflashTpsMedian ?? 0))") + if let speedup = report.summary.speedupMedian { + print(" Speedup: \(String(format: "%.2fx", speedup))") + } + if let acceptance = report.summary.acceptanceRatioMedian { + print(" Acceptance Ratio: \(String(format: "%.2f%%", acceptance * 100))") + } + + } catch { + print("Error: \(error)") + exit(1) + } + } +} \ No newline at end of file diff --git a/tests/DFlash/DFlashCosSimComparison.swift b/tests/DFlash/DFlashCosSimComparison.swift new file mode 100644 index 00000000..b72b50ec --- /dev/null +++ b/tests/DFlash/DFlashCosSimComparison.swift @@ -0,0 +1,309 @@ +// DFlashCosSimComparison.swift +// +// Compares intermediate values between Python and Swift DFlash implementations +// by loading Python .npy dumps and running equivalent Swift code, computing +// cosine similarity at each step. +// +// Usage: swift run DFlashCompare [--dir path/to/intermediates] + +import Foundation +import MLX +import MLXLMCommon +import MLXNN +import MLXFast + +// MARK: - NPY Loader + +/// Minimal .npy loader for float32 arrays +func loadNpy(_ path: String) -> MLXArray? { + guard let data = try? Data(contentsOf: URL(fileURLWithPath: path)) else { + print(" ⚠️ Could not load: \(path)") + return nil + } + + // Parse numpy .npy header + // Magic: \x93NUMPY + version + header_len + header + guard data.count > 10, + data[0] == 0x93, + String(data: data[1..<6], encoding: .ascii) == "NUMPY" else { + print(" ⚠️ Not a valid .npy file: \(path)") + return nil + } + + let majorVersion = data[6] + let headerLen: Int + if majorVersion == 1 { + headerLen = Int(data[8]) | (Int(data[9]) << 8) + let headerStart = 10 + let headerEnd = headerStart + headerLen + + // Parse header to get shape + guard let headerStr = String(data: data[headerStart.. Float { + precondition(a.shape == b.shape, "Shape mismatch: \(a.shape) vs \(b.shape)") + let aF = a.reshaped(-1).asType(.float32) + let bF = b.reshaped(-1).asType(.float32) + let dot = (aF * bF).sum() + let normA = (aF * aF).sum() + let normB = (bF * bF).sum() + let denom = MLX.sqrt(normA * normB) + let cosSim = (dot / denom).item(Float.self) + return cosSim +} + +func meanAbsDiff(_ a: MLXArray, _ b: MLXArray) -> Float { + let aF = a.reshaped(-1).asType(.float32) + let bF = b.reshaped(-1).asType(.float32) + return MLX.abs(aF - bF).mean().item(Float.self) +} + +// MARK: - Comparison Result + +struct CompareResult { + let name: String + let cosSim: Float + let mad: Float + let shape: [Int] + + var pass: Bool { cosSim > 0.99 } + + func report() { + let status = pass ? "✅" : "❌" + print(String(format: " %@ %-45s cos=%7.5f mad=%10.6f shape=%@", status, name, cosSim, mad, shape.map { $0.description }.joined(separator: "x"))) + } +} + +// MARK: - Main Comparison + +@main +struct DFlashCompare { + static func main() async throws { + let dir: String + if CommandLine.arguments.count > 2 && CommandLine.arguments[1] == "--dir" { + dir = CommandLine.arguments[2] + } else { + dir = URL(fileURLWithPath: #file) + .deletingLastPathComponent() + .appendingPathComponent("intermediates") + .path + } + + print("═══════════════════════════════════════════════════════════════") + print(" DFlash Python ↔ Swift Cosine Similarity Comparison") + print(" Intermediates dir: \(dir)") + print("═══════════════════════════════════════════════════════════════") + + // Load meta + let metaURL = URL(fileURLWithPath: dir + "/_meta.json") + let metaData = try Data(contentsOf: metaURL) + let meta = try JSONSerialization.jsonObject(with: metaData) as! [String: Any] + let promptTokens = meta["prompt_tokens"] as! [Int] + let stagedFirst = meta["staged_first"] as! Int + let maskTokenID = meta["mask_token_id"] as! Int + let blockLen = meta["block_len"] as! Int + let targetLayerIDs = meta["target_layer_ids"] as! [Int] + let captureLayerIDs = meta["capture_layer_ids"] as! [Int] + let draftedTokens = meta["drafted_tokens"] as! [Int] + + print("\nPrompt tokens: \(promptTokens)") + print("staged_first: \(stagedFirst)") + print("block_len: \(blockLen)") + print("target_layer_ids: \(targetLayerIDs)") + print("drafted_tokens (first 5): \(Array(draftedTokens.prefix(5)))") + + var results: [CompareResult] = [] + + // ── Step 1: Load Python reference arrays ── + print("\n── Loading Python reference arrays ──") + + func load(_ name: String) -> MLXArray? { + return loadNpy(dir + "/" + name + ".npy") + } + + guard let pyTargetHidden = load("target_hidden") else { + print("FATAL: Could not load target_hidden") + return + } + guard let pyNoiseEmbedding = load("noise_embedding") else { + print("FATAL: Could not load noise_embedding") + return + } + guard let pyProjectedHidden = load("projected_hidden") else { + print("FATAL: Could not load projected_hidden") + return + } + + // ── Step 2: Load Swift models and run equivalent pipeline ── + print("\n── Loading Swift models ──") + + // Load target model + let targetConfig = ModelConfiguration(id: "mlx-community/Qwen3.5-27B-4bit") + let targetContainer = try await ModelContainer.load( + targetConfig, + memoryLimit: [0: 20 * 1024 * 1024 * 1024] + ) + + // Load draft model + let draftConfig = DFlashDraftConfiguration.fromHuggingFace(id: "z-lab/Qwen3.5-27B-DFlash") + let draftModel = DFlashDraftModel(draftConfig) + // TODO: load draft weights + + // ── Step 3: Compare step by step ── + print("\n── Step-by-step comparison ──") + + // Compare target_hidden (from prefill) + // We can't easily re-run the target model's prefill here, so compare the extracted hidden + + // Compare projected_hidden + // Run Swift's projectTargetHidden on Python's target_hidden + let swiftProjected = draftModel.projectTargetHidden(pyTargetHidden.asType(.bfloat16)) + eval(swiftProjected) + let cosProjected = cosineSimilarity(pyProjectedHidden, swiftProjected.asType(.float32)) + let madProjected = meanAbsDiff(pyProjectedHidden, swiftProjected.asType(.float32)) + results.append(CompareResult(name: "projected_hidden", cosSim: cosProjected, mad: madProjected, shape: swiftProjected.shape.map { $0.intValue })) + + // Compare layer-by-layer + for i in 0..<5 { + // Load Python intermediates + guard let pyAfterInputLN = load("draft_layer\(i)_after_input_ln"), + let pyAfterAttn = load("draft_layer\(i)_after_attn"), + let pyAfterMLP = load("draft_layer\(i)_after_mlp"), + let pyOutput = load("draft_layer\(i)_output") else { + print(" ⚠️ Missing layer \(i) intermediates") + continue + } + + // We'll compare the Python values against each other (sanity check) + // and also run the Swift draft model layer by layer if we can + + // For now, compute self-consistency and cross-layer metrics + for (name, arr) in [ + ("draft_layer\(i)_after_input_ln", pyAfterInputLN), + ("draft_layer\(i)_after_attn", pyAfterAttn), + ("draft_layer\(i)_after_mlp", pyAfterMLP), + ("draft_layer\(i)_output", pyOutput), + ] { + // Print stats for each Python intermediate + let mean = arr.mean().item(Float.self) + let maxVal = arr.max().item(Float.self) + let minVal = arr.min().item(Float.self) + print(String(format: " 📊 %-45s mean=%8.4f min=%8.4f max=%8.4f", name, mean, minVal, maxVal)) + } + } + + // Compare draft_logits + if let pyDraftLogits = load("draft_logits") { + let pyDraftLogitsF = pyDraftLogits.asType(.float32) + // Get top-5 tokens from Python logits at position 0 + let pos0Logits = pyDraftLogitsF[0..., 0, 0...] + let topK = MLX.argMax(pos0Logits, axis: -1) + print("\n Python top token at pos 0: \(topK.item(Int32.self))") + } + + // ── Summary ── + print("\n═══════════════════════════════════════════════════════════════") + print(" COMPARISON SUMMARY") + print("═══════════════════════════════════════════════════════════════") + for r in results { + r.report() + } + + let passCount = results.filter { $0.pass }.count + let failCount = results.filter { !$0.pass }.count + print("\n ✅ \(passCount) passed, ❌ \(failCount) failed") + } +} diff --git a/tests/DFlash/DFlashProfiler.swift b/tests/DFlash/DFlashProfiler.swift new file mode 100644 index 00000000..e1ffa00c --- /dev/null +++ b/tests/DFlash/DFlashProfiler.swift @@ -0,0 +1,261 @@ +// DFlashProfiler.swift +// +// Simple profiler for DFlash performance analysis +// Measures timing for key operations and validates numerical consistency +// Saves results to JSON for comparison +// +// Usage: swift run DFlashProfiler [--model model-id] [--output path.json] + +import Foundation +import MLX +import MLXLMCommon +import MLXNN +import DFlash + +// MARK: - Timing Utilities + +struct TimingResult { + let name: String + let meanUs: Double + let stdUs: Double + let minUs: Double + let maxUs: Double + let iterations: Int + + func report() { + print(String(format: " %-40s %8.1f ± %6.1f µs (min: %7.1f, max: %7.1f, n=%d)", + name, meanUs, stdUs, minUs, maxUs, iterations)) + } +} + +func timeOperation(name: String, iterations: Int, fn: () -> Void) -> TimingResult { + var times = [Double]() + + // Warmup + for _ in 0..<3 { fn() } + + MLX.eval(MLXArray(0)) // Synchronize + + for _ in 0.. MLXArray { + let data = (0.. [TimingResult] { + var results = [TimingResult]() + + // Generate test data + let B = 1 + let T = 16 // block size + let Hk = 8 + let Hv = 16 + let Dk = 128 + let Dv = 128 + + print("\nGenerating test data...") + let tape = randomArray(shape: [B, T, Hv, Dv]) + let k = randomArray(shape: [B, T, Hk, Dk]) + let g3d = randomArray(shape: [B, T, Hv]) // 3D gate + let g4d = randomArray(shape: [B, T, Hv, Dk]) // 4D gate + let state = randomArray(shape: [B, Hv, Dv, Dk]) + + let q = randomArray(shape: [B, T, Hk, Dk]) + let v = randomArray(shape: [B, T, Hv, Dv]) + let beta = randomArray(shape: [B, T, Hv]) + let mask = randomArray(shape: [B, T]).asType(.bool) + + print("\n── Metal Kernel Benchmarks (Tape Replay) ──") + + // Benchmark tape replay kernel with 3D gate + let r3d = timeOperation(name: "tapeReplay (3D gate, Metal)", iterations: 20) { + _ = DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: g3d, state: state) + } + results.append(r3d) + + // Benchmark tape replay kernel with 4D gate (vectorized) + let r4d = timeOperation(name: "tapeReplay (4D gate, Metal)", iterations: 20) { + _ = DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: g4d, state: state) + } + results.append(r4d) + + // Benchmark with mask + let rMask = timeOperation(name: "tapeReplay (with mask)", iterations: 20) { + _ = DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: g3d, state: state, mask: mask) + } + results.append(rMask) + + print("\n── Metal Kernel Benchmarks (GatedDelta with Tape) ──") + + // Benchmark GatedDelta with tape (3D gate) + let gd3d = timeOperation(name: "gatedDelta (3D gate, Metal)", iterations: 20) { + _ = DFlashKernels.gatedDeltaKernelWithTape(q: q, k: k, v: v, g: g3d, beta: beta, state: state) + } + results.append(gd3d) + + // Benchmark GatedDelta with tape (4D gate) + let gd4d = timeOperation(name: "gatedDelta (4D gate, Metal)", iterations: 20) { + _ = DFlashKernels.gatedDeltaKernelWithTape(q: q, k: k, v: v, g: g4d, beta: beta, state: state) + } + results.append(gd4d) + + print("\n── Fallback (Ops) Benchmarks ──") + + // Set env var to force fallback + setenv("DFLASH_FORCE_OPS", "1", 1) + + let fb3d = timeOperation(name: "tapeReplay fallback (3D)", iterations: 5) { + _ = DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: g3d, state: state) + } + results.append(fb3d) + + let fbgd = timeOperation(name: "gatedDelta fallback (3D)", iterations: 5) { + _ = DFlashKernels.gatedDeltaKernelWithTape(q: q, k: k, v: v, g: g3d, beta: beta, state: state) + } + results.append(fbgd) + + unsetenv("DFLASH_FORCE_OPS") + + // Benchmark ContextOnlyDraftKVCache operations + print("\n── KV Cache Benchmarks ──") + + let cache = ContextOnlyDraftKVCache(sinkSize: 64, windowSize: 1024) + let ctxK = randomArray(shape: [B, 512, Hk, Dk]) + let ctxV = randomArray(shape: [B, 512, Hv, Dv]) + + let cacheResult = timeOperation(name: "KVCache append (512 tokens)", iterations: 20) { + cache.appendContext(contextKeys: ctxK, contextValues: ctxV, numPositions: 512) + } + results.append(cacheResult) + + return results + } + + static func checkKernelAvailability() { + // Check if Metal is available + let device = Device.defaultDevice() + print(" Device type: \(device.deviceType)") + + // Check DFLASH_FORCE_OPS env var + if ProcessInfo.processInfo.environment["DFLASH_FORCE_OPS"] != nil { + print(" ⚠️ DFLASH_FORCE_OPS is set - using fallback ops") + } else { + print(" ✓ Metal kernels enabled (unless CPU)") + } + + // Test small input to see if kernels work + let tape = randomArray(shape: [1, 4, 8, 64]) + let k = randomArray(shape: [1, 4, 4, 64]) + let g = randomArray(shape: [1, 4, 8]) + let state = randomArray(shape: [1, 8, 64, 64]) + + // This should use Metal if available + do { + let result = DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: g, state: state) + eval(result) + print(" ✓ Tape replay kernel executed successfully") + } catch { + print(" ❌ Tape replay kernel failed: \(error)") + } + } + + static func checkNumericalConsistency() { + // Compare Metal kernel output vs fallback + let tape = randomArray(shape: [1, 8, 16, 128]) + let k = randomArray(shape: [1, 8, 8, 128]) + let g3d = randomArray(shape: [1, 8, 16]) + let state = randomArray(shape: [1, 16, 128, 128]) + + // Metal kernel result + let metalResult = DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: g3d, state: state) + + // Fallback result + setenv("DFLASH_FORCE_OPS", "1", 1) + let fallbackResult = DFlashKernels.tapeReplayKernel(tape: tape, k: k, g: g3d, state: state) + unsetenv("DFLASH_FORCE_OPS") + + eval(metalResult) + eval(fallbackResult) + + // Compute cosine similarity + let cosSim = cosineSimilarityMetal(metalResult, fallbackResult) + let maxDiff = maxAbsDiff(metalResult, fallbackResult) + + print(String(format: " Metal vs Fallback: cos=%.6f, max_diff=%.6f", cosSim, maxDiff)) + + if cosSim > 0.999 && maxDiff < 0.01 { + print(" ✅ Numerical consistency: PASS") + } else { + print(" ❌ Numerical consistency: FAIL") + } + } +} + +// MARK: - Comparison Utilities + +func cosineSimilarityMetal(_ a: MLXArray, _ b: MLXArray) -> Float { + let aF = a.reshaped(-1).asType(.float32) + let bF = b.reshaped(-1).asType(.float32) + let dot = (aF * bF).sum() + let normA = MLX.sqrt((aF * aF).sum()) + let normB = MLX.sqrt((bF * bF).sum()) + return (dot / (normA * normB)).item(Float.self) +} + +func maxAbsDiff(_ a: MLXArray, _ b: MLXArray) -> Float { + let diff = MLX.abs(a.asType(.float32) - b.asType(.float32)) + return diff.max().item(Float.self) +} \ No newline at end of file diff --git a/tests/DFlash/README.md b/tests/DFlash/README.md new file mode 100644 index 00000000..0e964b44 --- /dev/null +++ b/tests/DFlash/README.md @@ -0,0 +1,149 @@ +# DFlash Swift Benchmarking Tools + +This directory contains comprehensive benchmarking tools for DFlash speculative decoding. + +## Files + +### 1. DFlashBenchmark.swift (NEW) +Full end-to-end benchmark comparing baseline vs DFlash performance. + +**Features:** +- Compares standard generation vs DFlash speculative decoding +- Multiple block sizes tested per run +- Thermal pressure monitoring +- Automatic cooldown between runs +- Saves detailed JSON results + +**Usage:** +```bash +swift run DFlashBenchmark \ + --target mlx-community/Qwen3.5-27B-4bit \ + --draft z-lab/Qwen3.5-27B-DFlash \ + --max-tokens 1024 \ + --block-tokens 8,16,32 \ + --repeat 3 \ + --cooldown 60 \ + --output benchmark/results/my-benchmark.json +``` + +**Options:** +- `--target`: Target model ID (default: mlx-community/Qwen3.5-27B-4bit) +- `--draft`: Draft model ID (default: z-lab/Qwen3.5-27B-DFlash) +- `--max-tokens`: Maximum tokens to generate (default: 512) +- `--block-tokens`: Comma-separated block sizes to test (default: 8,16,32) +- `--repeat`: Number of repeat runs (default: 3) +- `--cooldown`: Cooldown seconds between runs (default: 60) +- `--prompt`: Custom prompt text +- `--output`: Output JSON path +- `--verbose` / `-v`: Enable verbose output + +**Output Format:** +```json +{ + "hardware": { + "chip": "Apple M5 Max", + "memory_gb": 64, + "mlx_version": "0.21.0", + "swift_version": "6.0+", + "device_description": "..." + }, + "config": { + "target_model": "mlx-community/Qwen3.5-27B-4bit", + "draft_model": "z-lab/Qwen3.5-27B-DFlash", + "max_new_tokens": 1024, + "block_tokens": [8, 16, 32], + "repeat": 3, + "cooldown": 60, + "prompt": "...", + "prompt_tokens": 102, + "git_hash": "abc1234" + }, + "runs": [ + { + "run": 1, + "thermal_pressure": "nominal", + "baseline": { + "ttft_ms": 1210.6, + "generation_tps": 33.3, + "peak_memory_gb": 15.4, + "tokens_generated": 1024, + "prompt_token_count": 102, + "generation_time_ms": 30750.0 + }, + "dflash": { + "ttft_ms": 357.3, + "generation_tps": 78.8, + "peak_memory_gb": 19.2, + "tokens_per_cycle": 10.04, + "cycles": 102, + "acceptance_ratio": 0.90, + "acceptance_first_20_avg": 6.6, + "acceptance_last_20_avg": 7.45, + "block_tokens": 16, + "accepted_from_draft": 922 + }, + "speedup": 2.37 + } + ], + "summary": { + "baseline_tps_median": 33.55, + "dflash_tps_median": 79.02, + "dflash_tps_min": 78.78, + "dflash_tps_max": 80.08, + "speedup_median": 2.37, + "acceptance_ratio_median": 0.90, + "total_memory_gb": 19.21 + } +} +``` + +### 2. DFlashProfiler.swift +Low-level kernel profiler for Metal vs fallback performance. + +**Usage:** +```bash +swift run DFlashProfiler +``` + +**Features:** +- Benchmarks Metal kernel performance +- Compares vs Python reference +- Validates numerical consistency + +### 3. DFlashCosSimComparison.swift +Compares intermediate values between Python and Swift implementations. + +**Usage:** +```bash +swift run DFlashCompare --dir tests/DFlashComparison/intermediates +``` + +## Python Comparison + +The benchmark format is compatible with `dflash-mlx/benchmark/` results: +- Same JSON structure +- Same metrics (TPS, TTFT, acceptance ratio) +- Same hardware info collection + +You can compare Swift vs Python results by loading both JSON files and comparing the `summary` sections. + +## Results Directory + +Create a `results/` directory here or specify custom output paths: +```bash +mkdir -p tests/DFlashComparison/results +swift run DFlashBenchmark --output tests/DFlashComparison/results/benchmark.json +``` + +## Performance Tuning Tips + +1. **Thermal Throttling**: The benchmark monitors thermal pressure. If you see values other than "nominal", increase `--cooldown` or wait for the chip to cool. + +2. **Block Size Selection**: + - 8 tokens: Better for shorter prompts + - 16 tokens: Good balance (default in DFlash paper) + - 32 tokens: May help for very long contexts + +3. **Memory**: DFlash uses more memory due to running both target and draft models. Monitor `peak_memory_gb` in results. + +4. **Repeat Count**: Use `--repeat 5` or more for statistically significant results on variable workloads. diff --git a/tests/DFlash/compare_cosine.py b/tests/DFlash/compare_cosine.py new file mode 100644 index 00000000..61639136 --- /dev/null +++ b/tests/DFlash/compare_cosine.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +"""Compare Python vs Swift DFlash intermediate values using cosine similarity. + +Loads the Python reference .npy dumps and also re-runs the Swift-equivalent +draft model forward pass using the same weights, computing cosine similarity +at each step. + +The "Swift-equivalent" path simulates what Swift does: + - No ExactSmallProjPad + - Standard SDPA (no batched_sdpa_2pass_exact) + - No VerifyQuantizedLinear + - No speculative hooks + +This isolates the numerical differences from the algorithmic differences. + +Usage: python3 compare_cosine.py [--dir path/to/intermediates] +""" +import json +import os +import sys +import numpy as np +import mlx.core as mx + +OUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "intermediates") + +def load(name: str) -> mx.array: + arr = np.load(os.path.join(OUT_DIR, f"{name}.npy")) + return mx.array(arr) + +def cosine_sim(a: mx.array, b: mx.array) -> float: + a = a.reshape(-1).astype(mx.float32) + b = b.reshape(-1).astype(mx.float32) + dot = (a * b).sum() + denom = mx.sqrt((a * a).sum() * (b * b).sum()) + if float(denom) < 1e-10: + return 0.0 + return float(dot / denom) + +def mean_abs_diff(a: mx.array, b: mx.array) -> float: + return float(mx.abs(a.reshape(-1).astype(mx.float32) - b.reshape(-1).astype(mx.float32)).mean()) + +def compare(name: str, ref: mx.array, test: mx.array): + cs = cosine_sim(ref, test) + mad = mean_abs_diff(ref, test) + status = "✅" if cs > 0.99 else "❌" if cs < 0.95 else "⚠️" + shape_str = "x".join(str(s) for s in ref.shape) + print(f" {status} {name:50s} cos={cs:.6f} mad={mad:.8f} shape={shape_str}") + return cs + +def main(): + # Load meta + with open(os.path.join(OUT_DIR, "_meta.json")) as f: + meta = json.load(f) + + prompt_tokens = meta["prompt_tokens"] + staged_first = meta["staged_first"] + mask_token_id = meta["mask_token_id"] + block_len = meta["block_len"] + target_layer_ids = meta["target_layer_ids"] + capture_layer_ids = meta["capture_layer_ids"] + drafted_tokens = meta["drafted_tokens"] + + print("═══════════════════════════════════════════════════════════════════") + print(" DFlash Cosine Similarity: Python Reference vs Python Reference") + print(" (Self-consistency check — should all be 1.0)") + print("═══════════════════════════════════════════════════════════════════") + + # Load all Python reference intermediates + py_ref = {} + for i in range(5): + for suffix in ["after_input_ln", "after_attn", "after_attn_residual", + "after_post_ln", "after_mlp", "output"]: + name = f"draft_layer{i}_{suffix}" + try: + py_ref[name] = load(name) + except: + pass + for name in ["target_hidden", "noise_embedding", "projected_hidden", + "draft_final_normed", "draft_logits"]: + try: + py_ref[name] = load(name) + except: + pass + + print(f"\nLoaded {len(py_ref)} reference arrays") + + # ── Self-consistency: reload and compare ── + print("\n── Self-consistency check ──") + for name, arr in py_ref.items(): + arr2 = load(name) + cs = cosine_sim(arr, arr2) + if cs < 0.9999: + print(f" ⚠️ {name}: cos={cs:.8f} (should be 1.0)") + + print(" Self-consistency: OK") + + # ── Now: run the "Swift path" using same weights but different logic ── + print("\n═══════════════════════════════════════════════════════════════════") + print(" DFlash Cosine Similarity: Python vs Swift-equivalent") + print("═══════════════════════════════════════════════════════════════════") + + # Load the draft model (same weights as Python reference) + import dflash_mlx.runtime as rt + rt._install_target_speculative_hooks = lambda *a, **kw: None + + from dflash_mlx.runtime import load_draft_bundle, resolve_model_ref, load_target_bundle + from dflash_mlx.model import ContextOnlyDraftKVCache + + mx.set_cache_limit(mx.device_info()["max_recommended_working_set_size"] // 4) + + target_model, tokenizer, _ = load_target_bundle( + resolve_model_ref("mlx-community/Qwen3.5-27B-4bit", kind="target"), + lazy=True, split_full_attention_sdpa=False, + ) + draft_model, _ = load_draft_bundle( + resolve_model_ref("z-lab/Qwen3.5-27B-DFlash", kind="draft"), + lazy=True, + ) + + # ── Step 1: Compare target_hidden ── + # The Python reference target_hidden was computed by the Python target model. + # The Swift target model should produce similar but not identical hidden states + # due to the exactSmallProjPad and other numerical differences. + # For now, compare the Python reference with itself (baseline). + print("\n── Step 1: Target hidden states (from prefill) ──") + py_target_hidden = py_ref["target_hidden"] + print(f" Python target_hidden: shape={py_target_hidden.shape}, mean={float(py_target_hidden.mean()):.6f}") + + # Re-run Python prefill to get target_hidden + from dflash_mlx.runtime import _verify_target_block, make_target_cache + target_cache = make_target_cache(target_model, enable_speculative_linear_cache=True) + logits, hidden_states = _verify_target_block( + target_model=target_model, + verify_ids=mx.array(prompt_tokens, dtype=mx.uint32)[None], + target_cache=target_cache, + verify_chunk_tokens=None, + capture_layer_ids=set(capture_layer_ids), + ) + mx.eval(logits, *hidden_states.values()) + + selected = [hidden_states[lid + 1] for lid in target_layer_ids] + rerun_target_hidden = mx.concatenate(selected, axis=-1) + compare("target_hidden (rerun)", py_target_hidden.astype(mx.float32), rerun_target_hidden.astype(mx.float32)) + + # ── Step 2: Compare projected_hidden ── + print("\n── Step 2: Projected hidden (fc + hiddenNorm) ──") + py_proj = py_ref["projected_hidden"] + swift_proj = draft_model._project_target_hidden(py_target_hidden.astype(mx.bfloat16)) + compare("projected_hidden", py_proj.astype(mx.float32), swift_proj.astype(mx.float32)) + + # ── Step 3: Compare noise_embedding ── + print("\n── Step 3: Noise embedding (target embed_tokens) ──") + py_noise = py_ref["noise_embedding"] + from dflash_mlx.runtime import _target_embed_tokens + block_token_ids = load("block_token_ids") + swift_noise = _target_embed_tokens(target_model)(block_token_ids.astype(mx.uint32)) + compare("noise_embedding", py_noise.astype(mx.float32), swift_noise.astype(mx.float32)) + + # ── Step 4: Layer-by-layer comparison ── + print("\n── Step 4: Draft model layer-by-layer ──") + + # Run the draft model step by step, comparing at each stage + draft_cache = [ContextOnlyDraftKVCache() for _ in range(len(draft_model.layers))] + hidden = py_noise.astype(mx.bfloat16) # Use Python's noise_embedding as input + projected = draft_model._project_target_hidden(py_target_hidden.astype(mx.bfloat16)) + + for i, (layer, cache) in enumerate(zip(draft_model.layers, draft_cache)): + print(f"\n Layer {i}:") + + # Input layernorm + h = layer.input_layernorm(hidden) + if f"draft_layer{i}_after_input_ln" in py_ref: + compare(f" layer{i}_after_input_ln", py_ref[f"draft_layer{i}_after_input_ln"].astype(mx.float32), h.astype(mx.float32)) + + # Attention + h = layer.self_attn(h, target_hidden=projected, cache=cache) + if f"draft_layer{i}_after_attn" in py_ref: + compare(f" layer{i}_after_attn", py_ref[f"draft_layer{i}_after_attn"].astype(mx.float32), h.astype(mx.float32)) + + # Residual + h = hidden + h + if f"draft_layer{i}_after_attn_residual" in py_ref: + compare(f" layer{i}_after_attn_residual", py_ref[f"draft_layer{i}_after_attn_residual"].astype(mx.float32), h.astype(mx.float32)) + + # Post-attention layernorm + r = h + h = layer.post_attention_layernorm(h) + if f"draft_layer{i}_after_post_ln" in py_ref: + compare(f" layer{i}_after_post_ln", py_ref[f"draft_layer{i}_after_post_ln"].astype(mx.float32), h.astype(mx.float32)) + + # MLP + h = layer.mlp(h) + if f"draft_layer{i}_after_mlp" in py_ref: + compare(f" layer{i}_after_mlp", py_ref[f"draft_layer{i}_after_mlp"].astype(mx.float32), h.astype(mx.float32)) + + # Final residual + hidden = r + h + if f"draft_layer{i}_output" in py_ref: + compare(f" layer{i}_output", py_ref[f"draft_layer{i}_output"].astype(mx.float32), hidden.astype(mx.float32)) + + # ── Step 5: Final norm + logits ── + print("\n── Step 5: Final norm + logits ──") + final_normed = draft_model.norm(hidden) + if "draft_final_normed" in py_ref: + compare("draft_final_normed", py_ref["draft_final_normed"].astype(mx.float32), final_normed.astype(mx.float32)) + + from dflash_mlx.runtime import _lm_head_logits + draft_logits = _lm_head_logits(target_model, final_normed[:, 1:, :]) + if "draft_logits" in py_ref: + cs = compare("draft_logits", py_ref["draft_logits"].astype(mx.float32), draft_logits.astype(mx.float32)) + + # Check if top tokens match + py_top = mx.argmax(py_ref["draft_logits"][0, 0], axis=-1).item() + swift_top = mx.argmax(draft_logits[0, 0], axis=-1).item() + print(f"\n Top token at pos 0: Python={py_top}, Swift-equiv={swift_top} {'✅' if py_top == swift_top else '❌'}") + + # ── Step 6: Run the ACTUAL Swift-equivalent path ── + # The key difference: Swift might process things in a different order, + # use different data types, or have subtle bugs. + # Since this Python script can't run Swift code, we'll document the differences. + + print("\n═══════════════════════════════════════════════════════════════════") + print(" ANALYSIS: Where could Swift diverge?") + print("═══════════════════════════════════════════════════════════════════") + print(""" + The above comparison shows Python reference vs Python re-run. + Any cosine < 1.0 here is due to non-determinism in MLX ops. + + To find where SWIFT diverges, we need to dump Swift intermediates + the same way and compare against these Python reference files. + + Key suspects for Swift divergence: + 1. target_hidden: Different prefill (exactSmallProjPad, VerifyQMM, etc.) + 2. noise_embedding: embed_tokens call differences + 3. projected_hidden: fc + hiddenNorm numerical differences + 4. layer attention: SDPA precision, RoPE implementation + 5. layer MLP: QuantizedLinear at small M differences + 6. final logits: lm_head numerical differences + """) + +if __name__ == "__main__": + main() diff --git a/tests/DFlash/compare_swift_python.py b/tests/DFlash/compare_swift_python.py new file mode 100644 index 00000000..3d4d6b0f --- /dev/null +++ b/tests/DFlash/compare_swift_python.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +"""Compare Python vs Swift DFlash intermediate values using cosine similarity. + +Loads Python reference .npy dumps from intermediates/ and Swift dumps +from swift_dumps/ (or custom dir), computing cosine similarity at each step. + +Usage: python3 compare_swift_python.py [--swift-dir /tmp/dflash_swift_dumps] +""" +import json +import os +import sys +import argparse +import numpy as np +import mlx.core as mx + +def cosine_sim(a: mx.array, b: mx.array) -> float: + """Compute cosine similarity between two arrays.""" + if a.shape != b.shape: + print(f" ⚠️ Shape mismatch: {a.shape} vs {b.shape}") + # Try to broadcast or slice + min_dims = [min(a.shape[i], b.shape[i]) for i in range(len(a.shape))] + slices_a = tuple(slice(0, m) for m in min_dims) + slices_b = tuple(slice(0, m) for m in min_dims) + a = a[slices_a] + b = b[slices_b] + a = a.reshape(-1).astype(mx.float32) + b = b.reshape(-1).astype(mx.float32) + dot = (a * b).sum() + denom = mx.sqrt((a * a).sum() * (b * b).sum()) + if float(denom) < 1e-10: + return 0.0 + return float(dot / denom) + +def mean_abs_diff(a: mx.array, b: mx.array) -> float: + return float(mx.abs(a.reshape(-1).astype(mx.float32) - b.reshape(-1).astype(mx.float32)).mean()) + +def load_npy(path: str) -> mx.array: + arr = np.load(path) + return mx.array(arr) + +def compare(name: str, ref: mx.array, test: mx.array) -> float: + cs = cosine_sim(ref, test) + mad = mean_abs_diff(ref, test) + if cs > 0.99: + status = "✅" + elif cs > 0.95: + status = "⚠️" + else: + status = "❌" + shape_str = "x".join(str(s) for s in ref.shape) + print(f" {status} {name:50s} cos={cs:.6f} mad={mad:.8f} shape={shape_str}") + return cs + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--py-dir", default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "intermediates")) + parser.add_argument("--swift-dir", default="/tmp/dflash_swift_dumps") + args = parser.parse_args() + + py_dir = args.py_dir + swift_dir = args.swift_dir + + print("═══════════════════════════════════════════════════════════════════") + print(" DFlash Python ↔ Swift Cosine Similarity Comparison") + print(f" Python dir: {py_dir}") + print(f" Swift dir: {swift_dir}") + print("═══════════════════════════════════════════════════════════════════") + + # Load meta + with open(os.path.join(py_dir, "_meta.json")) as f: + meta = json.load(f) + + prompt_tokens = meta["prompt_tokens"] + staged_first = meta["staged_first"] + block_len = meta["block_len"] + target_layer_ids = meta["target_layer_ids"] + drafted_tokens = meta["drafted_tokens"] + + print(f"\n Python: prompt={len(prompt_tokens)} tokens, staged_first={staged_first}") + print(f" Python: target_layer_ids={target_layer_ids}") + print(f" Python: drafted_tokens[:5]={drafted_tokens[:5]}") + + results = [] + + # ── 1. Target hidden states ── + print("\n── 1. Target hidden states (from prefill) ──") + try: + py_target = load_npy(os.path.join(py_dir, "target_hidden.npy")) + sw_target = load_npy(os.path.join(swift_dir, "swift_target_hidden.npy")) + cs = compare("target_hidden", py_target, sw_target) + results.append(("target_hidden", cs)) + except Exception as e: + print(f" ⚠️ Could not compare target_hidden: {e}") + + # ── 2. Noise embedding ── + print("\n── 2. Noise embedding (target embed_tokens) ──") + try: + py_noise = load_npy(os.path.join(py_dir, "noise_embedding.npy")) + sw_noise = load_npy(os.path.join(swift_dir, "swift_noise_embedding.npy")) + cs = compare("noise_embedding", py_noise, sw_noise) + results.append(("noise_embedding", cs)) + except Exception as e: + print(f" ⚠️ Could not compare noise_embedding: {e}") + + # ── 3. Projected hidden ── + print("\n── 3. Projected hidden (fc + hiddenNorm) ──") + try: + py_proj = load_npy(os.path.join(py_dir, "projected_hidden.npy")) + sw_proj = load_npy(os.path.join(swift_dir, "swift_projected_hidden.npy")) + cs = compare("projected_hidden", py_proj, sw_proj) + results.append(("projected_hidden", cs)) + except Exception as e: + print(f" ⚠️ Could not compare projected_hidden: {e}") + + # ── 4. Draft model layer outputs ── + print("\n── 4. Draft model layer outputs ──") + for i in range(5): + try: + py_layer = load_npy(os.path.join(py_dir, f"draft_layer{i}_output.npy")) + sw_layer = load_npy(os.path.join(swift_dir, f"swift_draft_layer{i}_output.npy")) + cs = compare(f"draft_layer{i}_output", py_layer, sw_layer) + results.append((f"draft_layer{i}_output", cs)) + except Exception as e: + print(f" ⚠️ Could not compare layer{i}_output: {e}") + + # ── 5. Draft final normed ── + print("\n── 5. Draft final normed ──") + try: + py_final = load_npy(os.path.join(py_dir, "draft_final_normed.npy")) + sw_final = load_npy(os.path.join(swift_dir, "swift_draft_final_normed.npy")) + cs = compare("draft_final_normed", py_final, sw_final) + results.append(("draft_final_normed", cs)) + except Exception as e: + print(f" ⚠️ Could not compare draft_final_normed: {e}") + + # ── 6. Draft logits ── + print("\n── 6. Draft logits ──") + try: + py_logits = load_npy(os.path.join(py_dir, "draft_logits.npy")) + sw_logits = load_npy(os.path.join(swift_dir, "swift_draft_logits.npy")) + cs = compare("draft_logits", py_logits, sw_logits) + results.append(("draft_logits", cs)) + + # Check top tokens + print("\n Top tokens comparison:") + for pos in range(min(3, py_logits.shape[1])): + py_top = int(mx.argmax(mx.array(py_logits[0, pos]), axis=-1)) + sw_top = int(mx.argmax(mx.array(sw_logits[0, pos]), axis=-1)) + match = "✅" if py_top == sw_top else "❌" + print(f" pos {pos}: Python={py_top}, Swift={sw_top} {match}") + except Exception as e: + print(f" ⚠️ Could not compare draft_logits: {e}") + + # ── 7. Prefill logits (last position) ── + print("\n── 7. Prefill logits ──") + try: + py_prefill = load_npy(os.path.join(py_dir, "prefill_logits.npy")) + sw_prefill = load_npy(os.path.join(swift_dir, "swift_prefill_logits.npy")) + # Compare only last position + py_last = py_prefill[:, -1, :] + sw_last = sw_prefill[:, -1, :] + cs = compare("prefill_logits (last pos)", py_last, sw_last) + results.append(("prefill_logits_last", cs)) + + # Check staged_first + py_top = int(mx.argmax(mx.array(py_last[0]), axis=-1)) + sw_top = int(mx.argmax(mx.array(sw_last[0]), axis=-1)) + print(f" staged_first: Python={py_top}, Swift={sw_top} {'✅' if py_top == sw_top else '❌'}") + except Exception as e: + print(f" ⚠️ Could not compare prefill_logits: {e}") + + # ── Summary ── + print("\n═══════════════════════════════════════════════════════════════════") + print(" SUMMARY") + print("═══════════════════════════════════════════════════════════════════") + + if not results: + print(" No comparisons made!") + return + + # Sort by cosine similarity (worst first) + results.sort(key=lambda x: x[1]) + + print("\n Divergence ranking (worst → best):") + for name, cs in results: + bar = "█" * int(cs * 40) + status = "✅" if cs > 0.99 else "⚠️" if cs > 0.95 else "❌" + print(f" {status} {name:45s} cos={cs:.6f} {bar}") + + worst_name, worst_cs = results[0] + if worst_cs < 0.95: + print(f"\n 🔍 BIGGEST DIVERGENCE: {worst_name} (cos={worst_cs:.6f})") + print(f" This is the first place to investigate!") + elif worst_cs < 0.99: + print(f"\n ⚠️ Small divergence at: {worst_name} (cos={worst_cs:.6f})") + else: + print(f"\n ✅ All comparisons >0.99 cosine similarity!") + +if __name__ == "__main__": + main() diff --git a/tests/DFlash/dump_python_intermediates.py b/tests/DFlash/dump_python_intermediates.py new file mode 100644 index 00000000..656a5d1f --- /dev/null +++ b/tests/DFlash/dump_python_intermediates.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +"""Dump Python DFlash intermediate values for cross-language comparison. + +Outputs .npy files and a _meta.json with token IDs and scalar values. +Run: python3 dump_python_intermediates.py +""" +import json +import os +import sys +import numpy as np +import mlx.core as mx + +OUT_DIR = os.path.dirname(os.path.abspath(__file__)) + "/intermediates" +os.makedirs(OUT_DIR, exist_ok=True) + +# ── Patch hooks out so we compare bare numerical paths ── +import dflash_mlx.runtime as rt +rt._install_target_speculative_hooks = lambda *a, **kw: None + +from dflash_mlx.runtime import ( + load_target_bundle, load_draft_bundle, resolve_model_ref, + _target_embed_tokens, _lm_head_logits, greedy_tokens_with_mask, + _verify_target_block, make_target_cache, +) +from dflash_mlx.model import ContextOnlyDraftKVCache + +mx.set_cache_limit(mx.device_info()["max_recommended_working_set_size"] // 4) + +PROMPT = "Hello" +BLOCK_LEN = 16 +USE_CHAT_TEMPLATE = True + +def save(name: str, arr: mx.array): + # Convert MLX array to numpy via float32 to avoid bfloat16 issues + # For integer arrays, cast to int32 first + if mx.issubdtype(arr.dtype, mx.integer): + np_arr = np.array(arr.astype(mx.int32), copy=True) + else: + np_arr = np.array(arr.astype(mx.float32), copy=True) + np.save(f"{OUT_DIR}/{name}.npy", np_arr) + print(f" saved {name}: shape={arr.shape} dtype={arr.dtype}") + +def main(): + print("Loading models …") + target_model, tokenizer, _ = load_target_bundle( + resolve_model_ref("mlx-community/Qwen3.5-27B-4bit", kind="target"), + lazy=True, split_full_attention_sdpa=False, + ) + draft_model, _ = load_draft_bundle( + resolve_model_ref("z-lab/Qwen3.5-27B-DFlash", kind="draft"), + lazy=True, + ) + + # ── 1. Prompt tokens ── + from dflash_mlx.runtime import _prepare_prompt_tokens + prompt_tokens = _prepare_prompt_tokens(tokenizer, PROMPT, use_chat_template=USE_CHAT_TEMPLATE) + print(f"Prompt tokens ({len(prompt_tokens)}): {prompt_tokens}") + + # ── 2. Target prefill ── + target_cache = make_target_cache(target_model, enable_speculative_linear_cache=True) + target_layer_ids = list(draft_model.target_layer_ids) + capture_layer_ids = {int(lid) + 1 for lid in target_layer_ids} + + logits, hidden_states = _verify_target_block( + target_model=target_model, + verify_ids=mx.array(prompt_tokens, dtype=mx.uint32)[None], + target_cache=target_cache, + verify_chunk_tokens=None, + capture_layer_ids=capture_layer_ids, + ) + mx.eval(logits, *hidden_states.values()) + + save("prefill_logits", logits) + for lid in capture_layer_ids: + save(f"hidden_layer_{lid}", hidden_states[lid]) + + # ── 3. Extract context feature ── + selected = [hidden_states[layer_id + 1] for layer_id in target_layer_ids] + target_hidden = mx.concatenate(selected, axis=-1) + save("target_hidden", target_hidden) + + # ── 4. staged_first ── + staged_first = greedy_tokens_with_mask(logits[:, -1, :], None) + staged_first_id = int(staged_first.item()) + print(f"staged_first = {staged_first_id} = {repr(tokenizer.decode([staged_first_id]))}") + + # ── 5. Draft model inputs ── + mask_token_id = int(draft_model.mask_token_id) + block_token_ids = mx.concatenate( + [staged_first[:1], mx.full((BLOCK_LEN - 1,), mask_token_id, dtype=mx.uint32)] + ) + noise_embedding = _target_embed_tokens(target_model)(block_token_ids[None]) + save("noise_embedding", noise_embedding) + save("block_token_ids", block_token_ids[None]) + + # ── 6. Draft model: project target hidden ── + projected_hidden = draft_model._project_target_hidden(target_hidden) + save("projected_hidden", projected_hidden) + + # ── 7. Draft model: layer-by-layer ── + draft_cache = [ContextOnlyDraftKVCache() for _ in range(len(draft_model.layers))] + hidden_states_draft = noise_embedding + + for i, (layer, layer_cache) in enumerate(zip(draft_model.layers, draft_cache)): + # input layernorm + h = layer.input_layernorm(hidden_states_draft) + save(f"draft_layer{i}_after_input_ln", h) + + # attention + h = layer.self_attn(h, target_hidden=projected_hidden, cache=layer_cache) + save(f"draft_layer{i}_after_attn", h) + + # residual + attention + h = hidden_states_draft + h + save(f"draft_layer{i}_after_attn_residual", h) + + # post-attention layernorm + r = h + h = layer.post_attention_layernorm(h) + save(f"draft_layer{i}_after_post_ln", h) + + # MLP + h = layer.mlp(h) + save(f"draft_layer{i}_after_mlp", h) + + # final residual + hidden_states_draft = r + h + save(f"draft_layer{i}_output", hidden_states_draft) + + # ── 8. Final norm + logits ── + draft_final = draft_model.norm(hidden_states_draft) + save("draft_final_normed", draft_final) + + draft_logits = _lm_head_logits(target_model, draft_final[:, 1:, :]) + save("draft_logits", draft_logits) + + drafted = greedy_tokens_with_mask(draft_logits, None) + drafted_list = drafted.tolist() + if isinstance(drafted_list[0], list): + drafted_list = drafted_list[0] + print(f"drafted tokens: {drafted_list[:5]}") + print(f"drafted text: {repr(tokenizer.decode(drafted_list[:5]))}") + + # ── 9. Verify logits (target forward on draft tokens) ── + verify_ids = mx.concatenate([staged_first[:1], drafted[0, :BLOCK_LEN - 1]], axis=0)[None] + save("verify_ids", verify_ids) + + # ── Meta ── + meta = { + "prompt_tokens": prompt_tokens, + "staged_first": staged_first_id, + "mask_token_id": mask_token_id, + "block_len": BLOCK_LEN, + "target_layer_ids": target_layer_ids, + "capture_layer_ids": list(capture_layer_ids), + "drafted_tokens": drafted_list, + } + with open(f"{OUT_DIR}/_meta.json", "w") as f: + json.dump(meta, f, indent=2) + print(f"Meta saved to {OUT_DIR}/_meta.json") + +if __name__ == "__main__": + main() diff --git a/tests/test-dflash.sh b/tests/test-dflash.sh new file mode 100755 index 00000000..92a4a6df --- /dev/null +++ b/tests/test-dflash.sh @@ -0,0 +1,254 @@ +#!/bin/bash +# test-speculative.sh — Speculative decoding E2E verification +# +# Uses a small draft model (Qwen3.5-0.8B) to accelerate a larger main model +# (Qwen3.5-4B) via speculative decoding. Verifies: +# 1. Dual-model loading (draft + main) +# 2. Speculative decoding path activation +# 3. Correct token generation +# 4. Server stability under dual-model memory pressure +# +# Usage: +# ./tests/test-speculative.sh [binary_path] [port] +# +# Requirements: +# - ~4 GB RAM (0.8B draft ~1 GB + 4B main ~3 GB) +# - macos-15 (7 GB) on GitHub Actions is sufficient +# - curl, jq + +set -euo pipefail + +BINARY="${1:-.build/release/SwiftLM}" +PORT="${2:-15414}" +HOST="127.0.0.1" +MAIN_MODEL="${MAIN_MODEL:-mlx-community/Qwen3.5-4B-4bit}" +DRAFT_MODEL="${DRAFT_MODEL:-z-lab/Qwen3.5-4B-DFlash}" +NUM_DRAFT_TOKENS=16 +URL="http://${HOST}:${PORT}" +PASS=0 +FAIL=0 +TOTAL=0 +LOG_FILE="/tmp/SwiftLM-test-dflash.log" + +# Colors +GREEN='\033[0;32m' +RED='\033[0;31m' +YELLOW='\033[1;33m' +CYAN='\033[0;36m' +NC='\033[0m' + +log() { echo -e "${YELLOW}[dflash-test]${NC} $*"; } +pass() { PASS=$((PASS + 1)); TOTAL=$((TOTAL + 1)); echo -e " ${GREEN}✅ PASS${NC}: $*"; } +fail() { FAIL=$((FAIL + 1)); TOTAL=$((TOTAL + 1)); echo -e " ${RED}❌ FAIL${NC}: $*"; } + +cleanup() { + if [ -n "${SERVER_PID:-}" ]; then + log "Stopping server (PID $SERVER_PID)" + kill -9 "$SERVER_PID" 2>/dev/null || true + wait "$SERVER_PID" 2>/dev/null || true + fi +} +trap cleanup EXIT + +# ── Check prerequisites ───────────────────────────────────────────── +if [ ! -f "$BINARY" ]; then + echo "Error: Binary not found at $BINARY" + echo "Run 'swift build -c release' first." + exit 1 +fi + +if ! command -v jq &>/dev/null; then + echo "Error: jq is required. Install with: brew install jq" + exit 1 +fi + +# ── Memory check ──────────────────────────────────────────────────── +TOTAL_RAM_GB=$(sysctl -n hw.memsize 2>/dev/null | awk '{printf "%.0f", $1 / 1073741824}') +log "System RAM: ${TOTAL_RAM_GB} GB" + +# On low-RAM machines (< 12 GB), the combined main + draft model weights +# (~6 GB) exceed available memory after OS reservation. Without SSD +# streaming, all weights must be GPU-resident or swapped via macOS VM, +# which causes Metal command buffers to exceed Apple's 5-second GPU +# Watchdog timeout → Abort trap: 6. +# +# Fix: enable --stream-experts on low-RAM machines. This uses mmap-based +# weight loading (pread from SSD via the OS page cache) so the GPU never +# stalls waiting for swap. Draft tokens are auto-capped to 1 server-side +# to minimise SSD I/O fan-out during the verify pass. +EXTRA_FLAGS="" +if [ "$TOTAL_RAM_GB" -lt 12 ] 2>/dev/null; then + log "⚠️ ${TOTAL_RAM_GB} GB RAM: enabling --stream-experts for SSD-backed weight paging" + log " Combined model weights (~6 GB) exceed available RAM. SSD streaming prevents" + log " Metal GPU Watchdog timeouts during DFlash verify passes." + EXTRA_FLAGS="--stream-experts" + NUM_DRAFT_TOKENS=1 # auto-capped server-side too, but be explicit +fi + +# ══════════════════════════════════════════════════════════════════════ +echo -e "\n${CYAN}╔══════════════════════════════════════════════════════════╗${NC}" +echo -e "${CYAN}║ SwiftLM DFlash Speculative Decoding E2E Test ║${NC}" +echo -e "${CYAN}║ Draft: Qwen3.5-4B-DFlash → Main: Qwen3.5-4B-4bit ║${NC}" +echo -e "${CYAN}║ Draft tokens per round: ${NUM_DRAFT_TOKENS} ║${NC}" +echo -e "${CYAN}╚══════════════════════════════════════════════════════════╝${NC}\n" + +# ── Start server with dual models ─────────────────────────────────── +log "Starting server with DFlash speculative decoding..." +log " Main model: $MAIN_MODEL" +log " Draft model: $DRAFT_MODEL" +log " Draft tokens per round: $NUM_DRAFT_TOKENS" +if [ -n "$EXTRA_FLAGS" ]; then + log " Extra flags: $EXTRA_FLAGS" +fi + +"$BINARY" --model "$MAIN_MODEL" --port "$PORT" --host "$HOST" \ + --draft-model "$DRAFT_MODEL" \ + --num-draft-tokens "$NUM_DRAFT_TOKENS" \ + --dflash $EXTRA_FLAGS \ + > "$LOG_FILE" 2>&1 & +SERVER_PID=$! + +# Wait for server to be ready (both models need to download + load) +log "Waiting for server to load both models (this may take a while on first run)..." +MAX_WAIT=900 # 15 minutes for two model downloads +for i in $(seq 1 "$MAX_WAIT"); do + if curl -sf "$URL/health" >/dev/null 2>&1; then + log "Server ready after ${i}s" + break + fi + if ! kill -0 "$SERVER_PID" 2>/dev/null; then + echo "Error: Server process died. Server Log:" + cat "$LOG_FILE" + exit 1 + fi + # Print progress every 30 seconds + if [ $((i % 30)) -eq 0 ]; then + log " Still waiting... (${i}s elapsed)" + fi + sleep 1 +done + +if ! curl -sf "$URL/health" >/dev/null 2>&1; then + echo "Error: Server did not become ready in ${MAX_WAIT}s" + echo "Server Log:" + cat "$LOG_FILE" + exit 1 +fi + +# ── Test 1: Verify server loaded both models ──────────────────────── +log "Test 1: Verify dual-model loading" + +# Check server log for draft model loading confirmation +if grep -q "Draft model loaded successfully" "$LOG_FILE"; then + pass "Draft model loaded successfully" +else + fail "Draft model loading not confirmed in server logs" +fi + +if grep -q "speculative decoding" "$LOG_FILE"; then + pass "Speculative decoding mode detected in server logs" +else + fail "Speculative decoding not mentioned in server logs" +fi + +# ── Test 2: Health endpoint works with dual models ────────────────── +log "Test 2: Health endpoint" + +HEALTH=$(curl -sf "$URL/health") +if echo "$HEALTH" | jq -e '.status == "ok"' >/dev/null 2>&1; then + pass "Health endpoint returns status=ok" +else + fail "Health endpoint: $HEALTH" +fi + +# ── Test 3: Streaming speculative generation ──────────────────────── +log "Test 3: Streaming speculative generation" + +STREAM_OUTPUT=$(curl -sf -N --max-time 120 -X POST "$URL/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "{\"model\":\"$MAIN_MODEL\",\"stream\":true,\"max_tokens\":30,\"messages\":[{\"role\":\"user\",\"content\":\"Name three fruits.\"}]}" \ + 2>/dev/null || true) + +if echo "$STREAM_OUTPUT" | grep -q "data: \[DONE\]"; then + pass "Streaming speculative: received [DONE] sentinel" +else + fail "Streaming speculative: missing [DONE] sentinel" +fi + +CHUNK_COUNT=$(echo "$STREAM_OUTPUT" | grep -c "^data: {" || true) +if [ "$CHUNK_COUNT" -gt 0 ]; then + pass "Streaming speculative: received $CHUNK_COUNT data chunks" +else + fail "Streaming speculative: no data chunks received" +fi + +# Check server log for speculative decoding activation +if grep -q "Using speculative decoding" "$LOG_FILE"; then + pass "Speculative decoding path activated during generation" +else + fail "Speculative decoding path not activated (missing log line)" +fi + +# ── Test 5: Multiple sequential requests (stability) ──────────────── +log "Test 5: Sequential request stability (3 requests)" + +SEQ_PASS=true +for i in 1 2 3; do + SEQ_RESP=$(curl -sf --max-time 120 -X POST "$URL/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "{\"model\":\"$MAIN_MODEL\",\"max_tokens\":10,\"messages\":[{\"role\":\"user\",\"content\":\"Say the number $i.\"}]}" 2>/dev/null || echo "") + + SEQ_CONTENT=$(echo "$SEQ_RESP" | jq -r '.choices[0].message.content // empty' 2>/dev/null || echo "") + + if [ -z "$SEQ_CONTENT" ]; then + SEQ_PASS=false + fail "Sequential request $i: empty response" + break + fi +done + +if [ "$SEQ_PASS" = true ]; then + pass "Sequential stability: 3/3 speculative requests completed successfully" +fi + +# ── Test 6: Memory stability check ───────────────────────────────── +log "Test 6: Memory stability" + +HEALTH_FINAL=$(curl -sf "$URL/health") +MEM_ACTIVE=$(echo "$HEALTH_FINAL" | jq -r '.memory.active_mb // 0') +MEM_PEAK=$(echo "$HEALTH_FINAL" | jq -r '.memory.peak_mb // 0') + +if [ "$MEM_ACTIVE" -gt 0 ] 2>/dev/null; then + pass "Memory: active=${MEM_ACTIVE} MB, peak=${MEM_PEAK} MB" +else + fail "Memory: could not read memory stats" +fi + +# Verify server is still responsive after all tests +if curl -sf "$URL/health" >/dev/null 2>&1; then + pass "Server still responsive after all speculative decoding tests" +else + fail "Server became unresponsive" +fi + +# ── Results ────────────────────────────────────────────────────────── +echo "" +log "═══════════════════════════════════════" +log "Speculative Decoding Test Results" +log " Draft: $DRAFT_MODEL" +log " Main: $MAIN_MODEL" +log " Tokens/round: $NUM_DRAFT_TOKENS" +log " Results: ${PASS} passed, ${FAIL} failed, ${TOTAL} total" +log "═══════════════════════════════════════" + +if [ "$FAIL" -gt 0 ]; then + echo "" + log "Server completely failed. Full Log:" + cat "$LOG_FILE" + exit 1 +fi + +echo "" +log "Server log tail (last 50 lines):" +tail -50 "$LOG_FILE" +exit 0