-
Notifications
You must be signed in to change notification settings - Fork 383
mlx-swift-examples prep for mlx-swift-lm 3.x release #468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,14 @@ | ||
| // Copyright © 2025 Apple Inc. | ||
|
|
||
| import Hub | ||
| import HuggingFace | ||
| import MLX | ||
| import MLXHuggingFace | ||
| import MLXLLM | ||
| import MLXLMCommon | ||
| import Metal | ||
| import SwiftUI | ||
| import Tokenizers | ||
|
|
||
| @Observable | ||
| @MainActor | ||
|
|
@@ -101,14 +104,11 @@ class LLMEvaluator { | |
|
|
||
| Memory.cacheLimit = 20 * 1024 * 1024 | ||
|
|
||
| let hub = HubApi( | ||
| downloadBase: FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first | ||
| ) | ||
|
|
||
| do { | ||
| let modelDirectory = try await downloadModel( | ||
| hub: hub, | ||
| configuration: modelConfiguration | ||
| let downloader = #hubDownloader() | ||
|
|
||
| let resolved = try await resolve( | ||
| configuration: modelConfiguration, from: downloader, useLatest: false | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a little more complex because it separates out the download from the load. |
||
| ) { [weak self] progress in | ||
| Task { @MainActor in | ||
| self?.updateDownloadProgress(progress) | ||
|
|
@@ -117,8 +117,9 @@ class LLMEvaluator { | |
|
|
||
| // Verify the download succeeded by checking for model files | ||
| let fileManager = FileManager.default | ||
| let directoryExists = fileManager.fileExists(atPath: modelDirectory.path) | ||
| let contents = (try? fileManager.contentsOfDirectory(atPath: modelDirectory.path)) ?? [] | ||
| let directoryExists = fileManager.fileExists(atPath: resolved.modelDirectory.path) | ||
| let contents = | ||
| (try? fileManager.contentsOfDirectory(atPath: resolved.modelDirectory.path)) ?? [] | ||
| let hasSafetensors = contents.contains { $0.hasSuffix(".safetensors") } | ||
|
|
||
| if !directoryExists || !hasSafetensors { | ||
|
|
@@ -137,9 +138,8 @@ class LLMEvaluator { | |
| totalSize = nil | ||
|
|
||
| let modelContainer = try await LLMModelFactory.shared.loadContainer( | ||
| hub: hub, | ||
| configuration: modelConfiguration | ||
| ) { _ in } | ||
| from: resolved.modelDirectory, | ||
| using: #huggingFaceTokenizerLoader()) | ||
|
|
||
| let numParams = await modelContainer.perform { $0.model.numParameters() } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| // Copyright © 2024 Apple Inc. | ||
|
|
||
| import HuggingFace | ||
| import MLX | ||
| import MLXHuggingFace | ||
| import MLXLLM | ||
| import MLXLMCommon | ||
| import MLXNN | ||
|
|
@@ -141,7 +143,12 @@ class LoRAEvaluator { | |
| progress = .init(title: "Loading \(name)", current: 0, limit: 1) | ||
| } | ||
|
|
||
| let downloader = #hubDownloader() | ||
| let loader = #huggingFaceTokenizerLoader() | ||
|
Comment on lines
+146
to
+147
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't work yet -- there is a local |
||
|
|
||
| let modelContainer = try await LLMModelFactory.shared.loadContainer( | ||
| from: downloader, | ||
| using: loader, | ||
| configuration: modelConfiguration | ||
| ) { | ||
| progress in | ||
|
|
@@ -186,7 +193,7 @@ class LoRAEvaluator { | |
| let modelContainer = try await loadModel() | ||
|
|
||
| // apply LoRA adapters and train | ||
| let modelAdapter = try await modelContainer.perform { context in | ||
| let _ = try await modelContainer.perform { context in | ||
| try LoRAContainer.from( | ||
| model: context.model, | ||
| configuration: LoRAConfiguration(numLayers: loraLayers) | ||
|
|
@@ -263,22 +270,28 @@ class LoRAEvaluator { | |
| let modelContainer = try await loadModel() | ||
|
|
||
| // evaluate | ||
| let result = try await modelContainer.perform { context in | ||
| let input = try await context.processor.prepare(input: .init(prompt: prompt)) | ||
| return try MLXLMCommon.generate( | ||
| input: input, parameters: generateParameters, context: context | ||
| ) { tokens in | ||
| if tokens.count % evaluateShowEvery == 0 { | ||
| let fullOutput = context.tokenizer.decode(tokens: tokens) | ||
| Task { @MainActor in | ||
| self.output = fullOutput | ||
| } | ||
| let input = try await modelContainer.processor.prepare(input: .init(prompt: prompt)) | ||
|
|
||
| var count = 0 | ||
| var output = "" | ||
| for try await item in try await modelContainer.generate( | ||
| input: input, parameters: generateParameters | ||
| ) { | ||
| switch item { | ||
| case .chunk(let string): | ||
| count += 1 | ||
| output += string | ||
|
|
||
| if count % evaluateShowEvery == 0 { | ||
| self.output = output | ||
| } | ||
| return tokens.count >= maxTokens ? .stop : .more | ||
|
|
||
| default: | ||
| break | ||
| } | ||
| } | ||
|
|
||
| self.output = result.output | ||
| self.output = output | ||
| self.progress = nil | ||
| } | ||
| } | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Along with linking the HuggingFace libraries, this is the pattern for adopting the new API.