Skip to content
Open
7 changes: 7 additions & 0 deletions Libraries/Embedders/Bert.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.

import MLX
import MLXLMCommon
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just to pick up the quantization? Think the Embedders should not require MLXLMCommon / MLXLLM if possible. A copy of Quantization is OK. If we end up with a lot of duplication then I think MLXLMCommon might make sense.

import MLXNN

extension MLXArray {
Expand Down Expand Up @@ -196,6 +197,12 @@ public class BertModel: Module, EmbeddingModel {
result[key] = item.value
}.filter { key, _ in key != "embeddings.position_ids" }
}

public func sanitize(
weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?
) -> [String: MLXArray] {
fatalError("Bert does not support quantization")
}
}

public class DistilBertModel: BertModel {
Expand Down
8 changes: 8 additions & 0 deletions Libraries/Embedders/Configuration.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.

import Foundation
import MLXLLM

public enum StringOrNumber: Codable, Equatable, Sendable {
case string(String)
Expand Down Expand Up @@ -76,6 +77,13 @@ private class ModelTypeRegistry: @unchecked Sendable {
let model = Qwen3Model(configuration)
return model
},
"gemma3_text": {
url in
let configuration = try JSONDecoder().decode(
Gemma3TextConfiguration.self, from: Data(contentsOf: url))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to copy this config type into Embedders rather than add new linkage. Even sharing config types between models in the same library is rarely done.

let model = EmbeddingGemma(configuration)
return model
},
]

public func registerModelType(
Expand Down
111 changes: 111 additions & 0 deletions Libraries/Embedders/EmbeddingGemma.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import MLX
import MLXLLM
import MLXLMCommon
Comment on lines +2 to +3
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See elsewhere -- I think this should be done without adding linkage to additional libraries.

import MLXNN

public class EmbeddingGemma: Module, EmbeddingModel {
@ModuleInfo private var model: Gemma3TextModel
@ModuleInfo private var dense: [Module]

public let config: Gemma3TextConfiguration
public var vocabularySize: Int { config.vocabularySize }

public init(_ config: Gemma3TextConfiguration) {
self.config = config
self.model = Gemma3TextModel(config)
self.dense = [
Linear(768, 3072, bias: false), Linear(3072, 768, bias: false),
]
}

public func callAsFunction(
_ inputs: MLXArray, positionIds: MLXArray?, tokenTypeIds: MLXArray?,
attentionMask: MLXArray?
) -> EmbeddingModelOutput {
var out = model.getHiddenStates(inputs, mask: nil, cache: nil)

// mean pooling
let notPadding = inputs .!= 0
let sum = (out * notPadding[.ellipsis, .newAxis]).sum(axis: 1)
let nonMasked = notPadding.sum(axis: -1, keepDims: true)
out = sum / nonMasked

for dense in self.dense {
if let dense = dense as? Linear {
out = dense(out)
} else if let dense = dense as? QuantizedLinear {
out = dense(out)
}
}

// normalize
out = out.asType(Float32.self)
let norm = maximum(norm(out, ord: 2.0, axis: -1, keepDims: true), MLXArray(1e-6))
let pooledOutput = out / norm

return EmbeddingModelOutput(hiddenStates: out, pooledOutput: pooledOutput)
}

/// Get hidden states before the dense projection head
public func getHiddenStates(
_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil,
cache: [KVCache]? = nil
) -> MLXArray {
return model(inputs, mask: mask, cache: cache)
}

public func sanitize(
weights: [String: MLXArray],
quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization? = nil
)
-> [String: MLXArray]
{
var processedWeights = model.sanitize(
weights: weights, quantizationConfig: quantizationConfig)

// 1. Add a model. prefix to all model. weights
processedWeights = Dictionary(
uniqueKeysWithValues: processedWeights.map { key, value in
if key.hasPrefix("model.") || key.hasPrefix("lm_head.") {
return ("model.\(key)", value)
} else {
return (key, value)
}
})

// 2. Apply quantization to dense layers, if needed
let hasQuantizedDense = hasQuantizedWeights(layerPath: "dense.0", in: processedWeights)
if hasQuantizedDense {
let groupSize = quantizationConfig?.groupSize ?? 64
let bits = quantizationConfig?.bits ?? 4

quantize(model: self) { path, module in
if hasQuantizedWeights(layerPath: path, in: processedWeights) {
return (groupSize, bits)
}
return nil
}
}

return processedWeights.filter { key, _ in
!key.contains("self_attn.rotary_emb.inv_freq")
}
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
sanitize(weights: weights, quantizationConfig: nil)
}

/// Check if a layer has quantized weights
private func hasQuantizedWeights(layerPath: String, in weights: [String: MLXArray]) -> Bool {
let scalesKey = "\(layerPath).scales"
let biasesKey = "\(layerPath).biases"
let weightKey = "\(layerPath).weight"

let hasScales = weights[scalesKey] != nil
let hasBiases = weights[biasesKey] != nil
let hasWeight = weights[weightKey]?.dtype == .uint32

return hasScales && hasBiases && hasWeight
}
}
8 changes: 6 additions & 2 deletions Libraries/Embedders/EmbeddingModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import Foundation
@preconcurrency import Hub
import MLX
import MLXLMCommon
import MLXNN
import Tokenizers

Expand Down Expand Up @@ -87,8 +88,8 @@ extension Module {
}

public struct EmbeddingModelOutput {
let hiddenStates: MLXArray?
let pooledOutput: MLXArray?
public let hiddenStates: MLXArray?
public let pooledOutput: MLXArray?
}

public protocol EmbeddingModel: Module {
Expand All @@ -99,6 +100,9 @@ public protocol EmbeddingModel: Module {
) -> EmbeddingModelOutput
/// Optionally preprocess the weights and modify / remove values as needed.
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
func sanitize(
weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?
) -> [String: MLXArray]
}

extension EmbeddingModel {
Expand Down
5 changes: 4 additions & 1 deletion Libraries/Embedders/Load.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import Foundation
@preconcurrency import Hub
import MLX
import MLXLMCommon
import MLXNN
import Tokenizers

Expand Down Expand Up @@ -60,6 +61,8 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
let commonBaseConfig = try JSONDecoder().decode(
MLXLMCommon.BaseConfiguration.self, from: Data(contentsOf: configurationURL))

let modelType = ModelType(rawValue: baseConfig.modelType)
let model = try modelType.createModel(configuration: configurationURL)
Expand All @@ -78,7 +81,7 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
}

// per-model cleanup
weights = model.sanitize(weights: weights)
weights = model.sanitize(weights: weights, quantizationConfig: commonBaseConfig.quantization)

// quantize if needed
if let perLayerQuantization = baseConfig.perLayerQuantization {
Expand Down
12 changes: 12 additions & 0 deletions Libraries/Embedders/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ extension ModelConfiguration {
id: "mixedbread-ai/mxbai-embed-large-v1")
public static let qwen3_embedding = ModelConfiguration(
id: "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ")
public static let embeddinggemma_300m = ModelConfiguration(
id: "mlx-community/embeddinggemma-300m-bf16")
public static let embeddinggemma_300m_8bit = ModelConfiguration(
id: "mlx-community/embeddinggemma-300m-8bit")
public static let embeddinggemma_300m_6bit = ModelConfiguration(
id: "mlx-community/embeddinggemma-300m-6bit")
public static let embeddinggemma_300m_4bit = ModelConfiguration(
id: "mlx-community/embeddinggemma-300m-4bit")

private enum BootstrapState: Sendable {
case idle
Expand Down Expand Up @@ -141,6 +149,10 @@ extension ModelConfiguration {
bge_m3,
mixedbread_large,
qwen3_embedding,
embeddinggemma_300m,
embeddinggemma_300m_8bit,
embeddinggemma_300m_6bit,
embeddinggemma_300m_4bit,
])
bootstrapState = .bootstrapped

Expand Down
7 changes: 7 additions & 0 deletions Libraries/Embedders/NomicBert.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import Foundation
import MLX
import MLXLMCommon
import MLXNN

class NomicEmbedding: Module {
Expand Down Expand Up @@ -390,6 +391,12 @@ public class NomicBertModel: Module, EmbeddingModel {
result[key] = item.value
}
}

public func sanitize(
weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?
) -> [String: MLXArray] {
fatalError("Nomic does not support quantization")
}
}

public struct NomicBertConfiguration: Decodable, Sendable {
Expand Down
6 changes: 6 additions & 0 deletions Libraries/Embedders/Qwen3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ public class Qwen3Model: Module, EmbeddingModel {

return sanitizedWeights
}

public func sanitize(
weights: [String: MLXArray], quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization?
) -> [String: MLXArray] {
fatalError("Qwen3 does not support quantization")
}
}

public struct Qwen3Configuration: Codable, Sendable {
Expand Down
Loading