|
1 | 1 | // Copyright © 2024 Apple Inc. |
2 | 2 |
|
| 3 | +import HuggingFace |
3 | 4 | import MLX |
| 5 | +import MLXHuggingFace |
4 | 6 | import MLXLLM |
5 | 7 | import MLXLMCommon |
6 | 8 | import MLXNN |
@@ -141,7 +143,12 @@ class LoRAEvaluator { |
141 | 143 | progress = .init(title: "Loading \(name)", current: 0, limit: 1) |
142 | 144 | } |
143 | 145 |
|
| 146 | + let downloader = #hubDownloader() |
| 147 | + let loader = #huggingFaceTokenizerLoader() |
| 148 | + |
144 | 149 | let modelContainer = try await LLMModelFactory.shared.loadContainer( |
| 150 | + from: downloader, |
| 151 | + using: loader, |
145 | 152 | configuration: modelConfiguration |
146 | 153 | ) { |
147 | 154 | progress in |
@@ -186,7 +193,7 @@ class LoRAEvaluator { |
186 | 193 | let modelContainer = try await loadModel() |
187 | 194 |
|
188 | 195 | // apply LoRA adapters and train |
189 | | - let modelAdapter = try await modelContainer.perform { context in |
| 196 | + let _ = try await modelContainer.perform { context in |
190 | 197 | try LoRAContainer.from( |
191 | 198 | model: context.model, |
192 | 199 | configuration: LoRAConfiguration(numLayers: loraLayers) |
@@ -263,22 +270,28 @@ class LoRAEvaluator { |
263 | 270 | let modelContainer = try await loadModel() |
264 | 271 |
|
265 | 272 | // evaluate |
266 | | - let result = try await modelContainer.perform { context in |
267 | | - let input = try await context.processor.prepare(input: .init(prompt: prompt)) |
268 | | - return try MLXLMCommon.generate( |
269 | | - input: input, parameters: generateParameters, context: context |
270 | | - ) { tokens in |
271 | | - if tokens.count % evaluateShowEvery == 0 { |
272 | | - let fullOutput = context.tokenizer.decode(tokens: tokens) |
273 | | - Task { @MainActor in |
274 | | - self.output = fullOutput |
275 | | - } |
| 273 | + let input = try await modelContainer.processor.prepare(input: .init(prompt: prompt)) |
| 274 | + |
| 275 | + var count = 0 |
| 276 | + var output = "" |
| 277 | + for try await item in try await modelContainer.generate( |
| 278 | + input: input, parameters: generateParameters |
| 279 | + ) { |
| 280 | + switch item { |
| 281 | + case .chunk(let string): |
| 282 | + count += 1 |
| 283 | + output += string |
| 284 | + |
| 285 | + if count % evaluateShowEvery == 0 { |
| 286 | + self.output = output |
276 | 287 | } |
277 | | - return tokens.count >= maxTokens ? .stop : .more |
| 288 | + |
| 289 | + default: |
| 290 | + break |
278 | 291 | } |
279 | 292 | } |
280 | 293 |
|
281 | | - self.output = result.output |
| 294 | + self.output = output |
282 | 295 | self.progress = nil |
283 | 296 | } |
284 | 297 | } |
0 commit comments