diff --git a/Libraries/MLXLMCommon/LanguageModel.swift b/Libraries/MLXLMCommon/LanguageModel.swift index 01b551fca..572af280d 100644 --- a/Libraries/MLXLMCommon/LanguageModel.swift +++ b/Libraries/MLXLMCommon/LanguageModel.swift @@ -61,6 +61,7 @@ public struct LMInput { public let text: Text public let image: ProcessedImage? public let video: ProcessedVideo? + public let audio: ProcessedAudio? /// Representation of tokenized input text. public struct Text { @@ -120,17 +121,33 @@ public struct LMInput { } } + /// Representation of prepared input audio. + public struct ProcessedAudio { + + /// Mel spectrogram features, shape [batch, frames, melBins] or [frames, melBins]. + public let features: MLXArray + /// Optional attention mask indicating padding frames (True = padding). + public let mask: MLXArray? + + public init(features: MLXArray, mask: MLXArray? = nil) { + self.features = features + self.mask = mask + } + } + public init(tokens: MLXArray, mask: MLXArray? = nil) { self.init(text: .init(tokens: tokens, mask: mask)) } public init( text: LMInput.Text, image: LMInput.ProcessedImage? = nil, - video: LMInput.ProcessedVideo? = nil + video: LMInput.ProcessedVideo? = nil, + audio: LMInput.ProcessedAudio? = nil ) { self.text = text self.image = image self.video = video + self.audio = audio } } diff --git a/Libraries/MLXLMCommon/UserInput.swift b/Libraries/MLXLMCommon/UserInput.swift index 8aac26e71..eae3f00d7 100644 --- a/Libraries/MLXLMCommon/UserInput.swift +++ b/Libraries/MLXLMCommon/UserInput.swift @@ -177,6 +177,9 @@ public struct UserInput { /// collect the videos from the chat messages, otherwise these are the stored videos with the ``UserInput``. public var videos = [Video]() + /// Audio inputs as raw PCM float arrays (16 kHz mono expected by Gemma 4). + public var audios = [[Float]]() + public var tools: [ToolSpec]? /// Additional values provided for the chat template rendering context diff --git a/Libraries/MLXVLM/Models/Gemma4.swift b/Libraries/MLXVLM/Models/Gemma4.swift index cac757f98..e1444a222 100644 --- a/Libraries/MLXVLM/Models/Gemma4.swift +++ b/Libraries/MLXVLM/Models/Gemma4.swift @@ -384,9 +384,69 @@ public struct Gemma4VisionConfiguration: Codable, Sendable { } } +public struct Gemma4AudioConfiguration: Codable, Sendable { + public let hiddenSize: Int + public let numHiddenLayers: Int + public let numAttentionHeads: Int + public let subsamplingConvChannels: [Int] + public let convKernelSize: Int + public let residualWeight: Float + public let attentionChunkSize: Int + public let attentionContextLeft: Int + public let attentionContextRight: Int + public let attentionLogitCap: Float + public let attentionInvalidLogitsValue: Float + public let useClippedLinears: Bool + public let rmsNormEps: Float + public let gradientClipping: Float + public let outputProjDims: Int? + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case numHiddenLayers = "num_hidden_layers" + case numAttentionHeads = "num_attention_heads" + case subsamplingConvChannels = "subsampling_conv_channels" + case convKernelSize = "conv_kernel_size" + case residualWeight = "residual_weight" + case attentionChunkSize = "attention_chunk_size" + case attentionContextLeft = "attention_context_left" + case attentionContextRight = "attention_context_right" + case attentionLogitCap = "attention_logit_cap" + case attentionInvalidLogitsValue = "attention_invalid_logits_value" + case useClippedLinears = "use_clipped_linears" + case rmsNormEps = "rms_norm_eps" + case gradientClipping = "gradient_clipping" + case outputProjDims = "output_proj_dims" + } + + public init(from decoder: any Swift.Decoder) throws { + let c = try decoder.container(keyedBy: CodingKeys.self) + hiddenSize = try c.decodeIfPresent(Int.self, forKey: .hiddenSize) ?? 1024 + numHiddenLayers = try c.decodeIfPresent(Int.self, forKey: .numHiddenLayers) ?? 12 + numAttentionHeads = try c.decodeIfPresent(Int.self, forKey: .numAttentionHeads) ?? 8 + subsamplingConvChannels = + try c.decodeIfPresent([Int].self, forKey: .subsamplingConvChannels) ?? [128, 32] + convKernelSize = try c.decodeIfPresent(Int.self, forKey: .convKernelSize) ?? 5 + residualWeight = try c.decodeIfPresent(Float.self, forKey: .residualWeight) ?? 0.5 + attentionChunkSize = try c.decodeIfPresent(Int.self, forKey: .attentionChunkSize) ?? 12 + attentionContextLeft = try c.decodeIfPresent(Int.self, forKey: .attentionContextLeft) ?? 13 + attentionContextRight = try c.decodeIfPresent(Int.self, forKey: .attentionContextRight) ?? 0 + attentionLogitCap = try c.decodeIfPresent(Float.self, forKey: .attentionLogitCap) ?? 50.0 + attentionInvalidLogitsValue = + try c.decodeIfPresent(Float.self, forKey: .attentionInvalidLogitsValue) ?? -1e9 + useClippedLinears = + try c.decodeIfPresent(Bool.self, forKey: .useClippedLinears) ?? true + rmsNormEps = try c.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1e-6 + gradientClipping = + try c.decodeIfPresent(Float.self, forKey: .gradientClipping) ?? 1e10 + outputProjDims = try c.decodeIfPresent(Int.self, forKey: .outputProjDims) + } +} + public struct Gemma4Configuration: Codable, Sendable { public let textConfiguration: Gemma4TextConfiguration public let visionConfiguration: Gemma4VisionConfiguration + public let audioConfiguration: Gemma4AudioConfiguration? public let modelType: String public let quantization: BaseConfiguration.Quantization? public let imageTokenId: Int @@ -407,6 +467,7 @@ public struct Gemma4Configuration: Codable, Sendable { enum CodingKeys: String, CodingKey { case textConfiguration = "text_config" case visionConfiguration = "vision_config" + case audioConfiguration = "audio_config" case modelType = "model_type" case quantization case imageTokenId = "image_token_id" @@ -426,6 +487,8 @@ public struct Gemma4Configuration: Codable, Sendable { Gemma4TextConfiguration.self, forKey: CodingKeys.textConfiguration) visionConfiguration = try c.decode( Gemma4VisionConfiguration.self, forKey: CodingKeys.visionConfiguration) + audioConfiguration = try c.decodeIfPresent( + Gemma4AudioConfiguration.self, forKey: CodingKeys.audioConfiguration) modelType = try c.decodeIfPresent(String.self, forKey: CodingKeys.modelType) ?? "gemma4" quantization = try c.decodeIfPresent( BaseConfiguration.Quantization.self, forKey: CodingKeys.quantization) @@ -1641,12 +1704,719 @@ private final class Gemma4MultimodalEmbedder: Module, UnaryLayer { } } +// MARK: - Audio + +private final class Gemma4AudioRMSNorm: Module, UnaryLayer { + let eps: Float + @ModuleInfo var weight: MLXArray + + init(dimensions: Int, eps: Float = 1e-6) { + self.eps = eps + self._weight.wrappedValue = MLXArray.ones([dimensions]) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + MLXFast.rmsNorm(x, weight: weight, eps: eps) + } +} + +private final class Gemma4AudioClippableLinear: Module, UnaryLayer { + @ModuleInfo(key: "linear") var linear: Linear + @ModuleInfo(key: "input_min") var inputMin: MLXArray? + @ModuleInfo(key: "input_max") var inputMax: MLXArray? + @ModuleInfo(key: "output_min") var outputMin: MLXArray? + @ModuleInfo(key: "output_max") var outputMax: MLXArray? + let useClipping: Bool + + init(inFeatures: Int, outFeatures: Int, bias: Bool = false, useClipping: Bool = true) { + self.useClipping = useClipping + self._linear.wrappedValue = Linear(inFeatures, outFeatures, bias: bias) + if useClipping { + self._inputMin.wrappedValue = MLXArray(-Float.infinity) + self._inputMax.wrappedValue = MLXArray(Float.infinity) + self._outputMin.wrappedValue = MLXArray(-Float.infinity) + self._outputMax.wrappedValue = MLXArray(Float.infinity) + } + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let clippedInput: MLXArray + if let inputMin, let inputMax { + clippedInput = clip(x, min: inputMin, max: inputMax) + } else { + clippedInput = x + } + let projected = linear(clippedInput) + if let outputMin, let outputMax { + return clip(projected, min: outputMin, max: outputMax) + } + return projected + } +} + +/// LayerNorm without bias, matching `nn.LayerNorm(dims, bias=False)` in the Python model. +/// The checkpoint stores a single `weight` parameter at the `norm` key. +private final class Gemma4AudioLayerNorm: Module, UnaryLayer { + @ModuleInfo var weight: MLXArray + let eps: Float + + init(dimensions: Int, eps: Float = 1e-6) { + self.eps = eps + self._weight.wrappedValue = MLXArray.ones([dimensions]) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let xFloat = x.asType(.float32) + let meanVal = MLX.mean(xFloat, axis: -1, keepDims: true) + let variance = MLX.mean((xFloat - meanVal).square(), axis: -1, keepDims: true) + let normalized = (xFloat - meanVal) * rsqrt(variance + eps) + return (normalized * weight.asType(.float32)).asType(x.dtype) + } +} + +private final class Gemma4SSCPConvBlock: Module { + let timeStride: Int = 2 + let padding: (Int, Int, Int, Int) = (1, 1, 1, 1) + + @ModuleInfo(key: "conv") var conv: Conv2d + @ModuleInfo(key: "norm") var norm: Gemma4AudioLayerNorm + + init(config: Gemma4AudioConfiguration, idx: Int) { + let inChannels = idx == 0 ? 1 : config.subsamplingConvChannels[idx - 1] + let outChannels = config.subsamplingConvChannels[idx] + + // Conv2d: MLX expects [B, H, W, C], weight [C_out, kH, kW, C_in] + self._conv.wrappedValue = Conv2d( + inputChannels: inChannels, + outputChannels: outChannels, + kernelSize: 3, + stride: 2, + padding: 0, + bias: false + ) + + self._norm.wrappedValue = Gemma4AudioLayerNorm( + dimensions: outChannels, eps: config.rmsNormEps) + super.init() + } + + func callAsFunction(_ x: MLXArray, mask: MLXArray) -> (MLXArray, MLXArray) { + // x: [B, T, F, C] (MLX channel-last) + // mask: [B, T] (True = invalid/padding) + + // Zero out invalid positions + var x = MLX.where( + expandedDimensions(expandedDimensions(mask, axis: -1), axis: -1), + MLXArray(0.0, dtype: x.dtype), x) + + // Manual padding on T and F dims + x = MLX.padded( + x, + widths: [ + .init((0, 0)), .init((padding.0, padding.1)), + .init((padding.2, padding.3)), .init((0, 0)), + ]) + + x = conv(x) // [B, T_out, F_out, C_out] + + // Downsample mask by time stride + let tOut = x.dim(1) + let downsampled = mask[0..., .stride(by: timeStride)] + let outputMask = downsampled[0..., .. (MLXArray, MLXArray) { + // audioMel: [B, T, F_in] + // Add channel dim: [B, T, F, 1] + var x = expandedDimensions(audioMel, axis: -1) + + var currentMask = mask + (x, currentMask) = layer0(x, mask: currentMask) + (x, currentMask) = layer1(x, mask: currentMask) + + // Flatten F*C -> [B, T, F*C] + let batchSize = x.dim(0) + let timeSteps = x.dim(1) + let freqBins = x.dim(2) + let channels = x.dim(3) + x = x.reshaped(batchSize, timeSteps, freqBins * channels) + + // Project to hidden_size + x = inputProjLinear(x) + + return (x, currentMask) + } +} + +private final class Gemma4ConformerFeedForward: Module { + let gradientClipping: Float + let residualWeight: Float + + @ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma4AudioRMSNorm + @ModuleInfo(key: "ffw_layer_1") var ffwLayer1: Gemma4AudioClippableLinear + @ModuleInfo(key: "ffw_layer_2") var ffwLayer2: Gemma4AudioClippableLinear + @ModuleInfo(key: "post_layer_norm") var postLayerNorm: Gemma4AudioRMSNorm + + init(config: Gemma4AudioConfiguration) { + self.gradientClipping = config.gradientClipping + self.residualWeight = config.residualWeight + + self._preLayerNorm.wrappedValue = Gemma4AudioRMSNorm(dimensions: config.hiddenSize) + self._ffwLayer1.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize * 4, + useClipping: config.useClippedLinears) + self._ffwLayer2.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize * 4, outFeatures: config.hiddenSize, + useClipping: config.useClippedLinears) + self._postLayerNorm.wrappedValue = Gemma4AudioRMSNorm(dimensions: config.hiddenSize) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let residual = x + var h = clip(x, min: -gradientClipping, max: gradientClipping) + h = preLayerNorm(h) + h = ffwLayer1(h) + h = silu(h) + h = ffwLayer2(h) + h = clip(h, min: -gradientClipping, max: gradientClipping) + h = postLayerNorm(h) + return residual + h * residualWeight + } +} + +private final class Gemma4AudioRelativePositionEmbedding: Module { + let numHeads: Int + let channels: Int + let headDim: Int + let maxBackward: Int + let maxForward: Int + let invTimescales: MLXArray + + @ModuleInfo(key: "pos_proj") var posProj: Linear + + init(config: Gemma4AudioConfiguration) { + self.numHeads = config.numAttentionHeads + self.channels = config.hiddenSize + self.headDim = config.hiddenSize / config.numAttentionHeads + self.maxBackward = max(0, config.attentionContextLeft - 1) + self.maxForward = config.attentionContextRight + + self._posProj.wrappedValue = Linear( + config.hiddenSize, config.numAttentionHeads * headDim, bias: false) + + let minTimescale: Float = 1.0 + let maxTimescale: Float = 10000.0 + let numTimescales = config.hiddenSize / 2 + let logTimescaleIncrement = + Foundation.log(maxTimescale / minTimescale) / Float(max(numTimescales - 1, 1)) + self.invTimescales = + MLXArray(minTimescale) + * MLX.exp(MLXArray(0 ..< numTimescales).asType(.float32) * (-logTimescaleIncrement)) + + super.init() + } + + private func getTimingSignal(_ position: MLXArray, dtype: DType) -> MLXArray { + let posFloat = position.asType(.float32) + let pos = expandedDimensions(posFloat, axis: -1) + let invTS = invTimescales.reshaped(1, 1, -1) + let scaledTime = pos * invTS + let signal = concatenated([sin(scaledTime), cos(scaledTime)], axis: -1) + return signal.asType(dtype) + } + + private func relativeShift( + _ termBD: MLXArray, batchSize: Int, numHeads: Int, numBlocks: Int, + blockSize: Int, contextSize: Int, maxSpanPlus1: Int + ) -> MLXArray { + let padAmount = (contextSize + 1) - maxSpanPlus1 + var shifted = MLX.padded( + termBD, + widths: [ + .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, padAmount)), + ]) + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize * (contextSize + 1)) + shifted = shifted[0..., 0..., 0..., ..<(blockSize * contextSize)] + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize, contextSize) + return shifted + } + + func callAsFunction(queries: MLXArray, keys: MLXArray) -> MLXArray { + // queries: [B, U, W, N, H], keys: [B, U, C, N, H] + let batchSize = queries.dim(0) + let numBlocks = queries.dim(1) + let blockSize = queries.dim(2) + let contextSize = keys.dim(2) + + let posIndices = MLXArray( + stride(from: maxBackward, through: -maxForward, by: -1).map { Int32($0) } + ) + .reshaped(1, -1) + let maxSpanPlus1 = posIndices.dim(1) + + var sinEmb = getTimingSignal(posIndices, dtype: queries.dtype) + sinEmb = posProj(sinEmb.asType(posProj.weight.dtype)) + sinEmb = sinEmb.reshaped(maxSpanPlus1, numHeads, headDim) + sinEmb = sinEmb.asType(queries.dtype) + + // queries_p: [B, N, U, W, H], keys_p: [B, N, U, H, C] + let queriesP = queries.transposed(0, 3, 1, 2, 4) + let keysP = keys.transposed(0, 3, 1, 4, 2) + let termAC = queriesP.matmul(keysP) + + // sin_emb_t: [N, H, maxSpan] + let sinEmbT = sinEmb.transposed(1, 2, 0) + let qReshaped = queriesP.reshaped(batchSize, numHeads, numBlocks * blockSize, headDim) + var termBD = qReshaped.matmul(sinEmbT).reshaped( + batchSize, numHeads, numBlocks, blockSize, maxSpanPlus1) + + termBD = relativeShift( + termBD, batchSize: batchSize, numHeads: numHeads, numBlocks: numBlocks, + blockSize: blockSize, contextSize: contextSize, maxSpanPlus1: maxSpanPlus1) + + return termAC + termBD + } +} + +private final class Gemma4AudioAttention: Module { + let numHeads: Int + let hiddenSize: Int + let headDim: Int + let chunkSize: Int + let maxFutureHorizon: Int + let maxPastHorizon: Int + let contextSize: Int + let invalidLogitsValue: Float + let softcap: Float + let qScale: Float + let kScale: Float + + @ModuleInfo(key: "relative_k_proj") var relativeKProj: Linear + @ParameterInfo(key: "per_dim_scale") var perDimScale: MLXArray + @ModuleInfo(key: "q_proj") var qProj: Gemma4AudioClippableLinear + @ModuleInfo(key: "k_proj") var kProj: Gemma4AudioClippableLinear + @ModuleInfo(key: "v_proj") var vProj: Gemma4AudioClippableLinear + @ModuleInfo(key: "post") var post: Gemma4AudioClippableLinear + + // Relative position embedding (inline) + // Note: relPosInvTimescales is NOT a model parameter — it's a computed constant. + // Store as [Float] to avoid MLX Module treating it as a loadable weight. + private let relPosNumHeads: Int + private let relPosHeadDim: Int + private let relPosMaxBackward: Int + private let relPosMaxForward: Int + private let relPosInvTimescalesData: [Float] + + init(config: Gemma4AudioConfiguration) { + self.numHeads = config.numAttentionHeads + self.hiddenSize = config.hiddenSize + self.headDim = config.hiddenSize / config.numAttentionHeads + self.chunkSize = config.attentionChunkSize + self.maxFutureHorizon = config.attentionContextRight + self.maxPastHorizon = max(0, config.attentionContextLeft - 1) + self.contextSize = chunkSize + maxPastHorizon + maxFutureHorizon + self.invalidLogitsValue = config.attentionInvalidLogitsValue + self.softcap = config.attentionLogitCap + + self.qScale = pow(Float(headDim), -0.5) / Foundation.log(2.0) + self.kScale = Foundation.log(1 + Foundation.exp(1.0)) / Foundation.log(2.0) + + self._relativeKProj.wrappedValue = Linear( + config.hiddenSize, config.numAttentionHeads * headDim, bias: false) + self._perDimScale.wrappedValue = MLXArray.zeros([headDim]) + self._qProj.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: numHeads * headDim, + useClipping: config.useClippedLinears) + self._kProj.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: numHeads * headDim, + useClipping: config.useClippedLinears) + self._vProj.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: numHeads * headDim, + useClipping: config.useClippedLinears) + self._post.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize, + useClipping: config.useClippedLinears) + + // Relative position embedding setup + self.relPosNumHeads = numHeads + self.relPosHeadDim = headDim + self.relPosMaxBackward = maxPastHorizon + self.relPosMaxForward = maxFutureHorizon + + let minTimescale: Float = 1.0 + let maxTimescale: Float = 10000.0 + let numTimescales = config.hiddenSize / 2 + let logTimescaleIncrement = + Foundation.log(maxTimescale / minTimescale) / Float(max(numTimescales - 1, 1)) + self.relPosInvTimescalesData = (0 ..< numTimescales).map { i in + minTimescale * Foundation.exp(Float(i) * (-logTimescaleIncrement)) + } + + super.init() + } + + private func padDim1(_ x: MLXArray, padLeft: Int, padRight: Int) -> MLXArray { + var widths = Array(repeating: IntOrPair((0, 0)), count: x.ndim) + widths[1] = IntOrPair((padLeft, padRight)) + return MLX.padded(x, widths: widths) + } + + private func convertToBlock(_ x: MLXArray) -> MLXArray { + // [B, T, ...] -> [B, num_blocks, chunk_size, ...] + let batchSize = x.dim(0) + let timeSteps = x.dim(1) + let rest = Array(x.shape.dropFirst(2)) + let numBlocks = (timeSteps + chunkSize - 1) / chunkSize + let padLen = numBlocks * chunkSize - timeSteps + var result = x + if padLen > 0 { + result = padDim1(result, padLeft: 0, padRight: padLen) + } + return result.reshaped([batchSize, numBlocks, chunkSize] + rest) + } + + private func extractBlockContext(_ x: MLXArray) -> MLXArray { + // [B, T, ...] -> [B, num_blocks, context_size, ...] + let padLeft = maxPastHorizon + let padRight = maxFutureHorizon + chunkSize - 1 + let padded = padDim1(x, padLeft: padLeft, padRight: padRight) + let tPadded = padded.dim(1) + let numBlocks = (tPadded - contextSize) / chunkSize + 1 + + // Build indices: starts[:, None] + offsets[None, :] + let starts = MLXArray( + stride(from: 0, to: numBlocks * chunkSize, by: chunkSize).map { + Int32($0) + }) + let offsets = MLXArray((0 ..< contextSize).map { Int32($0) }) + let indices = expandedDimensions(starts, axis: 1) + expandedDimensions(offsets, axis: 0) + // indices: [numBlocks, contextSize] + + // Gather using advanced indexing + // padded: [B, T_padded, ...rest] + // We need padded[:, indices] which gives [B, numBlocks, contextSize, ...rest] + return padded[0..., indices] + } + + private func relPosTimingSignal(_ position: MLXArray, dtype: DType) -> MLXArray { + let posFloat = position.asType(.float32) + let pos = expandedDimensions(posFloat, axis: -1) + let invTS = MLXArray(relPosInvTimescalesData).reshaped(1, 1, -1) + let scaledTime = pos * invTS + let signal = concatenated([sin(scaledTime), cos(scaledTime)], axis: -1) + return signal.asType(dtype) + } + + private func relPosRelativeShift( + _ termBD: MLXArray, batchSize: Int, numHeads: Int, numBlocks: Int, + blockSize: Int, contextSize: Int, maxSpanPlus1: Int + ) -> MLXArray { + let padAmount = (contextSize + 1) - maxSpanPlus1 + var shifted = MLX.padded( + termBD, + widths: [ + .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, 0)), .init((0, padAmount)), + ]) + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize * (contextSize + 1)) + shifted = shifted[0..., 0..., 0..., ..<(blockSize * contextSize)] + shifted = shifted.reshaped(batchSize, numHeads, numBlocks, blockSize, contextSize) + return shifted + } + + private func computeRelativePositionLogits(queries: MLXArray, keys: MLXArray) -> MLXArray { + // queries: [B, U, W, N, H], keys: [B, U, C, N, H] + let batchSize = queries.dim(0) + let numBlocks = queries.dim(1) + let blockSize = queries.dim(2) + let ctxSize = keys.dim(2) + + let posIndices = MLXArray( + stride(from: relPosMaxBackward, through: -relPosMaxForward, by: -1).map { Int32($0) } + ).reshaped(1, -1) + let maxSpanPlus1 = posIndices.dim(1) + + var sinEmb = relPosTimingSignal(posIndices, dtype: queries.dtype) + sinEmb = relativeKProj(sinEmb.asType(relativeKProj.weight.dtype)) + sinEmb = sinEmb.reshaped(maxSpanPlus1, relPosNumHeads, relPosHeadDim) + sinEmb = sinEmb.asType(queries.dtype) + + let queriesP = queries.transposed(0, 3, 1, 2, 4) + let keysP = keys.transposed(0, 3, 1, 4, 2) + let termAC = queriesP.matmul(keysP) + + let sinEmbT = sinEmb.transposed(1, 2, 0) + let qReshaped = queriesP.reshaped( + batchSize, relPosNumHeads, numBlocks * blockSize, relPosHeadDim) + var termBD = qReshaped.matmul(sinEmbT).reshaped( + batchSize, relPosNumHeads, numBlocks, blockSize, maxSpanPlus1) + + termBD = relPosRelativeShift( + termBD, batchSize: batchSize, numHeads: relPosNumHeads, numBlocks: numBlocks, + blockSize: blockSize, contextSize: ctxSize, maxSpanPlus1: maxSpanPlus1) + + return termAC + termBD + } + + func callAsFunction( + _ hiddenStates: MLXArray, mask: MLXArray, causalValidMask: MLXArray + ) -> MLXArray { + let batchSize = hiddenStates.dim(0) + let timeSteps = hiddenStates.dim(1) + let qkvShape = [batchSize, timeSteps, numHeads, headDim] + + var q = qProj(hiddenStates).asType(.float32).reshaped(qkvShape) + var k = kProj(hiddenStates).asType(.float32).reshaped(qkvShape) + let v = vProj(hiddenStates).asType(.float32).reshaped(qkvShape) + + let pds = softplus(perDimScale) + q = q * (qScale * pds) + k = k * kScale + + let queryBlocks = convertToBlock(q) // [B, U, W, N, H] + let keyBlocks = extractBlockContext(k) // [B, U, C, N, H] + let valueBlocks = extractBlockContext(v) // [B, U, C, N, H] + let numBlocks = queryBlocks.dim(1) + + // Build validity condition + let validMask = logicalNot(mask) // True = valid + let extractedValid = extractBlockContext(validMask) // [B, U, C] + // condition: [B, 1, U, W, C] + let condition = + expandedDimensions(expandedDimensions(extractedValid, axis: 1), axis: 3) + * expandedDimensions( + expandedDimensions(expandedDimensions(causalValidMask, axis: 0), axis: 0), axis: 0) + + var logits = computeRelativePositionLogits(queries: queryBlocks, keys: keyBlocks) + logits = tanh(logits / softcap) * softcap + logits = MLX.where( + condition .> 0, logits, MLXArray(invalidLogitsValue, dtype: logits.dtype)) + + let probs = softmax(logits, axis: -1) + // context = einsum("bnuwc,bucnh->buwnh", probs, valueBlocks) + var context = einsum("bnuwc,bucnh->buwnh", probs, valueBlocks) + context = context.reshaped(batchSize, numBlocks * chunkSize, numHeads, headDim) + context = context[0..., .. [B, T, D] and post-project + context = context.reshaped(batchSize, timeSteps, numHeads * headDim) + return post(context) + } +} + +private final class Gemma4ConformerLightConv1d: Module { + let gradientClipping: Float + let causalPadding: Int + + @ModuleInfo(key: "pre_layer_norm") var preLayerNorm: Gemma4AudioRMSNorm + @ModuleInfo(key: "linear_start") var linearStart: Gemma4AudioClippableLinear + @ModuleInfo(key: "depthwise_conv1d") var depthwiseConv1d: Conv1d + @ModuleInfo(key: "conv_norm") var convNorm: Gemma4AudioRMSNorm + @ModuleInfo(key: "linear_end") var linearEnd: Gemma4AudioClippableLinear + + init(config: Gemma4AudioConfiguration) { + self.gradientClipping = config.gradientClipping + self.causalPadding = config.convKernelSize - 1 + + self._preLayerNorm.wrappedValue = Gemma4AudioRMSNorm( + dimensions: config.hiddenSize, eps: config.rmsNormEps) + self._linearStart.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize * 2, + useClipping: config.useClippedLinears) + // Depthwise conv1d: groups = hidden_size so weight shape is [out, kernel, 1] + self._depthwiseConv1d.wrappedValue = Conv1d( + inputChannels: config.hiddenSize, + outputChannels: config.hiddenSize, + kernelSize: config.convKernelSize, + stride: 1, + padding: 0, + groups: config.hiddenSize, + bias: false + ) + self._convNorm.wrappedValue = Gemma4AudioRMSNorm( + dimensions: config.hiddenSize, eps: config.rmsNormEps) + self._linearEnd.wrappedValue = Gemma4AudioClippableLinear( + inFeatures: config.hiddenSize, outFeatures: config.hiddenSize, + useClipping: config.useClippedLinears) + super.init() + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let residual = x + + var h = preLayerNorm(x) + h = linearStart(h) + + // GLU: split in half along last dim and gate + let halfDim = h.dim(-1) / 2 + let x1 = h[0..., 0..., .. MLXArray { + var h = feedForward1(x) + + // Attention with pre/post norm and residual + let residual = h + h = clip(h, min: -gradientClipping, max: gradientClipping) + h = normPreAttn(h) + h = selfAttn(h, mask: mask, causalValidMask: causalValidMask) + h = clip(h, min: -gradientClipping, max: gradientClipping) + h = residual + normPostAttn(h) + + // Zero out invalid positions before lconv1d + let validityMask = expandedDimensions(logicalNot(mask), axis: -1).asType(h.dtype) + h = h * validityMask + + h = lconv1d(h) + h = feedForward2(h) + h = clip(h, min: -gradientClipping, max: gradientClipping) + return normOut(h) + } +} + +private final class Gemma4AudioEncoder: Module { + let config: Gemma4AudioConfiguration + + @ModuleInfo(key: "subsample_conv_projection") var subsampleConvProjection: + Gemma4SubSampleConvProjection + @ModuleInfo(key: "layers") var layers: [Gemma4ConformerBlock] + @ModuleInfo(key: "output_proj") var outputProj: Linear? + + init(config: Gemma4AudioConfiguration) { + self.config = config + self._subsampleConvProjection.wrappedValue = Gemma4SubSampleConvProjection(config: config) + self._layers.wrappedValue = (0 ..< config.numHiddenLayers).map { _ in + Gemma4ConformerBlock(config: config) + } + if let outputProjDims = config.outputProjDims { + self._outputProj.wrappedValue = Linear( + config.hiddenSize, outputProjDims, bias: true) + } + super.init() + } + + private func buildCausalValidMask() -> MLXArray { + let chunkSize = config.attentionChunkSize + let maxFutureHorizon = config.attentionContextRight + let maxPastHorizon = max(0, config.attentionContextLeft - 1) + let upperDiagonal = maxPastHorizon + maxFutureHorizon + let ctxSize = chunkSize + maxPastHorizon + maxFutureHorizon + + let lowerCausal = tril(MLXArray.ones([ctxSize, chunkSize])).transposed() + let upperCausal = tril( + MLXArray.ones([chunkSize, ctxSize]), + k: upperDiagonal) + let maskResult = (lowerCausal * upperCausal).asType(.bool) + return maskResult + } + + func callAsFunction(_ audioMel: MLXArray, audioMelMask: MLXArray) -> (MLXArray, MLXArray) { + var (audioEncodings, currentMask) = subsampleConvProjection(audioMel, mask: audioMelMask) + + let causalValidMask = buildCausalValidMask() + + for block in layers { + audioEncodings = block( + audioEncodings, mask: currentMask, causalValidMask: causalValidMask) + } + + if let outputProj { + audioEncodings = outputProj(audioEncodings) + } + + if currentMask.dim(1) != audioEncodings.dim(1) { + let targetLen = audioEncodings.dim(1) + currentMask = currentMask[0..., ..