Skip to content

Commit 31e568f

Browse files
committed
mlx-swift-examples prep for mlx-swift-lm 3.x release
1 parent c684488 commit 31e568f

14 files changed

Lines changed: 419 additions & 51 deletions

File tree

Applications/LLMBasic/ChatModel.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
// Copyright © 2025 Apple Inc.
22

3+
import HuggingFace
4+
import MLXHuggingFace
35
import MLXLLM
46
import MLXLMCommon
57
import SwiftUI
8+
import Tokenizers
69

710
/// which model to load
811
private let modelConfiguration = LLMRegistry.gemma3_1B_qat_4bit
@@ -40,7 +43,11 @@ private let generateParameters = GenerateParameters(temperature: 0.5)
4043
case .idle:
4144
let task = Task {
4245
// download and report progress
43-
try await loadModelContainer(configuration: modelConfiguration) { value in
46+
try await loadModelContainer(
47+
from: #hubDownloader(),
48+
using: #huggingFaceTokenizerLoader(),
49+
configuration: modelConfiguration
50+
) { value in
4451
Task { @MainActor in
4552
self.progress = value.fractionCompleted
4653
}

Applications/LLMEval/ViewModels/LLMEvaluator.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
// Copyright © 2025 Apple Inc.
22

33
import Hub
4+
import HuggingFace
45
import MLX
6+
import MLXHuggingFace
57
import MLXLLM
68
import MLXLMCommon
79
import Metal
810
import SwiftUI
11+
import Tokenizers
912

1013
@Observable
1114
@MainActor
@@ -101,14 +104,11 @@ class LLMEvaluator {
101104

102105
Memory.cacheLimit = 20 * 1024 * 1024
103106

104-
let hub = HubApi(
105-
downloadBase: FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first
106-
)
107-
108107
do {
109-
let modelDirectory = try await downloadModel(
110-
hub: hub,
111-
configuration: modelConfiguration
108+
let downloader = #hubDownloader()
109+
110+
let resolved = try await resolve(
111+
configuration: modelConfiguration, from: downloader, useLatest: false
112112
) { [weak self] progress in
113113
Task { @MainActor in
114114
self?.updateDownloadProgress(progress)
@@ -117,8 +117,9 @@ class LLMEvaluator {
117117

118118
// Verify the download succeeded by checking for model files
119119
let fileManager = FileManager.default
120-
let directoryExists = fileManager.fileExists(atPath: modelDirectory.path)
121-
let contents = (try? fileManager.contentsOfDirectory(atPath: modelDirectory.path)) ?? []
120+
let directoryExists = fileManager.fileExists(atPath: resolved.modelDirectory.path)
121+
let contents =
122+
(try? fileManager.contentsOfDirectory(atPath: resolved.modelDirectory.path)) ?? []
122123
let hasSafetensors = contents.contains { $0.hasSuffix(".safetensors") }
123124

124125
if !directoryExists || !hasSafetensors {
@@ -137,9 +138,8 @@ class LLMEvaluator {
137138
totalSize = nil
138139

139140
let modelContainer = try await LLMModelFactory.shared.loadContainer(
140-
hub: hub,
141-
configuration: modelConfiguration
142-
) { _ in }
141+
from: resolved.modelDirectory,
142+
using: #huggingFaceTokenizerLoader())
143143

144144
let numParams = await modelContainer.perform { $0.model.numParameters() }
145145

Applications/LoRATrainingExample/ContentView.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright © 2024 Apple Inc.
22

3+
import HuggingFace
34
import MLX
5+
import MLXHuggingFace
46
import MLXLLM
57
import MLXLMCommon
68
import MLXNN
@@ -141,7 +143,12 @@ class LoRAEvaluator {
141143
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
142144
}
143145

146+
let downloader = #hubDownloader()
147+
let loader = #huggingFaceTokenizerLoader()
148+
144149
let modelContainer = try await LLMModelFactory.shared.loadContainer(
150+
from: downloader,
151+
using: loader,
145152
configuration: modelConfiguration
146153
) {
147154
progress in

Applications/MLXChatExample/Services/MLXService.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
//
77

88
import Foundation
9+
import HuggingFace
910
import MLX
11+
import MLXHuggingFace
1012
import MLXLLM
1113
import MLXLMCommon
1214
import MLXVLM
15+
import Tokenizers
1316

1417
/// A service class that manages machine learning models for text and vision-language tasks.
1518
/// This class handles model loading, caching, and text generation using various LLM and VLM models.
@@ -63,9 +66,14 @@ class MLXService {
6366
VLMModelFactory.shared
6467
}
6568

69+
let downloader = #hubDownloader()
70+
let loader = #huggingFaceTokenizerLoader()
71+
6672
// Load model and track download progress
6773
let container = try await factory.loadContainer(
68-
hub: .default, configuration: model.configuration
74+
from: downloader,
75+
using: loader,
76+
configuration: model.configuration
6977
) { progress in
7078
Task { @MainActor in
7179
self.modelDownloadProgress = progress

Package.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ let package = Package(
1515
targets: ["StableDiffusion"]),
1616
],
1717
dependencies: [
18-
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.30.3")),
18+
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")),
19+
20+
// Note: used by StableDiffusion library to download weights
1921
.package(
2022
url: "https://github.com/huggingface/swift-transformers",
21-
.upToNextMinor(from: "1.1.0")
23+
.upToNextMajor(from: "1.3.0")
2224
),
2325
.package(url: "https://github.com/1024jp/GzipSwift", "6.0.1" ... "6.0.1"), // Only needed by MLXMNIST
2426
],

Tools/Tutorial/Tutorial.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct Tutorial {
5050
print(x[1])
5151

5252
// make an array of shape [2, 2] filled with ones
53-
let y = MLXArray.ones([2, 2])
53+
let y = MLXArray.ones([2, 2], type: Float.self)
5454

5555
// pointwise add x and y
5656
let z = x + y

Tools/embedder-tool/EmbedderRuntime+Embedding.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Foundation
22
import MLX
33
import MLXEmbedders
4-
import Tokenizers
4+
import MLXLMCommon
55

66
public struct RuntimeEmbeddingResult {
77
public let embeddings: [(index: Int, vector: [Float])]

Tools/embedder-tool/ModelArguments.swift

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import ArgumentParser
44
import Foundation
5-
import Hub
5+
import HuggingFace
66
import MLXEmbedders
7+
import MLXHuggingFace
8+
import MLXLMCommon
9+
import Tokenizers
710

811
struct ModelArguments: ParsableArguments {
912

@@ -63,12 +66,15 @@ extension ModelArguments {
6366
return LoadedEmbedderModel(configuration: configuration, container: container)
6467
}
6568

66-
private func makeHub() -> HubApi {
67-
if let downloadURL {
68-
return HubApi(downloadBase: downloadURL)
69-
}
70-
71-
return HubApi()
69+
var downloader: any Downloader {
70+
let client =
71+
if let download {
72+
HubClient(cache: HubCache(cacheDirectory: download))
73+
} else {
74+
HubClient()
75+
}
76+
let downloader = #hubDownloader(client)
77+
return downloader
7278
}
7379
}
7480

Tools/llm-tool/LLMTool.swift

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import ArgumentParser
44
import CoreImage
55
import Foundation
6-
import Hub
6+
import HuggingFace
77
import MLX
8+
import MLXHuggingFace
89
import MLXLLM
910
import MLXLMCommon
1011
import MLXVLM
@@ -30,6 +31,17 @@ struct ModelArguments: ParsableArguments, Sendable {
3031
@Option(help: "Hub download directory")
3132
var download: URL?
3233

34+
var downloader: any Downloader {
35+
let client =
36+
if let download {
37+
HubClient(cache: HubCache(cacheDirectory: download))
38+
} else {
39+
HubClient()
40+
}
41+
let downloader = #hubDownloader(client)
42+
return downloader
43+
}
44+
3345
@Sendable
3446
func load(defaultModel: String, modelFactory: ModelFactory) async throws -> ModelContainer {
3547
let modelConfiguration: ModelConfiguration
@@ -46,14 +58,10 @@ struct ModelArguments: ParsableArguments, Sendable {
4658
modelConfiguration = modelFactory.configuration(id: modelName)
4759
}
4860

49-
let hub =
50-
if let download {
51-
HubApi(downloadBase: download)
52-
} else {
53-
HubApi()
54-
}
55-
56-
return try await modelFactory.loadContainer(hub: hub, configuration: modelConfiguration)
61+
return try await loadModelContainer(
62+
from: self.downloader,
63+
using: #huggingFaceTokenizerLoader(),
64+
configuration: modelConfiguration)
5765
}
5866
}
5967

@@ -157,6 +165,9 @@ struct GenerateArguments: ParsableArguments, Sendable {
157165
@Flag(name: .shortAndLong, help: "If true only print the generated output")
158166
var quiet = false
159167

168+
@Flag(name: .customLong("tool-time"), help: "Enable time telling tool")
169+
var useTimeTool = false
170+
160171
var generateParameters: GenerateParameters {
161172
GenerateParameters(
162173
maxTokens: maxTokens,
@@ -167,6 +178,23 @@ struct GenerateArguments: ParsableArguments, Sendable {
167178
repetitionContextSize: repetitionContextSize)
168179
}
169180

181+
var toolSpecs: [MLXLMCommon.ToolSpec] {
182+
var tools = [MLXLMCommon.ToolSpec]()
183+
184+
if useTimeTool {
185+
tools.append(timeTool.schema)
186+
}
187+
188+
return tools
189+
}
190+
191+
func call(toolCall: ToolCall) async throws -> String {
192+
if useTimeTool && toolCall.function.name == timeTool.name {
193+
return try await toolCall.execute(with: timeTool).toolResult
194+
}
195+
return "Unknown tool: \(toolCall.function.name)"
196+
}
197+
170198
func prepare(
171199
_ context: inout ModelContext
172200
) {
@@ -188,7 +216,14 @@ struct GenerateArguments: ParsableArguments, Sendable {
188216
print(string, terminator: "")
189217
case .info(let info):
190218
return (info, output)
191-
case .toolCall:
219+
case .toolCall(let toolCall):
220+
do {
221+
// TODO maybe just use ChatSession here?
222+
let x = try await call(toolCall: toolCall)
223+
print("TOOL RESULT: \(x)")
224+
} catch {
225+
print("\nError executing tool: \(error.localizedDescription)")
226+
}
192227
break
193228
}
194229
}
@@ -323,7 +358,8 @@ struct EvaluateCommand: AsyncParsableCommand {
323358
modelContainer,
324359
instructions: generate.system,
325360
generateParameters: generate.generateParameters,
326-
processing: media.processing
361+
processing: media.processing,
362+
tools: generate.toolSpecs
327363
)
328364

329365
if !generate.quiet {

Tools/llm-tool/LoraCommands.swift

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import ArgumentParser
44
import Foundation
55
import Hub
6+
import HuggingFace
67
import MLX
8+
import MLXHuggingFace
79
import MLXLLM
810
import MLXLMCommon
911
import MLXNN
@@ -185,8 +187,19 @@ struct LoRAFuseCommand: AsyncParsableCommand {
185187
if output.hasPrefix("/") {
186188
outputURL = URL(filePath: output)
187189
} else {
188-
let repo = HubApi.Repo(id: output)
189-
outputURL = HubApi().localRepoLocation(repo)
190+
let cache =
191+
if let download = args.args.download {
192+
HubCache(cacheDirectory: download)
193+
} else {
194+
HubCache.default
195+
}
196+
197+
let parts = output.components(separatedBy: "/")
198+
guard parts.count == 2 else {
199+
fatalError("output must be org/name, e.g. mlx-community/mistral-lora: \(output)")
200+
}
201+
let repo = Repo.ID(namespace: parts[0], name: parts[1])
202+
outputURL = cache.repoDirectory(repo: repo, kind: .model)
190203
}
191204

192205
let (modelContainer, modelAdapter) = try await args.load()
@@ -196,9 +209,14 @@ struct LoRAFuseCommand: AsyncParsableCommand {
196209
try context.model.fuse(with: modelAdapter)
197210
}
198211

212+
let resolved = try await resolve(
213+
configuration: modelContainer.configuration,
214+
from: args.args.downloader,
215+
useLatest: false, progressHandler: { _ in })
216+
199217
// make the new directory and copy files from source model
200218
try FileManager.default.createDirectory(at: outputURL, withIntermediateDirectories: true)
201-
let inputURL = await modelContainer.configuration.modelDirectory()
219+
let inputURL = resolved.modelDirectory
202220
let enumerator = FileManager.default.enumerator(
203221
at: inputURL, includingPropertiesForKeys: nil)!
204222
for url in enumerator.allObjects.compactMap({ $0 as? URL }) {

0 commit comments

Comments
 (0)