Skip to content

Commit b146473

Browse files
committed
update for MLXEmbedders
1 parent 48d2df8 commit b146473

9 files changed

Lines changed: 145 additions & 121 deletions

File tree

Package.resolved

Lines changed: 76 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Tools/embedder-tool/EmbedderRuntime+Embedding.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ extension EmbedderRuntime {
2626
embeddings: [], skippedIndices: [], fallbackDescription: nil)
2727
}
2828

29-
return try await container.perform { model, tokenizer, pooler in
29+
return try await container.perform { context in
3030
var skippedIndices: [Int] = []
3131

32+
let tokenizer = context.tokenizer
3233
let encoded = texts.enumerated().compactMap { index, text -> (Int, [Int])? in
3334
let tokens = tokenizer.encode(text: text, addSpecialTokens: true)
3435
guard !tokens.isEmpty else {
@@ -58,14 +59,14 @@ extension EmbedderRuntime {
5859
let mask = (padded .!= padToken)
5960
let tokenTypes = MLXArray.zeros(like: padded)
6061

61-
let outputs = model(
62+
let outputs = context.model(
6263
padded,
6364
positionIds: nil,
6465
tokenTypeIds: tokenTypes,
6566
attentionMask: mask
6667
)
6768

68-
let poolingModule = resolvedPooler(for: pooler)
69+
let poolingModule = resolvedPooler(for: context.pooling)
6970
let pooled = poolingModule(
7071
outputs,
7172
mask: mask,

Tools/embedder-tool/EmbedderTool.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import ArgumentParser
55
import Foundation
66
import MLX
77
import MLXEmbedders
8+
import MLXLMCommon
89
import Tokenizers
910

1011
@main
@@ -17,7 +18,7 @@ struct EmbedderTool: AsyncParsableCommand {
1718
]
1819
)
1920

20-
private static let defaultModelConfiguration = ModelConfiguration.nomic_text_v1_5
21+
private static let defaultModelConfiguration = EmbedderRegistry.nomic_text_v1_5
2122

2223
@OptionGroup var model: ModelArguments
2324
@OptionGroup var corpus: CorpusArguments
@@ -42,9 +43,8 @@ struct EmbedderTool: AsyncParsableCommand {
4243
-> EmbedderRuntime
4344
{
4445
let loadedModel = try await model.load(default: defaultModelConfiguration)
45-
let baseStrategy = await loadedModel.container.perform { _, _, pooler in
46-
pooler.strategy
47-
}
46+
let baseStrategy = await loadedModel.container.poolingStrategy
47+
4848
return EmbedderRuntime(
4949
configuration: loadedModel.configuration,
5050
container: loadedModel.container,
@@ -58,7 +58,7 @@ struct EmbedderTool: AsyncParsableCommand {
5858

5959
struct EmbedderRuntime {
6060
let configuration: ModelConfiguration
61-
let container: ModelContainer
61+
let container: EmbedderModelContainer
6262
let baseStrategy: Pooling.Strategy
6363
let strategyOverride: Pooling.Strategy?
6464
let normalize: Bool

Tools/embedder-tool/ListCommand.swift

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ArgumentParser
22
import Foundation
33
import MLXEmbedders
4+
import MLXLMCommon
45

56
struct ListCommand: AsyncParsableCommand {
67
static let configuration = CommandConfiguration(
@@ -12,13 +13,13 @@ struct ListCommand: AsyncParsableCommand {
1213
var includeDirectories = false
1314

1415
func run() async throws {
15-
let models = await MainActor.run { Array(ModelConfiguration.models) }
16+
let models = await MainActor.run { Array(EmbedderRegistry.shared.models) }
1617
.sorted { $0.name.localizedCaseInsensitiveCompare($1.name) == .orderedAscending }
1718

1819
for configuration in models {
1920
switch configuration.id {
20-
case .id(let identifier):
21-
print(identifier)
21+
case .id(let id, let revision):
22+
print("\(id)/\(revision)")
2223
case .directory(let url):
2324
if includeDirectories {
2425
print(url.path)

Tools/embedder-tool/ModelArguments.swift

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ struct ModelArguments: ParsableArguments {
2828
return localConfiguration
2929
}
3030

31-
return ModelConfiguration.configuration(id: model)
31+
return ModelConfiguration(id: model)
3232
}
3333

3434
var downloadURL: URL? {
@@ -38,20 +38,22 @@ struct ModelArguments: ParsableArguments {
3838

3939
struct LoadedEmbedderModel {
4040
let configuration: ModelConfiguration
41-
let container: ModelContainer
41+
let container: EmbedderModelContainer
4242
}
4343

4444
extension ModelArguments {
4545

4646
func load(default defaultConfiguration: ModelConfiguration) async throws -> LoadedEmbedderModel
4747
{
4848
let configuration = await configuration(default: defaultConfiguration)
49-
let hub = makeHub()
49+
let hub = #hubDownloader
50+
let loader = #huggingFaceTokenizerLoader
5051

5152
print("Loading model \(configuration.name)...")
5253

53-
let container = try await MLXEmbedders.loadModelContainer(
54-
hub: hub,
54+
let container = try await EmbedderModelFactory.shared.loadContainer(
55+
from: hub,
56+
using: loader,
5557
configuration: configuration,
5658
progressHandler: { progress in
5759
let percentage = Int(progress.fractionCompleted * 100)

Tools/llm-tool/LLMTool.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct ModelArguments: ParsableArguments, Sendable {
4343
}
4444

4545
@Sendable
46-
func load(defaultModel: String, modelFactory: ModelFactory) async throws -> ModelContainer {
46+
func load(defaultModel: String, modelFactory: any ModelFactory) async throws -> ModelContainer {
4747
let modelConfiguration: ModelConfiguration
4848

4949
let modelName = self.model ?? defaultModel
@@ -58,7 +58,7 @@ struct ModelArguments: ParsableArguments, Sendable {
5858
modelConfiguration = modelFactory.configuration(id: modelName)
5959
}
6060

61-
return try await loadModelContainer(
61+
return try await modelFactory.loadContainer(
6262
from: self.downloader,
6363
using: #huggingFaceTokenizerLoader(),
6464
configuration: modelConfiguration)
@@ -320,7 +320,7 @@ struct EvaluateCommand: AsyncParsableCommand {
320320

321321
@MainActor
322322
mutating func run() async throws {
323-
let modelFactory: ModelFactory
323+
let modelFactory: any ModelFactory
324324
let defaultModel: ModelConfiguration
325325

326326
// Switch between LLM and VLM based on presence of media

0 commit comments

Comments
 (0)