-
Notifications
You must be signed in to change notification settings - Fork 206
Embedding gemma #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a7725cf
3eda8bf
d47b48c
62c4b0f
771f6f4
d888682
e96ceb5
6e07a79
383961d
053232b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| // Copyright © 2024 Apple Inc. | ||
|
|
||
| import Foundation | ||
| import MLXLLM | ||
|
|
||
| public enum StringOrNumber: Codable, Equatable, Sendable { | ||
| case string(String) | ||
|
|
@@ -76,6 +77,13 @@ private class ModelTypeRegistry: @unchecked Sendable { | |
| let model = Qwen3Model(configuration) | ||
| return model | ||
| }, | ||
| "gemma3_text": { | ||
| url in | ||
| let configuration = try JSONDecoder().decode( | ||
| Gemma3TextConfiguration.self, from: Data(contentsOf: url)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes sense to copy this config type into Embedders rather than add new linkage. Even sharing config types between models in the same library is rarely done. |
||
| let model = EmbeddingGemma(configuration) | ||
| return model | ||
| }, | ||
| ] | ||
|
|
||
| public func registerModelType( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| import MLX | ||
| import MLXLLM | ||
| import MLXLMCommon | ||
|
Comment on lines
+2
to
+3
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See elsewhere -- I think this should be done without adding linkage to additional libraries. |
||
| import MLXNN | ||
|
|
||
| public class EmbeddingGemma: Module, EmbeddingModel { | ||
| @ModuleInfo private var model: Gemma3TextModel | ||
| @ModuleInfo private var dense: [Module] | ||
|
|
||
| public let config: Gemma3TextConfiguration | ||
| public var vocabularySize: Int { config.vocabularySize } | ||
|
|
||
| public init(_ config: Gemma3TextConfiguration) { | ||
| self.config = config | ||
| self.model = Gemma3TextModel(config) | ||
| self.dense = [ | ||
| Linear(768, 3072, bias: false), Linear(3072, 768, bias: false), | ||
| ] | ||
| } | ||
|
|
||
| public func callAsFunction( | ||
| _ inputs: MLXArray, positionIds: MLXArray?, tokenTypeIds: MLXArray?, | ||
| attentionMask: MLXArray? | ||
| ) -> EmbeddingModelOutput { | ||
| var out = model.getHiddenStates(inputs, mask: nil, cache: nil) | ||
|
|
||
| // mean pooling | ||
| let notPadding = inputs .!= 0 | ||
| let sum = (out * notPadding[.ellipsis, .newAxis]).sum(axis: 1) | ||
| let nonMasked = notPadding.sum(axis: -1, keepDims: true) | ||
| out = sum / nonMasked | ||
|
|
||
| for dense in self.dense { | ||
| if let dense = dense as? Linear { | ||
| out = dense(out) | ||
| } else if let dense = dense as? QuantizedLinear { | ||
| out = dense(out) | ||
| } | ||
| } | ||
|
|
||
| // normalize | ||
| out = out.asType(Float32.self) | ||
| let norm = maximum(norm(out, ord: 2.0, axis: -1, keepDims: true), MLXArray(1e-6)) | ||
| let pooledOutput = out / norm | ||
|
|
||
| return EmbeddingModelOutput(hiddenStates: out, pooledOutput: pooledOutput) | ||
| } | ||
|
|
||
| /// Get hidden states before the dense projection head | ||
| public func getHiddenStates( | ||
| _ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil, | ||
| cache: [KVCache]? = nil | ||
| ) -> MLXArray { | ||
| return model(inputs, mask: mask, cache: cache) | ||
| } | ||
|
|
||
| public func sanitize( | ||
| weights: [String: MLXArray], | ||
| quantizationConfig: MLXLMCommon.BaseConfiguration.Quantization? = nil | ||
| ) | ||
| -> [String: MLXArray] | ||
| { | ||
| var processedWeights = model.sanitize( | ||
| weights: weights, quantizationConfig: quantizationConfig) | ||
|
|
||
| // 1. Add a model. prefix to all model. weights | ||
| processedWeights = Dictionary( | ||
| uniqueKeysWithValues: processedWeights.map { key, value in | ||
| if key.hasPrefix("model.") || key.hasPrefix("lm_head.") { | ||
| return ("model.\(key)", value) | ||
| } else { | ||
| return (key, value) | ||
| } | ||
| }) | ||
|
|
||
| // 2. Apply quantization to dense layers, if needed | ||
| let hasQuantizedDense = hasQuantizedWeights(layerPath: "dense.0", in: processedWeights) | ||
| if hasQuantizedDense { | ||
| let groupSize = quantizationConfig?.groupSize ?? 64 | ||
| let bits = quantizationConfig?.bits ?? 4 | ||
|
|
||
| quantize(model: self) { path, module in | ||
| if hasQuantizedWeights(layerPath: path, in: processedWeights) { | ||
| return (groupSize, bits) | ||
| } | ||
| return nil | ||
| } | ||
| } | ||
|
|
||
| return processedWeights.filter { key, _ in | ||
| !key.contains("self_attn.rotary_emb.inv_freq") | ||
| } | ||
| } | ||
|
|
||
| public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { | ||
| sanitize(weights: weights, quantizationConfig: nil) | ||
| } | ||
|
|
||
| /// Check if a layer has quantized weights | ||
| private func hasQuantizedWeights(layerPath: String, in weights: [String: MLXArray]) -> Bool { | ||
| let scalesKey = "\(layerPath).scales" | ||
| let biasesKey = "\(layerPath).biases" | ||
| let weightKey = "\(layerPath).weight" | ||
|
|
||
| let hasScales = weights[scalesKey] != nil | ||
| let hasBiases = weights[biasesKey] != nil | ||
| let hasWeight = weights[weightKey]?.dtype == .uint32 | ||
|
|
||
| return hasScales && hasBiases && hasWeight | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just to pick up the quantization? Think the Embedders should not require MLXLMCommon / MLXLLM if possible. A copy of Quantization is OK. If we end up with a lot of duplication then I think MLXLMCommon might make sense.