Skip to content

Commit e091e33

Browse files
committed
mlx-swift-lm 3.x update
1 parent c684488 commit e091e33

16 files changed

Lines changed: 548 additions & 115 deletions

File tree

Applications/LLMBasic/ChatModel.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
// Copyright © 2025 Apple Inc.
22

3+
import HuggingFace
4+
import MLXHuggingFace
35
import MLXLLM
46
import MLXLMCommon
57
import SwiftUI
8+
import Tokenizers
69

710
/// which model to load
811
private let modelConfiguration = LLMRegistry.gemma3_1B_qat_4bit
@@ -40,7 +43,11 @@ private let generateParameters = GenerateParameters(temperature: 0.5)
4043
case .idle:
4144
let task = Task {
4245
// download and report progress
43-
try await loadModelContainer(configuration: modelConfiguration) { value in
46+
try await LLMModelFactory.shared.loadContainer(
47+
from: #hubDownloader(),
48+
using: #huggingFaceTokenizerLoader(),
49+
configuration: modelConfiguration
50+
) { value in
4451
Task { @MainActor in
4552
self.progress = value.fractionCompleted
4653
}

Applications/LLMEval/ViewModels/LLMEvaluator.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
// Copyright © 2025 Apple Inc.
22

33
import Hub
4+
import HuggingFace
45
import MLX
6+
import MLXHuggingFace
57
import MLXLLM
68
import MLXLMCommon
79
import Metal
810
import SwiftUI
11+
import Tokenizers
912

1013
@Observable
1114
@MainActor
@@ -101,14 +104,11 @@ class LLMEvaluator {
101104

102105
Memory.cacheLimit = 20 * 1024 * 1024
103106

104-
let hub = HubApi(
105-
downloadBase: FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first
106-
)
107-
108107
do {
109-
let modelDirectory = try await downloadModel(
110-
hub: hub,
111-
configuration: modelConfiguration
108+
let downloader = #hubDownloader()
109+
110+
let resolved = try await resolve(
111+
configuration: modelConfiguration, from: downloader, useLatest: false
112112
) { [weak self] progress in
113113
Task { @MainActor in
114114
self?.updateDownloadProgress(progress)
@@ -117,8 +117,9 @@ class LLMEvaluator {
117117

118118
// Verify the download succeeded by checking for model files
119119
let fileManager = FileManager.default
120-
let directoryExists = fileManager.fileExists(atPath: modelDirectory.path)
121-
let contents = (try? fileManager.contentsOfDirectory(atPath: modelDirectory.path)) ?? []
120+
let directoryExists = fileManager.fileExists(atPath: resolved.modelDirectory.path)
121+
let contents =
122+
(try? fileManager.contentsOfDirectory(atPath: resolved.modelDirectory.path)) ?? []
122123
let hasSafetensors = contents.contains { $0.hasSuffix(".safetensors") }
123124

124125
if !directoryExists || !hasSafetensors {
@@ -137,9 +138,8 @@ class LLMEvaluator {
137138
totalSize = nil
138139

139140
let modelContainer = try await LLMModelFactory.shared.loadContainer(
140-
hub: hub,
141-
configuration: modelConfiguration
142-
) { _ in }
141+
from: resolved.modelDirectory,
142+
using: #huggingFaceTokenizerLoader())
143143

144144
let numParams = await modelContainer.perform { $0.model.numParameters() }
145145

Applications/LoRATrainingExample/ContentView.swift

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright © 2024 Apple Inc.
22

3+
import HuggingFace
34
import MLX
5+
import MLXHuggingFace
46
import MLXLLM
57
import MLXLMCommon
68
import MLXNN
@@ -141,7 +143,12 @@ class LoRAEvaluator {
141143
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
142144
}
143145

146+
let downloader = #hubDownloader()
147+
let loader = #huggingFaceTokenizerLoader()
148+
144149
let modelContainer = try await LLMModelFactory.shared.loadContainer(
150+
from: downloader,
151+
using: loader,
145152
configuration: modelConfiguration
146153
) {
147154
progress in
@@ -186,7 +193,7 @@ class LoRAEvaluator {
186193
let modelContainer = try await loadModel()
187194

188195
// apply LoRA adapters and train
189-
let modelAdapter = try await modelContainer.perform { context in
196+
let _ = try await modelContainer.perform { context in
190197
try LoRAContainer.from(
191198
model: context.model,
192199
configuration: LoRAConfiguration(numLayers: loraLayers)
@@ -263,22 +270,28 @@ class LoRAEvaluator {
263270
let modelContainer = try await loadModel()
264271

265272
// evaluate
266-
let result = try await modelContainer.perform { context in
267-
let input = try await context.processor.prepare(input: .init(prompt: prompt))
268-
return try MLXLMCommon.generate(
269-
input: input, parameters: generateParameters, context: context
270-
) { tokens in
271-
if tokens.count % evaluateShowEvery == 0 {
272-
let fullOutput = context.tokenizer.decode(tokens: tokens)
273-
Task { @MainActor in
274-
self.output = fullOutput
275-
}
273+
let input = try await modelContainer.processor.prepare(input: .init(prompt: prompt))
274+
275+
var count = 0
276+
var output = ""
277+
for try await item in try await modelContainer.generate(
278+
input: input, parameters: generateParameters
279+
) {
280+
switch item {
281+
case .chunk(let string):
282+
count += 1
283+
output += string
284+
285+
if count % evaluateShowEvery == 0 {
286+
self.output = output
276287
}
277-
return tokens.count >= maxTokens ? .stop : .more
288+
289+
default:
290+
break
278291
}
279292
}
280293

281-
self.output = result.output
294+
self.output = output
282295
self.progress = nil
283296
}
284297
}

Applications/MLXChatExample/Services/MLXService.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
//
77

88
import Foundation
9+
import HuggingFace
910
import MLX
11+
import MLXHuggingFace
1012
import MLXLLM
1113
import MLXLMCommon
1214
import MLXVLM
15+
import Tokenizers
1316

1417
/// A service class that manages machine learning models for text and vision-language tasks.
1518
/// This class handles model loading, caching, and text generation using various LLM and VLM models.
@@ -63,9 +66,14 @@ class MLXService {
6366
VLMModelFactory.shared
6467
}
6568

69+
let downloader = #hubDownloader()
70+
let loader = #huggingFaceTokenizerLoader()
71+
6672
// Load model and track download progress
6773
let container = try await factory.loadContainer(
68-
hub: .default, configuration: model.configuration
74+
from: downloader,
75+
using: loader,
76+
configuration: model.configuration
6977
) { progress in
7078
Task { @MainActor in
7179
self.modelDownloadProgress = progress

Package.resolved

Lines changed: 76 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ let package = Package(
1515
targets: ["StableDiffusion"]),
1616
],
1717
dependencies: [
18-
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.30.3")),
18+
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")),
19+
20+
// Note: used by StableDiffusion library to download weights
1921
.package(
2022
url: "https://github.com/huggingface/swift-transformers",
21-
.upToNextMinor(from: "1.1.0")
23+
.upToNextMajor(from: "1.3.0")
2224
),
2325
.package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"), // Only needed by MLXMNIST
2426
],

Tools/Tutorial/Tutorial.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct Tutorial {
5050
print(x[1])
5151

5252
// make an array of shape [2, 2] filled with ones
53-
let y = MLXArray.ones([2, 2])
53+
let y = MLXArray.ones([2, 2], type: Float.self)
5454

5555
// pointwise add x and y
5656
let z = x + y

Tools/embedder-tool/EmbedderRuntime+Embedding.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Foundation
22
import MLX
33
import MLXEmbedders
4-
import Tokenizers
4+
import MLXLMCommon
55

66
public struct RuntimeEmbeddingResult {
77
public let embeddings: [(index: Int, vector: [Float])]
@@ -26,9 +26,10 @@ extension EmbedderRuntime {
2626
embeddings: [], skippedIndices: [], fallbackDescription: nil)
2727
}
2828

29-
return try await container.perform { model, tokenizer, pooler in
29+
return try await container.perform { context in
3030
var skippedIndices: [Int] = []
3131

32+
let tokenizer = context.tokenizer
3233
let encoded = texts.enumerated().compactMap { index, text -> (Int, [Int])? in
3334
let tokens = tokenizer.encode(text: text, addSpecialTokens: true)
3435
guard !tokens.isEmpty else {
@@ -58,14 +59,14 @@ extension EmbedderRuntime {
5859
let mask = (padded .!= padToken)
5960
let tokenTypes = MLXArray.zeros(like: padded)
6061

61-
let outputs = model(
62+
let outputs = context.model(
6263
padded,
6364
positionIds: nil,
6465
tokenTypeIds: tokenTypes,
6566
attentionMask: mask
6667
)
6768

68-
let poolingModule = resolvedPooler(for: pooler)
69+
let poolingModule = resolvedPooler(for: context.pooling)
6970
let pooled = poolingModule(
7071
outputs,
7172
mask: mask,

0 commit comments

Comments
 (0)