diff --git a/Libraries/BenchmarkHelpers/BenchmarkHelpers.swift b/Libraries/BenchmarkHelpers/BenchmarkHelpers.swift new file mode 100644 index 000000000..93a0d78a6 --- /dev/null +++ b/Libraries/BenchmarkHelpers/BenchmarkHelpers.swift @@ -0,0 +1,464 @@ +// Shared benchmark logic for measuring model loading, tokenizer performance, +// and download performance. +// Integration packages inject their own Downloader and TokenizerLoader. + +import Foundation +import MLX +import MLXEmbedders +import MLXLLM +import MLXLMCommon +import MLXVLM + +// MARK: - No-Op Tokenizer + +/// A tokenizer loader that returns a stub tokenizer. Useful for benchmarking +/// model loading in downloader integration packages that don't provide a +/// real tokenizer. +public struct NoOpTokenizerLoader: TokenizerLoader { + public init() {} + + public func load(from directory: URL) async throws -> any Tokenizer { + NoOpTokenizer() + } +} + +private struct NoOpTokenizer: Tokenizer { + func encode(text: String, addSpecialTokens: Bool) -> [Int] { [] } + func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String { "" } + func convertTokenToId(_ token: String) -> Int? { nil } + func convertIdToToken(_ id: Int) -> String? { nil } + var bosToken: String? { nil } + var eosToken: String? { nil } + var unknownToken: String? { nil } + + func applyChatTemplate( + messages: [[String: any Sendable]], + tools: [[String: any Sendable]]?, + additionalContext: [String: any Sendable]? + ) throws -> [Int] { + throw MLXLMCommon.TokenizerError.missingChatTemplate + } +} + +// MARK: - Stats + +public struct BenchmarkStats: Sendable { + public let mean: Double + public let median: Double + public let stdDev: Double + public let min: Double + public let max: Double + + public init(times: [Double]) { + precondition(!times.isEmpty, "BenchmarkStats requires at least one timing measurement") + let sorted = times.sorted() + self.min = sorted.first! + self.max = sorted.last! + let mean = times.reduce(0, +) / Double(times.count) + self.mean = mean + self.median = sorted[sorted.count / 2] + + let squaredDiffs = times.map { ($0 - mean) * ($0 - mean) } + self.stdDev = sqrt(squaredDiffs.reduce(0, +) / Double(times.count)) + } + + public func printSummary(label: String) { + print("\(label) results:") + print(" Mean: \(String(format: "%.1f", mean))ms") + print(" Median: \(String(format: "%.1f", median))ms") + print(" StdDev: \(String(format: "%.1f", stdDev))ms") + print(" Range: \(String(format: "%.1f", min))-\(String(format: "%.1f", max))ms") + } +} + +// MARK: - Benchmark Text + +public enum BenchmarkDefaults { + public static let textSource = BenchmarkTextSource.prideAndPrejudice + public static let tokenizationTextCharacterCount = 20_000 + public static let decodingTextCharacterCount = 200_000 + public static let loadingRuns = 7 + public static let downloadRuns = 7 + public static let tokenizationRuns = 25 + public static let decodingRuns = 25 + public static let decodesPerRun = 10 +} + +public struct BenchmarkTextSource: Sendable { + public let name: String + public let url: URL + public let contentStartMarker: String? + + public init(name: String, url: URL, contentStartMarker: String? = nil) { + self.name = name + self.url = url + self.contentStartMarker = contentStartMarker + } + + public static let prideAndPrejudice = BenchmarkTextSource( + name: "pride-and-prejudice", + url: URL(string: "https://www.gutenberg.org/ebooks/1342.txt.utf-8")!, + contentStartMarker: "It is a truth universally acknowledged" + ) +} + +public enum BenchmarkTextError: LocalizedError { + case invalidResponse(URL) + case decodeFailed(URL) + case contentStartMarkerNotFound(String) + + public var errorDescription: String? { + switch self { + case .invalidResponse(let url): + return "Unexpected response while fetching benchmark text from \(url.absoluteString)." + case .decodeFailed(let url): + return "Failed to decode benchmark text from \(url.absoluteString) as UTF-8." + case .contentStartMarkerNotFound(let marker): + return "Benchmark text start marker not found: '\(marker)'." + } + } +} + +private func benchmarkTextCacheURL(for source: BenchmarkTextSource) -> URL { + FileManager.default.temporaryDirectory + .appending(component: "BenchmarkHelpers", directoryHint: .isDirectory) + .appending(component: "\(source.name).txt") +} + +private func normalizeBenchmarkText(_ text: String) -> String { + text.replacingOccurrences(of: "\r\n", with: "\n") + .replacingOccurrences(of: "\r", with: "\n") +} + +private func trimmedBenchmarkText(_ text: String, source: BenchmarkTextSource) throws -> String { + guard let marker = source.contentStartMarker else { + return text + } + guard let markerRange = text.range(of: marker) else { + throw BenchmarkTextError.contentStartMarkerNotFound(marker) + } + return String(text[markerRange.lowerBound...]) +} + +private func fetchBenchmarkText(source: BenchmarkTextSource) async throws -> String { + let cacheURL = benchmarkTextCacheURL(for: source) + let fileManager = FileManager.default + + if fileManager.fileExists(atPath: cacheURL.path) { + let cached = try String(contentsOf: cacheURL, encoding: .utf8) + return try trimmedBenchmarkText(normalizeBenchmarkText(cached), source: source) + } + + let (data, response) = try await URLSession.shared.data(from: source.url) + guard let httpResponse = response as? HTTPURLResponse, + (200 ..< 300).contains(httpResponse.statusCode) + else { + throw BenchmarkTextError.invalidResponse(source.url) + } + guard let downloaded = String(data: data, encoding: .utf8) else { + throw BenchmarkTextError.decodeFailed(source.url) + } + + try fileManager.createDirectory( + at: cacheURL.deletingLastPathComponent(), + withIntermediateDirectories: true + ) + try downloaded.write(to: cacheURL, atomically: true, encoding: .utf8) + + return try trimmedBenchmarkText(normalizeBenchmarkText(downloaded), source: source) +} + +/// Load benchmark text from a remote public-domain source and cache it locally in the temporary directory. +public func loadBenchmarkText( + source: BenchmarkTextSource = BenchmarkDefaults.textSource, + characterCount: Int = 20_000 +) async throws -> String { + precondition(characterCount > 0, "characterCount must be greater than zero") + let text = try await fetchBenchmarkText(source: source) + return String(text.prefix(characterCount)) +} + +public func loadTokenizationBenchmarkText( + source: BenchmarkTextSource = BenchmarkDefaults.textSource +) async throws -> String { + try await loadBenchmarkText( + source: source, + characterCount: BenchmarkDefaults.tokenizationTextCharacterCount + ) +} + +public func loadDecodingBenchmarkText( + source: BenchmarkTextSource = BenchmarkDefaults.textSource +) async throws -> String { + try await loadBenchmarkText( + source: source, + characterCount: BenchmarkDefaults.decodingTextCharacterCount + ) +} + +private func resolveTokenizerDirectory( + from downloader: any Downloader, + configuration: MLXLMCommon.ModelConfiguration, + useLatest: Bool +) async throws -> URL { + switch configuration.tokenizerSource { + case .id(let id, let revision): + return try await downloader.download( + id: id, + revision: revision, + matching: tokenizerDownloadPatterns, + useLatest: useLatest, + progressHandler: { _ in } + ) + case .directory(let directory): + return directory + case nil: + switch configuration.id { + case .id(let id, let revision): + return try await downloader.download( + id: id, + revision: revision, + matching: tokenizerDownloadPatterns, + useLatest: useLatest, + progressHandler: { _ in } + ) + case .directory(let directory): + return directory + } + } +} + +private func loadTokenizerForBenchmark( + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + configuration: MLXLMCommon.ModelConfiguration, + useLatest: Bool +) async throws -> any Tokenizer { + let tokenizerDirectory = try await resolveTokenizerDirectory( + from: downloader, + configuration: configuration, + useLatest: useLatest + ) + return try await tokenizerLoader.load(from: tokenizerDirectory) +} + +// MARK: - Benchmark Runners + +/// Benchmark tokenizer loading without downloading model weights or initializing a model. +public func benchmarkTokenizerLoading( + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + configuration: MLXLMCommon.ModelConfiguration = .init(id: "mlx-community/Qwen3-0.6B-4bit"), + useLatest: Bool = false, + runs: Int = BenchmarkDefaults.loadingRuns +) async throws -> BenchmarkStats { + let tokenizerDirectory = try await resolveTokenizerDirectory( + from: downloader, + configuration: configuration, + useLatest: useLatest + ) + + _ = try await tokenizerLoader.load(from: tokenizerDirectory) + + var times: [Double] = [] + for i in 1 ... runs { + let start = CFAbsoluteTimeGetCurrent() + _ = try await tokenizerLoader.load(from: tokenizerDirectory) + let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000 + times.append(elapsed) + print("Tokenizer load run \(i): \(String(format: "%.1f", elapsed))ms") + } + + return BenchmarkStats(times: times) +} + +/// Benchmark tokenization on a preloaded tokenizer without initializing a model. +public func benchmarkTokenization( + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + configuration: MLXLMCommon.ModelConfiguration = .init(id: "mlx-community/Qwen3-0.6B-4bit"), + text: String = "The quick brown fox jumps over the lazy dog.", + addSpecialTokens: Bool = true, + useLatest: Bool = false, + runs: Int = BenchmarkDefaults.tokenizationRuns +) async throws -> BenchmarkStats { + let tokenizer = try await loadTokenizerForBenchmark( + from: downloader, + using: tokenizerLoader, + configuration: configuration, + useLatest: useLatest + ) + + _ = tokenizer.encode(text: text, addSpecialTokens: addSpecialTokens) + + var times: [Double] = [] + for i in 1 ... runs { + let start = CFAbsoluteTimeGetCurrent() + let tokenIds = tokenizer.encode(text: text, addSpecialTokens: addSpecialTokens) + let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000 + times.append(elapsed) + print( + "Tokenization run \(i): \(String(format: "%.1f", elapsed))ms (\(tokenIds.count) tokens)" + ) + } + + return BenchmarkStats(times: times) +} + +/// Benchmark decoding on a preloaded tokenizer without initializing a model. +public func benchmarkDecoding( + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + configuration: MLXLMCommon.ModelConfiguration = .init(id: "mlx-community/Qwen3-0.6B-4bit"), + text: String = "The quick brown fox jumps over the lazy dog.", + addSpecialTokens: Bool = true, + skipSpecialTokens: Bool = false, + useLatest: Bool = false, + runs: Int = BenchmarkDefaults.decodingRuns, + decodesPerRun: Int = BenchmarkDefaults.decodesPerRun +) async throws -> BenchmarkStats { + precondition(decodesPerRun > 0, "decodesPerRun must be greater than zero") + + let tokenizer = try await loadTokenizerForBenchmark( + from: downloader, + using: tokenizerLoader, + configuration: configuration, + useLatest: useLatest + ) + let tokenIds = tokenizer.encode(text: text, addSpecialTokens: addSpecialTokens) + + _ = tokenizer.decode(tokenIds: tokenIds, skipSpecialTokens: skipSpecialTokens) + + var times: [Double] = [] + for i in 1 ... runs { + var decoded = "" + let start = CFAbsoluteTimeGetCurrent() + for _ in 0 ..< decodesPerRun { + decoded = tokenizer.decode(tokenIds: tokenIds, skipSpecialTokens: skipSpecialTokens) + } + let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000 / Double(decodesPerRun) + times.append(elapsed) + print( + "Decoding run \(i): \(String(format: "%.1f", elapsed))ms avg over \(decodesPerRun)x " + + "(\(decoded.count) chars)" + ) + } + + return BenchmarkStats(times: times) +} + +/// Benchmark LLM model loading. Performs a warm-up run, then measures `runs` timed loads. +public func benchmarkLLMLoading( + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + modelId: String = "mlx-community/Qwen3-0.6B-4bit", + runs: Int = BenchmarkDefaults.loadingRuns +) async throws -> BenchmarkStats { + let config = MLXLMCommon.ModelConfiguration(id: modelId) + + _ = try await LLMModelFactory.shared.load( + from: downloader, using: tokenizerLoader, configuration: config + ) { _ in } + Memory.clearCache() + + var times: [Double] = [] + for i in 1 ... runs { + let start = CFAbsoluteTimeGetCurrent() + _ = try await LLMModelFactory.shared.load( + from: downloader, using: tokenizerLoader, configuration: config + ) { _ in } + let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000 + times.append(elapsed) + print("LLM load run \(i): \(String(format: "%.1f", elapsed))ms") + Memory.clearCache() + } + + return BenchmarkStats(times: times) +} + +/// Benchmark VLM model loading. Performs a warm-up run, then measures `runs` timed loads. +public func benchmarkVLMLoading( + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + modelId: String = "mlx-community/Qwen2-VL-2B-Instruct-4bit", + runs: Int = BenchmarkDefaults.loadingRuns +) async throws -> BenchmarkStats { + let config = MLXLMCommon.ModelConfiguration(id: modelId) + + _ = try await VLMModelFactory.shared.load( + from: downloader, using: tokenizerLoader, configuration: config + ) { _ in } + Memory.clearCache() + + var times: [Double] = [] + for i in 1 ... runs { + let start = CFAbsoluteTimeGetCurrent() + _ = try await VLMModelFactory.shared.load( + from: downloader, using: tokenizerLoader, configuration: config + ) { _ in } + let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000 + times.append(elapsed) + print("VLM load run \(i): \(String(format: "%.1f", elapsed))ms") + Memory.clearCache() + } + + return BenchmarkStats(times: times) +} + +/// Benchmark embedding model loading. Performs a warm-up run, then measures `runs` timed loads. +public func benchmarkEmbeddingLoading( + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + configuration: MLXEmbedders.ModelConfiguration = .init( + id: "mlx-community/Qwen3-Embedding-0.6B-8bit"), + runs: Int = BenchmarkDefaults.loadingRuns +) async throws -> BenchmarkStats { + _ = try await MLXEmbedders.loadModelContainer( + from: downloader, using: tokenizerLoader, configuration: configuration + ) { _ in } + Memory.clearCache() + + var times: [Double] = [] + for i in 1 ... runs { + let start = CFAbsoluteTimeGetCurrent() + _ = try await MLXEmbedders.loadModelContainer( + from: downloader, using: tokenizerLoader, configuration: configuration + ) { _ in } + let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000 + times.append(elapsed) + print("Embedding load run \(i): \(String(format: "%.1f", elapsed))ms") + Memory.clearCache() + } + + return BenchmarkStats(times: times) +} + +// MARK: - Download Benchmarks + +/// Benchmark download cache hit performance. Ensures the model is cached with a warm-up +/// download, then measures repeated cache lookups. +public func benchmarkDownloadCacheHit( + from downloader: any Downloader, + modelId: String = "mlx-community/Qwen3-0.6B-4bit", + runs: Int = BenchmarkDefaults.downloadRuns +) async throws -> BenchmarkStats { + let patterns = modelDownloadPatterns + + // Warm-up: ensure the model is cached + _ = try await downloader.download( + id: modelId, revision: "main", matching: patterns, + useLatest: false, progressHandler: { _ in }) + + var times: [Double] = [] + for i in 1 ... runs { + let start = CFAbsoluteTimeGetCurrent() + _ = try await downloader.download( + id: modelId, revision: "main", matching: patterns, + useLatest: false, progressHandler: { _ in }) + let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000 + times.append(elapsed) + print("Download cache hit run \(i): \(String(format: "%.1f", elapsed))ms") + } + + return BenchmarkStats(times: times) +} diff --git a/Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift b/Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift new file mode 100644 index 000000000..ed4089134 --- /dev/null +++ b/Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift @@ -0,0 +1,614 @@ +// Shared integration test logic for verifying end-to-end model loading and generation. +// Integration packages inject their own Downloader and TokenizerLoader, then call +// these functions which run the test and throw on failure. + +import CoreImage +import Foundation +import MLX +import MLXEmbedders +import MLXLLM +import MLXLMCommon +import MLXVLM + +// Both MLXLMCommon and MLXEmbedders define ModelContainer. +public typealias LMModelContainer = MLXLMCommon.ModelContainer +public typealias EmbeddingModelContainer = MLXEmbedders.ModelContainer + +// MARK: - Error + +public struct IntegrationTestFailure: LocalizedError { + public let errorDescription: String? + + public init(_ message: String) { + self.errorDescription = message + } +} + +private func check(_ condition: Bool, _ message: String) throws { + guard condition else { throw IntegrationTestFailure(message) } +} + +// MARK: - Model IDs + +public enum IntegrationTestModelIDs { + public static let llm = "mlx-community/Qwen3-4B-Instruct-2507-4bit" + public static let vlm = "mlx-community/Qwen3-VL-4B-Instruct-4bit" + public static let lfm2 = "mlx-community/LFM2-2.6B-Exp-4bit" + public static let glm4 = "mlx-community/GLM-4-9B-0414-4bit" +} + +// MARK: - Model Loading + +/// Shared model cache that loads each model at most once per test run. +public actor IntegrationTestModels { + private let downloader: any Downloader + private let tokenizerLoader: any TokenizerLoader + + private var llmTask: Task? + private var vlmTask: Task? + private var lfm2Task: Task? + private var glm4Task: Task? + + public init(downloader: any Downloader, tokenizerLoader: any TokenizerLoader) { + self.downloader = downloader + self.tokenizerLoader = tokenizerLoader + } + + public func llmContainer() async throws -> LMModelContainer { + if let task = llmTask { + return try await task.value + } + let downloader = self.downloader + let tokenizerLoader = self.tokenizerLoader + let id = IntegrationTestModelIDs.llm + let task = Task { + print("Loading LLM: \(id)") + let container = try await LLMModelFactory.shared.loadContainer( + from: downloader, using: tokenizerLoader, + configuration: .init(id: id), + progressHandler: logProgress(id) + ) + print("Loaded LLM: \(id)") + return container + } + llmTask = task + return try await task.value + } + + public func vlmContainer() async throws -> LMModelContainer { + if let task = vlmTask { + return try await task.value + } + let downloader = self.downloader + let tokenizerLoader = self.tokenizerLoader + let id = IntegrationTestModelIDs.vlm + let task = Task { + print("Loading VLM: \(id)") + let container = try await VLMModelFactory.shared.loadContainer( + from: downloader, using: tokenizerLoader, + configuration: .init(id: id), + progressHandler: logProgress(id) + ) + print("Loaded VLM: \(id)") + return container + } + vlmTask = task + return try await task.value + } + + public func lfm2Container() async throws -> LMModelContainer { + if let task = lfm2Task { + return try await task.value + } + let downloader = self.downloader + let tokenizerLoader = self.tokenizerLoader + let id = IntegrationTestModelIDs.lfm2 + let task = Task { + print("Loading LFM2: \(id)") + let container = try await LLMModelFactory.shared.loadContainer( + from: downloader, using: tokenizerLoader, + configuration: .init(id: id), + progressHandler: logProgress(id) + ) + print("Loaded LFM2: \(id)") + return container + } + lfm2Task = task + return try await task.value + } + + public func glm4Container() async throws -> LMModelContainer { + if let task = glm4Task { + return try await task.value + } + let downloader = self.downloader + let tokenizerLoader = self.tokenizerLoader + let id = IntegrationTestModelIDs.glm4 + let task = Task { + print("Loading GLM4: \(id)") + let container = try await LLMModelFactory.shared.loadContainer( + from: downloader, using: tokenizerLoader, + configuration: .init(id: id), + progressHandler: logProgress(id) + ) + print("Loaded GLM4: \(id)") + return container + } + glm4Task = task + return try await task.value + } + + public func embeddingContainer() async throws -> EmbeddingModelContainer { + let downloader = self.downloader + let tokenizerLoader = self.tokenizerLoader + let id = "nomic_text_v1_5" + print("Loading embedding model: \(id)") + let container = try await MLXEmbedders.loadModelContainer( + from: downloader, using: tokenizerLoader, configuration: .nomic_text_v1_5, + progressHandler: logProgress(id) + ) + print("Loaded embedding model: \(id)") + return container + } +} + +// MARK: - ChatSession Tests + +private let generateParameters = GenerateParameters(maxTokens: 200, temperature: 0) + +public enum ChatSessionTests { + + public static func oneShot(container: LMModelContainer) async throws { + let session = ChatSession(container, generateParameters: generateParameters) + let result = try await streamAndCollect( + session.streamResponse( + to: "What is 2+2? Reply with just the number."), label: "One-shot") + try check( + result.contains("4") || result.lowercased().contains("four"), + "Expected '4' or 'four' in response, got: \(result)" + ) + } + + public static func oneShotStream(container: LMModelContainer) async throws { + let session = ChatSession(container, generateParameters: generateParameters) + let result = try await streamAndCollect( + session.streamResponse( + to: "What is 2+2? Reply with just the number."), label: "Stream") + try check( + result.contains("4") || result.lowercased().contains("four"), + "Expected '4' or 'four' in streamed response, got: \(result)" + ) + } + + public static func multiTurnConversation(container: LMModelContainer) async throws { + let session = ChatSession( + container, instructions: "You are a helpful assistant. Keep responses brief.", + generateParameters: generateParameters) + + _ = try await streamAndCollect( + session.streamResponse( + to: "My name is Alice."), label: "Turn 1") + + let response2 = try await streamAndCollect( + session.streamResponse( + to: "What is my name?"), label: "Turn 2") + + try check( + response2.lowercased().contains("alice"), + "Expected 'Alice' in response, got: \(response2)" + ) + } + + public static func visionModel(container: LMModelContainer) async throws { + let session = ChatSession(container, generateParameters: generateParameters) + let redImage = CIImage(color: .red).cropped( + to: CGRect(x: 0, y: 0, width: 100, height: 100)) + + let result = try await streamAndCollect( + session.streamResponse( + to: "What color is this image? Reply with just the color name.", + image: .ciImage(redImage)), label: "Vision") + try check( + result.lowercased().contains("red"), + "Expected 'red' in response, got: \(result)" + ) + } + + public static func streamDetailsWithTools(container: LMModelContainer) async throws { + let tools: [ToolSpec] = [weatherToolSchema] + let session = ChatSession(container, generateParameters: generateParameters, tools: tools) + + var responseText = "" + var toolCalls: [ToolCall] = [] + + var info: GenerateCompletionInfo? + print("Tools: ", terminator: "") + for try await generation in session.streamDetails( + to: "What is the weather in San Francisco?", images: [], videos: []) + { + switch generation { + case .chunk(let text): + print(text, terminator: "") + responseText += text + case .toolCall(let toolCall): + toolCalls.append(toolCall) + case .info(let completionInfo): + info = completionInfo + } + } + print() + if let info { + print( + "Generation info: \(info.generationTokenCount) tokens, stop reason: \(info.stopReason)" + ) + } + if !toolCalls.isEmpty { + print("Tool calls: \(toolCalls)") + } + + try check( + !responseText.isEmpty || !toolCalls.isEmpty, + "Expected either text or tool calls, got neither (generated \(info?.generationTokenCount ?? 0) tokens, stop reason: \(String(describing: info?.stopReason)))" + ) + + // If we got tool calls, feed back a tool result and verify the model responds + if !toolCalls.isEmpty { + let followUp = try await streamAndCollect( + session.streamResponse( + to: "Foggy with a high in the low 60s, clearing later in the day", + role: .tool, images: [], videos: []), + label: "Tool result") + try check( + !followUp.isEmpty, + "Expected a response after providing tool result, got empty string" + ) + } + } + + public static func toolInvocation(container: LMModelContainer) async throws { + struct EmptyInput: Codable {} + + struct TimeOutput: Codable { + let time: String + } + + let timeTool = Tool( + name: "get_time", + description: "Get the current date and time including day of week.", + parameters: [] + ) { _ in + TimeOutput(time: "Wed Feb 18 17:50:43 PST 2026") + } + + let session = ChatSession( + container, generateParameters: generateParameters, + tools: [timeTool.schema] + ) { toolCall in + if toolCall.function.name == timeTool.name { + return try await toolCall.execute(with: timeTool).toolResult + } + return "Unknown tool: \(toolCall.function.name)" + } + + let result = try await streamAndCollect( + session.streamResponse( + to: "What day of week is it?"), label: "Tool invocation") + + try check( + result.lowercased().contains("wed") || result.lowercased().contains("wednesday"), + "Expected 'Wed' or 'Wednesday' in response, got: \(result)" + ) + } + + public static func promptRehydration(container: LMModelContainer) async throws { + let history: [Chat.Message] = [ + .system("You are a helpful assistant."), + .user("My name is Bob."), + .assistant("Hello Bob! How can I help you today?"), + ] + + let session = ChatSession( + container, history: history, generateParameters: generateParameters) + let response = try await streamAndCollect( + session.streamResponse( + to: "What is my name?"), label: "Rehydration") + + try check( + response.lowercased().contains("bob"), + "Expected 'Bob' in response (prompt rehydration), got: \(response)" + ) + } +} + +// MARK: - Stream Helper + +private func streamAndCollect( + _ stream: AsyncThrowingStream, + label: String +) async throws -> String { + var result = "" + print("\(label): ", terminator: "") + for try await token in stream { + print(token, terminator: "") + result += token + } + print() + return result +} + +// MARK: - Embedder Tests + +public enum EmbedderTests { + + public static func gemma3Embedder( + downloader: any Downloader, tokenizerLoader: any TokenizerLoader + ) async throws { + let modelId = "mlx-community/gemma-3-1b-it-qat-4bit" + print("Loading Gemma 3 embedding model: \(modelId)") + let modelContainer = try await MLXEmbedders.loadModelContainer( + from: downloader, using: tokenizerLoader, configuration: .init(id: modelId), + progressHandler: logProgress(modelId) + ) + print("Loaded Gemma 3 embedding model: \(modelId)") + + let inputs = [ + "The Coca-Cola Company is a soft drink company based in Atlanta, Georgia, USA.", + "In the United States, PepsiCo Inc. is a leading soft drink company.", + ] + + let resultEmbeddings = await modelContainer.perform { + (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in + let encoded = inputs.map { + tokenizer.encode(text: $0, addSpecialTokens: true) + } + let maxLength = encoded.reduce(into: 1) { acc, elem in + acc = max(acc, elem.count) + } + + let padded = stacked( + encoded.map { elem in + MLXArray( + elem + + Array( + repeating: tokenizer.eosTokenId ?? 0, + count: maxLength - elem.count)) + }) + + let mask = (padded .!= (tokenizer.eosTokenId ?? 0)) + let tokenTypes = MLXArray.zeros(like: padded) + + let modelOutput = model( + padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask) + + let result = pooling( + modelOutput, + normalize: true, applyLayerNorm: true + ) + result.eval() + return result.map { $0.asArray(Float.self) } + } + + try check( + resultEmbeddings.count == inputs.count, + "Should have one embedding per input, got \(resultEmbeddings.count)" + ) + for embedding in resultEmbeddings { + try check( + embedding.count == 1152, + "Gemma 3 1B embedding size should be 1152, got \(embedding.count)" + ) + let l2Norm = sqrt(embedding.map { $0 * $0 }.reduce(0, +)) + try check( + abs(l2Norm - 1.0) < 0.05, + "Embeddings should be approximately L2-normalized, got L2 norm \(l2Norm)" + ) + } + + let similarity = zip(resultEmbeddings[0], resultEmbeddings[1]).map(*).reduce(0, +) + try check( + similarity > 0.0, + "Similarity between related sentences should be positive, got \(similarity)" + ) + } + + public static func readmeExample(container: EmbeddingModelContainer) async throws { + let searchInputs = [ + "search_query: Animals in Tropical Climates.", + "search_document: Elephants", + "search_document: Horses", + "search_document: Polar Bears", + ] + + let resultEmbeddings = await container.perform { + (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in + let inputs = searchInputs.map { + tokenizer.encode(text: $0, addSpecialTokens: true) + } + let maxLength = inputs.reduce(into: 16) { acc, elem in + acc = max(acc, elem.count) + } + let padded = stacked( + inputs.map { elem in + MLXArray( + elem + + Array( + repeating: tokenizer.eosTokenId ?? 0, + count: maxLength - elem.count)) + }) + let mask = (padded .!= tokenizer.eosTokenId ?? 0) + let tokenTypes = MLXArray.zeros(like: padded) + let result = pooling( + model(padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask), + normalize: true, applyLayerNorm: true + ) + result.eval() + return result.map { $0.asArray(Float.self) } + } + + let searchQueryEmbedding = resultEmbeddings[0] + let documentEmbeddings = resultEmbeddings[1...] + let similarities = documentEmbeddings.map { docEmbedding in + zip(searchQueryEmbedding, docEmbedding).map(*).reduce(0, +) + } + let documentNames = searchInputs[1...].map { + $0.replacingOccurrences(of: "search_document: ", with: "") + } + + let expectedSimilarities: [Float] = [0.6854175, 0.6644787, 0.63326025] + let tolerance: Float = 1e-4 + + for (index, resultSimilarity) in similarities.enumerated() { + try check( + abs(resultSimilarity - expectedSimilarities[index]) < tolerance, + "Similarity mismatch for \(documentNames[index]): expected \(expectedSimilarities[index]), got \(resultSimilarity)" + ) + } + } +} + +// MARK: - Tool Call Tests + +public enum ToolCallTests { + + public static func lfm2FormatAutoDetection(container: LMModelContainer) async throws { + let config = await container.configuration + try check( + config.toolCallFormat == ToolCallFormat.lfm2, + "Expected .lfm2 tool call format, got: \(String(describing: config.toolCallFormat))" + ) + } + + public static func lfm2EndToEndGeneration(container: LMModelContainer) async throws { + let (result, toolCalls) = try await generateWithTools( + container: container, + userMessage: "What's the weather in Tokyo?") + + print("LFM2 Output:", result) + print("LFM2 Tool Calls:", toolCalls) + + if !toolCalls.isEmpty { + let toolCall = toolCalls[0] + try check( + toolCall.function.name == "get_weather", + "Expected tool name 'get_weather', got: \(toolCall.function.name)" + ) + if case .string(let location) = toolCall.function.arguments["location"] { + try check( + location.lowercased().contains("tokyo"), + "Expected location containing 'Tokyo', got: \(location)" + ) + } + } + } + + public static func glm4FormatAutoDetection(container: LMModelContainer) async throws { + let config = await container.configuration + try check( + config.toolCallFormat == ToolCallFormat.glm4, + "Expected .glm4 tool call format, got: \(String(describing: config.toolCallFormat))" + ) + } + + public static func glm4EndToEndGeneration(container: LMModelContainer) async throws { + let (result, toolCalls) = try await generateWithTools( + container: container, + userMessage: "What's the weather in Paris?") + + print("GLM4 Output:", result) + print("GLM4 Tool Calls:", toolCalls) + + if !toolCalls.isEmpty { + let toolCall = toolCalls[0] + try check( + toolCall.function.name == "get_weather", + "Expected tool name 'get_weather', got: \(toolCall.function.name)" + ) + if case .string(let location) = toolCall.function.arguments["location"] { + try check( + location.lowercased().contains("paris"), + "Expected location containing 'Paris', got: \(location)" + ) + } + } + } + + private static func generateWithTools( + container: LMModelContainer, + userMessage: String + ) async throws -> (text: String, toolCalls: [ToolCall]) { + try await container.perform { context in + let input = UserInput( + chat: [ + .system( + "You are a helpful assistant with access to tools. When asked about weather, use the get_weather function." + ), + .user(userMessage), + ], + tools: [weatherToolSchema] + ) + let lmInput = try await context.processor.prepare(input: input) + let stream = try generate( + input: lmInput, + parameters: GenerateParameters(maxTokens: 100), + context: context + ) + + var text = "" + var toolCalls: [ToolCall] = [] + for try await generation in stream { + switch generation { + case .chunk(let chunk): + text += chunk + case .toolCall(let toolCall): + toolCalls.append(toolCall) + case .info: + break + } + } + return (text, toolCalls) + } + } +} + +// MARK: - Progress Logging + +private func logProgress(_ label: String) -> @Sendable (Progress) -> Void { + let lock = NSLock() + nonisolated(unsafe) var lastThreshold = -1 + return { progress in + let pct = Int(progress.fractionCompleted * 100) + let threshold = pct / 5 + lock.lock() + let shouldPrint = threshold > lastThreshold + if shouldPrint { lastThreshold = threshold } + lock.unlock() + if shouldPrint { + print(" \(label): \(pct)%") + } + } +} + +// MARK: - Shared Constants + +private let weatherToolSchema: ToolSpec = [ + "type": "function", + "function": [ + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": [ + "type": "object", + "properties": [ + "location": [ + "type": "string", + "description": "The city name, e.g. San Francisco", + ] as [String: any Sendable], + "unit": [ + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature unit", + ] as [String: any Sendable], + ] as [String: any Sendable], + "required": ["location"], + ] as [String: any Sendable], + ] as [String: any Sendable], +] diff --git a/Libraries/IntegrationTestHelpers/README.md b/Libraries/IntegrationTestHelpers/README.md new file mode 100644 index 000000000..39ec6ff14 --- /dev/null +++ b/Libraries/IntegrationTestHelpers/README.md @@ -0,0 +1,10 @@ +# Integration Test Helpers + +`IntegrationTestHelpers` and `BenchmarkHelpers` provide shared test logic for verifying end-to-end model loading, inference, tokenizer performance, and download performance. They are designed to be used by integration packages that supply their own `Downloader` and `TokenizerLoader` implementations. + +## Integration packages + +- [Swift Tokenizers MLX](https://github.com/DePasqualeOrg/swift-tokenizers-mlx): Uses [Swift Tokenizers](https://github.com/DePasqualeOrg/swift-tokenizers) and [Swift HF API](https://github.com/DePasqualeOrg/swift-hf-api) +- [Swift Transformers MLX](https://github.com/DePasqualeOrg/swift-transformers-mlx): Uses [Swift Transformers](https://github.com/huggingface/swift-transformers) and [Swift Hugging Face](https://github.com/huggingface/swift-huggingface) + +Integration tests and benchmarks are run from those packages. diff --git a/Libraries/MLXEmbedders/EmbeddingModel.swift b/Libraries/MLXEmbedders/EmbeddingModel.swift index a869c24fd..ca8f90882 100644 --- a/Libraries/MLXEmbedders/EmbeddingModel.swift +++ b/Libraries/MLXEmbedders/EmbeddingModel.swift @@ -1,10 +1,9 @@ // Copyright © 2024 Apple Inc. import Foundation -@preconcurrency import Hub import MLX +import MLXLMCommon import MLXNN -import Tokenizers /// Container for models that guarantees single threaded access. /// @@ -44,23 +43,21 @@ public actor ModelContainer { self.pooler = pooler } - /// build the model and tokenizer without passing non-sendable data over isolation barriers + /// Build the model and tokenizer without passing non-sendable data over isolation barriers public init( - hub: HubApi, modelDirectory: URL, - configuration: ModelConfiguration + tokenizerDirectory: URL, + configuration: ModelConfiguration, + tokenizerLoader: any TokenizerLoader ) async throws { - // Load tokenizer config and model in parallel using async let. - async let tokenizerConfigTask = loadTokenizerConfig( - configuration: configuration, hub: hub) + // Load tokenizer and model in parallel + async let tokenizerTask = tokenizerLoader.load(from: tokenizerDirectory) self.model = try loadSynchronous( modelDirectory: modelDirectory, modelName: configuration.name) self.pooler = loadPooling(modelDirectory: modelDirectory, model: model) - let (tokenizerConfig, tokenizerData) = try await tokenizerConfigTask - self.tokenizer = try PreTrainedTokenizer( - tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) + self.tokenizer = try await tokenizerTask } /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as diff --git a/Libraries/MLXEmbedders/Load.swift b/Libraries/MLXEmbedders/Load.swift index ed9518db4..4f6cd84f5 100644 --- a/Libraries/MLXEmbedders/Load.swift +++ b/Libraries/MLXEmbedders/Load.swift @@ -1,11 +1,9 @@ // Copyright © 2024 Apple Inc. import Foundation -@preconcurrency import Hub import MLX import MLXLMCommon import MLXNN -import Tokenizers /// Errors encountered during the model loading and initialization process. /// @@ -26,9 +24,6 @@ public enum EmbedderError: LocalizedError { /// The configuration file exists but contains invalid JSON or missing required fields. case configurationDecodingError(String, String, DecodingError) - /// Thrown when the tokenizer configuration is missing from the model bundle or Hub. - case missingTokenizerConfig - /// A human-readable description of the error. public var errorDescription: String? { switch self { @@ -39,8 +34,6 @@ public enum EmbedderError: LocalizedError { case .configurationDecodingError(let file, let modelName, let decodingError): let errorDetail = extractDecodingErrorDetail(decodingError) return "Failed to parse \(file) for model '\(modelName)': \(errorDetail)" - case .missingTokenizerConfig: - return "Missing tokenizer configuration" } } @@ -70,43 +63,48 @@ public enum EmbedderError: LocalizedError { } } -/// Prepares the local model directory by downloading files from the Hub or resolving a local path. -/// -/// If the `ModelConfiguration` identifies a remote repo, this function downloads weights -/// (`.safetensors`) and config files. It includes a fallback mechanism: if the user is -/// offline or unauthorized, it attempts to resolve the files from the local cache. +/// Resolve model and tokenizer directories from a ``ModelConfiguration`` +/// using a ``Downloader``. /// /// - Parameters: -/// - hub: The `HubApi` instance for managing downloads. +/// - downloader: The downloader to use for fetching remote resources. /// - configuration: The configuration identifying the model. +/// - useLatest: When true, always checks the provider for updates. /// - progressHandler: A closure to monitor download progress. -/// - Returns: A `URL` pointing to the directory containing model files. -func prepareModelDirectory( - hub: HubApi, +/// - Returns: A tuple of (modelDirectory, tokenizerDirectory). +func resolveDirectories( + from downloader: any Downloader, configuration: ModelConfiguration, + useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void -) async throws -> URL { - do { - switch configuration.id { - case .id(let id): - let repo = Hub.Repo(id: id) - let modelFiles = ["*.safetensors", "config.json", "*/config.json"] - return try await hub.snapshot( - from: repo, matching: modelFiles, progressHandler: progressHandler) - - case .directory(let directory): - return directory - } - } catch Hub.HubClientError.authorizationRequired { - return configuration.modelDirectory(hub: hub) - } catch { - let nserror = error as NSError - if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet { - return configuration.modelDirectory(hub: hub) - } else { - throw error - } +) async throws -> (modelDirectory: URL, tokenizerDirectory: URL) { + let modelDirectory: URL + switch configuration.id { + case .id(let id, let revision): + modelDirectory = try await downloader.download( + id: id, revision: revision, + matching: modelDownloadPatterns, + useLatest: useLatest, + progressHandler: progressHandler) + case .directory(let directory): + modelDirectory = directory + } + + let tokenizerDirectory: URL + switch configuration.tokenizerSource { + case .id(let id, let revision): + tokenizerDirectory = try await downloader.download( + id: id, revision: revision, + matching: tokenizerDownloadPatterns, + useLatest: useLatest, + progressHandler: { _ in }) + case .directory(let directory): + tokenizerDirectory = directory + case nil: + tokenizerDirectory = modelDirectory } + + return (modelDirectory, tokenizerDirectory) } /// Asynchronously loads the `EmbeddingModel` and its associated `Tokenizer`. @@ -116,19 +114,23 @@ func prepareModelDirectory( /// structure is being built synchronously. /// /// - Parameters: -/// - hub: The `HubApi` instance (defaults to a new instance). +/// - downloader: The ``Downloader`` to use for fetching remote resources. /// - configuration: The model configuration. +/// - useLatest: When true, always checks the provider for updates. /// - progressHandler: A closure for tracking download progress. /// - Returns: A tuple containing the initialized `EmbeddingModel` and `Tokenizer`. public func load( - hub: HubApi = defaultHubApi, + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, configuration: ModelConfiguration, + useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } ) async throws -> (EmbeddingModel, Tokenizer) { - let modelDirectory = try await prepareModelDirectory( - hub: hub, configuration: configuration, progressHandler: progressHandler) + let (modelDirectory, tokenizerDirectory) = try await resolveDirectories( + from: downloader, configuration: configuration, useLatest: useLatest, + progressHandler: progressHandler) - async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub) + async let tokenizerTask = tokenizerLoader.load(from: tokenizerDirectory) let model = try loadSynchronous(modelDirectory: modelDirectory, modelName: configuration.name) let tokenizer = try await tokenizerTask @@ -213,17 +215,65 @@ func loadSynchronous(modelDirectory: URL, modelName: String) throws -> Embedding /// or tasks may need to access the embedding model simultaneously. /// /// - Parameters: -/// - hub: The `HubApi` instance. +/// - downloader: The ``Downloader`` to use for fetching remote resources. /// - configuration: The model configuration. +/// - useLatest: When true, always checks the provider for updates. /// - progressHandler: A closure for tracking download progress. /// - Returns: A thread-safe `ModelContainer` instance. public func loadModelContainer( - hub: HubApi = defaultHubApi, + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, configuration: ModelConfiguration, + useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } ) async throws -> ModelContainer { - let modelDirectory = try await prepareModelDirectory( - hub: hub, configuration: configuration, progressHandler: progressHandler) + let (modelDirectory, tokenizerDirectory) = try await resolveDirectories( + from: downloader, configuration: configuration, useLatest: useLatest, + progressHandler: progressHandler) + return try await ModelContainer( - hub: hub, modelDirectory: modelDirectory, configuration: configuration) + modelDirectory: modelDirectory, + tokenizerDirectory: tokenizerDirectory, + configuration: configuration, + tokenizerLoader: tokenizerLoader) +} + +/// Load an embedding model from a local directory. +/// +/// No downloader is needed — the model and tokenizer are loaded from +/// the given directory. +/// +/// - Parameter directory: The local directory containing model files. +/// - Returns: A tuple containing the initialized `EmbeddingModel` and `Tokenizer`. +public func load( + from directory: URL, + using tokenizerLoader: any TokenizerLoader +) async throws -> (EmbeddingModel, Tokenizer) { + let name = + directory.deletingLastPathComponent().lastPathComponent + "/" + + directory.lastPathComponent + async let tokenizerTask = tokenizerLoader.load(from: directory) + let model = try loadSynchronous(modelDirectory: directory, modelName: name) + let tokenizer = try await tokenizerTask + return (model, tokenizer) +} + +/// Load an embedding model container from a local directory. +/// +/// No downloader is needed — the model and tokenizer are loaded from +/// the given directory. +/// +/// - Parameters: +/// - directory: The local directory containing model files. +/// - tokenizerLoader: The ``TokenizerLoader`` to use for loading the tokenizer. +/// - Returns: A thread-safe `ModelContainer` instance. +public func loadModelContainer( + from directory: URL, + using tokenizerLoader: any TokenizerLoader +) async throws -> ModelContainer { + try await ModelContainer( + modelDirectory: directory, + tokenizerDirectory: directory, + configuration: ModelConfiguration(directory: directory), + tokenizerLoader: tokenizerLoader) } diff --git a/Libraries/MLXEmbedders/Models.swift b/Libraries/MLXEmbedders/Models.swift index ac954b945..06489c4bb 100644 --- a/Libraries/MLXEmbedders/Models.swift +++ b/Libraries/MLXEmbedders/Models.swift @@ -1,7 +1,7 @@ // Copyright © 2024 Apple Inc. import Foundation -import Hub +import MLXLMCommon /// A registry and configuration provider for embedding models. /// @@ -22,7 +22,7 @@ public struct ModelConfiguration: Sendable { /// The backing storage for the model's location. public enum Identifier: Sendable { /// A Hugging Face Hub repository identifier (e.g., "BAAI/bge-small-en-v1.5"). - case id(String) + case id(String, revision: String = "main") /// A file system URL pointing to a local model directory. case directory(URL) } @@ -36,67 +36,44 @@ public struct ModelConfiguration: Sendable { /// it returns a path-based name (e.g., "ParentDir/ModelDir"). public var name: String { switch id { - case .id(let string): + case .id(let string, _): string case .directory(let url): url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent } } - /// An optional alternate Hub ID to use specifically for loading the tokenizer. + /// Where to load the tokenizer from when it differs from the model directory. /// - /// Use this if the model weights and tokenizer configuration are hosted in different repositories. - public let tokenizerId: String? - - /// An optional override string for specifying a specific tokenizer implementation. - /// - /// This is useful for providing compatibility hints to `swift-tokenizers` before - /// official support is updated. - public let overrideTokenizer: String? + /// - `.id`: download from a remote provider (requires a ``Downloader``) + /// - `.directory`: load from a local path + /// - `nil`: use the same directory as the model + public let tokenizerSource: TokenizerSource? /// Initializes a configuration using a Hub repository ID. /// - Parameters: /// - id: The Hugging Face repo ID. - /// - tokenizerId: Optional alternate repo for the tokenizer. - /// - overrideTokenizer: Optional specific tokenizer implementation name. + /// - revision: The Git revision to use (defaults to "main"). + /// - tokenizerSource: Optional alternate source for the tokenizer. public init( id: String, - tokenizerId: String? = nil, - overrideTokenizer: String? = nil + revision: String = "main", + tokenizerSource: TokenizerSource? = nil ) { - self.id = .id(id) - self.tokenizerId = tokenizerId - self.overrideTokenizer = overrideTokenizer + self.id = .id(id, revision: revision) + self.tokenizerSource = tokenizerSource } /// Initializes a configuration using a local directory. /// - Parameters: /// - directory: The `URL` of the model on disk. - /// - tokenizerId: Optional alternate repo for the tokenizer. - /// - overrideTokenizer: Optional specific tokenizer implementation name. + /// - tokenizerSource: Optional alternate source for the tokenizer. public init( directory: URL, - tokenizerId: String? = nil, - overrideTokenizer: String? = nil + tokenizerSource: TokenizerSource? = nil ) { self.id = .directory(directory) - self.tokenizerId = tokenizerId - self.overrideTokenizer = overrideTokenizer - } - - /// Resolves the local file system URL where the model is (or will be) stored. - /// - /// - Parameter hub: The `HubApi` used to resolve Hub paths. - /// - Returns: A `URL` pointing to the local directory. - public func modelDirectory(hub: HubApi = HubApi()) -> URL { - switch id { - case .id(let id): - let repo = Hub.Repo(id: id) - return hub.localRepoLocation(repo) - - case .directory(let directory): - return directory - } + self.tokenizerSource = tokenizerSource } // MARK: - Registry Management diff --git a/Libraries/MLXEmbedders/README.md b/Libraries/MLXEmbedders/README.md index 40e6a3100..10d9c5408 100644 --- a/Libraries/MLXEmbedders/README.md +++ b/Libraries/MLXEmbedders/README.md @@ -5,42 +5,104 @@ This directory contains ports of popular Encoders / Embedding Models. ## Usage Example ```swift - let modelContainer = try await loadModelContainer(configuration: .nomic_text_v1_5) - let searchInputs = [ - "search_query: Animals in Tropical Climates.", - "search_document: Elephants", - "search_document: Horses", - "search_document: Polar Bears", - ] - - // Generate embeddings - let resultEmbeddings = await modelContainer.perform { - (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in - let inputs = searchInputs.map { - tokenizer.encode(text: $0, addSpecialTokens: true) - } - // Pad to longest - let maxLength = inputs.reduce(into: 16) { acc, elem in - acc = max(acc, elem.count) - } - - let padded = stacked( - inputs.map { elem in - MLXArray( - elem - + Array( - repeating: tokenizer.eosTokenId ?? 0, - count: maxLength - elem.count)) - }) - let mask = (padded .!= tokenizer.eosTokenId ?? 0) - let tokenTypes = MLXArray.zeros(like: padded) - let result = pooling( - model(padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask), - normalize: true, applyLayerNorm: true - ) - result.eval() - return result.map { $0.asArray(Float.self) } +import MLXEmbedders +import MLXEmbeddersHuggingFace +import MLXLMTokenizers + +let modelContainer = try await loadModelContainer( + using: TokenizersLoader(), + configuration: .nomic_text_v1_5 +) +let searchInputs = [ + "search_query: Animals in Tropical Climates.", + "search_document: Elephants", + "search_document: Horses", + "search_document: Polar Bears", +] + +// Generate embeddings +let resultEmbeddings = await modelContainer.perform { + (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in + let inputs = searchInputs.map { + tokenizer.encode(text: $0, addSpecialTokens: true) + } + // Pad to longest + let maxLength = inputs.reduce(into: 16) { acc, elem in + acc = max(acc, elem.count) } + + let padded = stacked( + inputs.map { elem in + MLXArray( + elem + + Array( + repeating: tokenizer.eosTokenId ?? 0, + count: maxLength - elem.count)) + }) + let mask = (padded .!= tokenizer.eosTokenId ?? 0) + let tokenTypes = MLXArray.zeros(like: padded) + let result = pooling( + model(padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask), + normalize: true, applyLayerNorm: true + ) + result.eval() + return result.map { $0.asArray(Float.self) } +} +``` + +Load from a local directory: + +```swift +import MLXEmbedders +import MLXLMTokenizers + +let modelDirectory = URL(filePath: "/path/to/embedder") +let modelContainer = try await loadModelContainer( + from: modelDirectory, + using: TokenizersLoader() +) +``` + +Use a custom Hugging Face client: + +```swift +import MLXEmbedders +import MLXEmbeddersHuggingFace +import MLXLMTokenizers + +let hub = HubClient(token: "hf_...") +let modelContainer = try await loadModelContainer( + from: hub, + using: TokenizersLoader(), + configuration: .nomic_text_v1_5 +) +``` + +Use a custom downloader: + +```swift +import MLXEmbedders +import MLXLMCommon +import MLXLMTokenizers + +struct S3Downloader: Downloader { + func download( + id: String, + revision: String?, + matching patterns: [String], + useLatest: Bool, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> URL { + // Download files and return a local directory URL. + return URL(filePath: "/tmp/embedder") + } +} + +let modelContainer = try await loadModelContainer( + from: S3Downloader(), + using: TokenizersLoader(), + configuration: .init(id: "my-bucket/my-embedder") +) ``` diff --git a/Libraries/MLXEmbedders/Tokenizer.swift b/Libraries/MLXEmbedders/Tokenizer.swift deleted file mode 100644 index e7300f22a..000000000 --- a/Libraries/MLXEmbedders/Tokenizer.swift +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright © 2024 Apple Inc. - -import Foundation -import Hub -import Tokenizers - -/// Asynchronously loads and initializes a pretrained tokenizer. -/// -/// This function serves as the primary entry point for preparing a tokenizer. It fetches -/// configuration and vocabulary data—either from the Hugging Face Hub or a local -/// directory—and initializes a `PreTrainedTokenizer`. -/// -/// - Parameters: -/// - configuration: The `ModelConfiguration` containing the model ID or directory path. -/// - hub: An instance of `HubApi` used to manage downloads and file access. -/// - Returns: An initialized `Tokenizer` ready for encoding and decoding text. -/// - Throws: `EmbedderError.missingTokenizerConfig` if the configuration files cannot be found, -/// or standard network/parsing errors. -public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer -{ - let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig( - configuration: configuration, hub: hub) - - return try PreTrainedTokenizer( - tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) -} - -/// Retrieves the raw configuration and data files required to build a tokenizer. -/// -/// This internal helper handles the logic of determining where to fetch files from. -/// It includes a robust fallback: if a network request fails due to lack of internet -/// connectivity, it attempts to load the files from the local model directory. -/// -/// - Parameters: -/// - configuration: The model configuration providing the `tokenizerId` or `modelDirectory`. -/// - hub: The `HubApi` interface for remote or local file resolution. -/// - Returns: A tuple containing the `tokenizerConfig` and `tokenizerData` configurations. -/// - Throws: `NSURLError` for network issues (other than offline status). -/// - Throws: `EmbedderError.missingTokenizerConfig` if the configuration files are -/// successfully accessed but do not contain a valid `tokenizerConfig` payload. -/// This typically occurs when the model repository or directory is missing a -/// `tokenizer_config.json` file. -func loadTokenizerConfig( - configuration: ModelConfiguration, - hub: HubApi -) async throws -> (Config, Config) { - // from AutoTokenizer.from() -- this lets us override parts of the configuration - let config: LanguageModelConfigurationFromHub - - switch configuration.id { - case .id(let id): - do { - // Attempt to load from the remote Hub or Hub cache - let loaded = LanguageModelConfigurationFromHub( - modelName: configuration.tokenizerId ?? id, hubApi: hub) - - // Trigger an async fetch to verify the config exists - _ = try await loaded.tokenizerConfig - config = loaded - } catch { - let nserror = error as NSError - if nserror.domain == NSURLErrorDomain - && nserror.code == NSURLErrorNotConnectedToInternet - { - // Fallback: Internet connection is offline, load from the local model directory - config = LanguageModelConfigurationFromHub( - modelFolder: configuration.modelDirectory(hub: hub), hubApi: hub) - } else { - // Re-throw if it's a critical error (e.g., 404, parsing error) - throw error - } - } - case .directory(let directory): - // Load directly from a specified local directory - config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub) - } - - guard let tokenizerConfig = try await config.tokenizerConfig else { - throw EmbedderError.missingTokenizerConfig - } - let tokenizerData = try await config.tokenizerData - return (tokenizerConfig, tokenizerData) -} diff --git a/Libraries/MLXHuggingFace/Macros.swift b/Libraries/MLXHuggingFace/Macros.swift new file mode 100644 index 000000000..7160ae1ee --- /dev/null +++ b/Libraries/MLXHuggingFace/Macros.swift @@ -0,0 +1,55 @@ +import Foundation +import MLXLMCommon + +@freestanding(expression) +public macro hubDownloader(_ hub: Any) -> MLXLMCommon.Downloader = + #externalMacro(module: "MLXHuggingFaceMacros", type: "DownloaderMacro") + +@freestanding(expression) +public macro hubDownloader() -> MLXLMCommon.Downloader = + #externalMacro(module: "MLXHuggingFaceMacros", type: "DownloaderMacro") + +@freestanding(expression) +public macro adaptHuggingFaceTokenizer(_ upstream: Any) -> MLXLMCommon.Tokenizer = + #externalMacro(module: "MLXHuggingFaceMacros", type: "TokenizerAdaptorMacro") + +@freestanding(expression) +public macro huggingFaceTokenizerLoader() -> MLXLMCommon.TokenizerLoader = + #externalMacro(module: "MLXHuggingFaceMacros", type: "TokenizerLoaderMacro") + +@freestanding(expression) +public macro huggingFaceLoadModelContainer( + configuration: ModelConfiguration +) -> ModelContainer = + #externalMacro(module: "MLXHuggingFaceMacros", type: "LoadContainerMacro") + +@freestanding(expression) +public macro huggingFaceLoadModelContainer( + configuration: ModelConfiguration, + progressHandler: @Sendable @escaping (Progress) -> Void +) -> ModelContainer = + #externalMacro(module: "MLXHuggingFaceMacros", type: "LoadContainerMacro") + +@freestanding(expression) +public macro huggingFaceLoadModel( + configuration: ModelConfiguration +) -> ModelContext = + #externalMacro(module: "MLXHuggingFaceMacros", type: "LoadContextMacro") + +@freestanding(expression) +public macro huggingFaceLoadModel( + configuration: ModelConfiguration, + progressHandler: @Sendable @escaping (Progress) -> Void +) -> ModelContext = + #externalMacro(module: "MLXHuggingFaceMacros", type: "LoadContextMacro") + +public enum HuggingFaceDownloaderError: LocalizedError { + case invalidRepositoryID(String) + + public var errorDescription: String? { + switch self { + case .invalidRepositoryID(let id): + return "Invalid Hugging Face repository ID: '\(id)'. Expected format 'namespace/name'." + } + } +} diff --git a/Libraries/MLXHuggingFaceMacros/HuggingFaceIntegrationMacros.swift b/Libraries/MLXHuggingFaceMacros/HuggingFaceIntegrationMacros.swift new file mode 100644 index 000000000..74796efbc --- /dev/null +++ b/Libraries/MLXHuggingFaceMacros/HuggingFaceIntegrationMacros.swift @@ -0,0 +1,225 @@ +import SwiftCompilerPlugin +import SwiftSyntax +import SwiftSyntaxBuilder +import SwiftSyntaxMacros + +@main +struct Macros: CompilerPlugin { + let providingMacros: [Macro.Type] = [ + DownloaderMacro.self, + TokenizerAdaptorMacro.self, + TokenizerLoaderMacro.self, + LoadContainerMacro.self, + LoadContextMacro.self, + ] +} + +public struct DownloaderMacro: ExpressionMacro { + public static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + let argument = node.arguments.first?.expression.description ?? "HubClient()" + + return + """ + // make sure you: + // + // import HuggingFace + // + { (hubApi: HubClient) -> MLXLMCommon.Downloader in + struct HubBridge: MLXLMCommon.Downloader { + private let upstream: HubClient + + init(_ upstream: HubClient) { + self.upstream = upstream + } + + public func download( + id: String, + revision: String?, + matching patterns: [String], + useLatest: Bool, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> URL { + guard let repoID = HuggingFace.Repo.ID(rawValue: id) else { + throw HuggingFaceDownloaderError.invalidRepositoryID(id) + } + let revision = revision ?? "main" + + return try await upstream.downloadSnapshot( + of: repoID, + revision: revision, + matching: patterns, + progressHandler: { @MainActor progress in + progressHandler(progress) + } + ) + } + } + + return HubBridge(hubApi) + }(\(raw: argument)) + """ + } +} + +public struct TokenizerAdaptorMacro: ExpressionMacro { + public static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + guard let argument = node.arguments.first?.expression else { + throw MacroExpansionError.message("#adaptHuggingFaceTokenizer requires an argument") + } + + return + """ + // make sure you: + // + // import Tokenizers + // + { (huggingFaceTokenizer: Tokenizers.Tokenizer) -> MLXLMCommon.Tokenizer in + struct TokenizerBridge: MLXLMCommon.Tokenizer { + private let upstream: any Tokenizers.Tokenizer + + init(_ upstream: any Tokenizers.Tokenizer) { + self.upstream = upstream + } + + func encode(text: String, addSpecialTokens: Bool) -> [Int] { + upstream.encode(text: text, addSpecialTokens: addSpecialTokens) + } + + // swift-transformers uses `decode(tokens:)` instead of `decode(tokenIds:)`. + func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String { + upstream.decode(tokens: tokenIds, skipSpecialTokens: skipSpecialTokens) + } + + func convertTokenToId(_ token: String) -> Int? { + upstream.convertTokenToId(token) + } + + func convertIdToToken(_ id: Int) -> String? { + upstream.convertIdToToken(id) + } + + var bosToken: String? { upstream.bosToken } + var eosToken: String? { upstream.eosToken } + var unknownToken: String? { upstream.unknownToken } + + func applyChatTemplate( + messages: [[String: any Sendable]], + tools: [[String: any Sendable]]?, + additionalContext: [String: any Sendable]? + ) throws -> [Int] { + do { + return try upstream.applyChatTemplate( + messages: messages, tools: tools, additionalContext: additionalContext) + } catch Tokenizers.TokenizerError.missingChatTemplate { + throw MLXLMCommon.TokenizerError.missingChatTemplate + } + } + } + + return TokenizerBridge(huggingFaceTokenizer) + }(\(argument)) + """ + } +} + +public struct TokenizerLoaderMacro: ExpressionMacro { + public static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + return + """ + { () -> MLXLMCommon.TokenizerLoader in + struct TransformersLoader: MLXLMCommon.TokenizerLoader { + public init() {} + + public func load(from directory: URL) async throws -> any MLXLMCommon.Tokenizer { + // make sure you: + // + // import Tokenizers + // + let upstream = try await AutoTokenizer.from(modelFolder: directory) + return #adaptHuggingFaceTokenizer(upstream) + } + } + + return TransformersLoader() + }() + """ + } +} + +public struct LoadContainerMacro: ExpressionMacro { + public static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + guard let configuration = node.arguments.first?.expression else { + throw MacroExpansionError.message( + "#huggingFaceLoadModelContainer requires a configuration") + } + + let progress = + if let expr = node.arguments.first(where: { $0.label?.text == "progressHandler" })? + .expression + { + expr.description + } else { + "{ _ in }" + } + + return + """ + loadModelContainer( + from: #hubDownloader(), + using: #huggingFaceTokenizerLoader(), + configuration: \(configuration), + progressHandler: \(raw: progress)) + """ + } +} + +public struct LoadContextMacro: ExpressionMacro { + public static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + guard let configuration = node.arguments.first?.expression else { + throw MacroExpansionError.message("#huggingFaceLoadModel requires a configuration") + } + + let progress = + if let expr = node.arguments.first(where: { $0.label?.text == "progressHandler" })? + .expression + { + expr.description + } else { + "{ _ in }" + } + + return + """ + loadModel( + from: #hubDownloader(), + using: #huggingFaceTokenizerLoader(), + configuration: \(configuration), + progressHandler: \(raw: progress)) + """ + } +} + +enum MacroExpansionError: Error, CustomStringConvertible { + case message(String) + + var description: String { + switch self { + case .message(let text): return text + } + } +} diff --git a/Libraries/MLXLLM/Documentation.docc/Documentation.md b/Libraries/MLXLLM/Documentation.docc/Documentation.md index 1d17aeba4..c3d5c204a 100644 --- a/Libraries/MLXLLM/Documentation.docc/Documentation.md +++ b/Libraries/MLXLLM/Documentation.docc/Documentation.md @@ -16,7 +16,10 @@ See . Using LLMs and VLMs is as easy as this: ```swift -let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") +let model = try await loadModel( + using: TokenizersLoader(), + id: "mlx-community/Qwen3-4B-4bit" +) let session = ChatSession(model) print(try await session.respond(to: "What are two things to see in San Francisco?") print(try await session.respond(to: "How about a great place to eat?") diff --git a/Libraries/MLXLLM/Documentation.docc/evaluation.md b/Libraries/MLXLLM/Documentation.docc/evaluation.md index ed3a2a320..9551b4ad8 100644 --- a/Libraries/MLXLLM/Documentation.docc/evaluation.md +++ b/Libraries/MLXLLM/Documentation.docc/evaluation.md @@ -5,7 +5,10 @@ The simplified LLM/VLM API allows you to load a model and evaluate prompts with For example, this loads a model and asks a question and a follow-on question: ```swift -let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") +let model = try await loadModel( + using: TokenizersLoader(), + id: "mlx-community/Qwen3-4B-4bit" +) let session = ChatSession(model) print(try await session.respond(to: "What are two things to see in San Francisco?") print(try await session.respond(to: "How about a great place to eat?") @@ -26,7 +29,10 @@ users want to see the text as it is generated -- you can do this with a stream: ```swift -let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") +let model = try await loadModel( + using: TokenizersLoader(), + id: "mlx-community/Qwen3-4B-4bit" +) let session = ChatSession(model) for try await item in session.streamResponse(to: "Why is the sky blue?") { @@ -41,7 +47,10 @@ This same API supports VLMs as well. Simply present the image or video to the ``ChatSession``: ```swift -let model = try await loadModel(id: "mlx-community/Qwen2.5-VL-3B-Instruct-4bit") +let model = try await loadModel( + using: TokenizersLoader(), + id: "mlx-community/Qwen2.5-VL-3B-Instruct-4bit" +) let session = ChatSession(model) let answer1 = try await session.respond( diff --git a/Libraries/MLXLLM/Documentation.docc/using-model.md b/Libraries/MLXLLM/Documentation.docc/using-model.md index 4b1ae11b0..e4d4f400e 100644 --- a/Libraries/MLXLLM/Documentation.docc/using-model.md +++ b/Libraries/MLXLLM/Documentation.docc/using-model.md @@ -16,7 +16,13 @@ let modelFactory: ModelFactory // e.g. LLMRegistry.llama3_8B_4bit let modelConfiguration: ModelConfiguration -let container = try await modelFactory.loadContainer(configuration: modelConfiguration) +// e.g. TokenizersLoader() from MLXLMTokenizers +let tokenizerLoader: any TokenizerLoader + +let container = try await modelFactory.loadContainer( + using: tokenizerLoader, + configuration: modelConfiguration +) ``` The `container` provides an isolation context (an `actor`) to run inference in the model. @@ -34,16 +40,15 @@ The flow inside the `ModelFactory` goes like this: public class LLMModelFactory: ModelFactory { public func _load( - hub: HubApi, configuration: ModelConfiguration, - progressHandler: @Sendable @escaping (Progress) -> Void + configuration: ResolvedModelConfiguration, + tokenizerLoader: any TokenizerLoader ) async throws -> ModelContext { - // download the weight and config using HubApi + // modelDirectory and tokenizerDirectory are already resolved // load the base configuration // using the typeRegistry create a model (random weights) // load the weights, apply quantization as needed, update the model // calls model.sanitize() for weight preparation - // load the tokenizer - // (vlm) load the processor configuration, create the processor + // load the tokenizer via tokenizerLoader.load(from: directory) } } ``` diff --git a/Libraries/MLXLLM/LLMModel.swift b/Libraries/MLXLLM/LLMModel.swift index f120d6ea5..d686b7ef7 100644 --- a/Libraries/MLXLLM/LLMModel.swift +++ b/Libraries/MLXLLM/LLMModel.swift @@ -2,7 +2,6 @@ import MLX import MLXLMCommon -import Tokenizers /// Marker protocol for LLMModels public protocol LLMModel: LanguageModel, LoRAModel { diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index 366774ec2..057b376a6 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -1,10 +1,8 @@ // Copyright © 2024 Apple Inc. import Foundation -import Hub import MLX import MLXLMCommon -import Tokenizers /// Creates a function that decodes configuration data and instantiates a model with the proper configuration private func create( @@ -18,7 +16,7 @@ private func create( /// Registry of model type, e.g 'llama', to functions that can instantiate the model from configuration. /// -/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``. +/// Typically called via ``LLMModelFactory/load(from:configuration:progressHandler:)``. public enum LLMTypeRegistry { /// Shared instance with default model types. @@ -107,7 +105,6 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable { static public let codeLlama13b4bit = ModelConfiguration( id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", - overrideTokenizer: "PreTrainedTokenizer", defaultPrompt: "func sortArray(_ array: [Int]) -> String { }" ) @@ -132,28 +129,22 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable { id: "mlx-community/Phi-3.5-MoE-instruct-4bit", defaultPrompt: "What is the gravity on Mars and the moon?", extraEOSTokens: ["<|end|>"] - ) { - prompt in - "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n" - } + ) static public let gemma2bQuantized = ModelConfiguration( id: "mlx-community/quantized-gemma-2b-it", - overrideTokenizer: "PreTrainedTokenizer", // https://www.promptingguide.ai/models/gemma defaultPrompt: "what is the difference between lettuce and cabbage?" ) static public let gemma_2_9b_it_4bit = ModelConfiguration( id: "mlx-community/gemma-2-9b-it-4bit", - overrideTokenizer: "PreTrainedTokenizer", // https://www.promptingguide.ai/models/gemma defaultPrompt: "What is the difference between lettuce and cabbage?" ) static public let gemma_2_2b_it_4bit = ModelConfiguration( id: "mlx-community/gemma-2-2b-it-4bit", - overrideTokenizer: "PreTrainedTokenizer", // https://www.promptingguide.ai/models/gemma defaultPrompt: "What is the difference between lettuce and cabbage?" ) @@ -194,7 +185,6 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable { static public let qwen205b4bit = ModelConfiguration( id: "mlx-community/Qwen1.5-0.5B-Chat-4bit", - overrideTokenizer: "PreTrainedTokenizer", defaultPrompt: "why is the sky blue?" ) @@ -484,12 +474,10 @@ public final class LLMModelFactory: ModelFactory { public let modelRegistry: AbstractModelRegistry public func _load( - hub: HubApi, configuration: ModelConfiguration, - progressHandler: @Sendable @escaping (Progress) -> Void + configuration: ResolvedModelConfiguration, + tokenizerLoader: any TokenizerLoader ) async throws -> ModelContext { - // download weights and config - let modelDirectory = try await downloadModel( - hub: hub, configuration: configuration, progressHandler: progressHandler) + let modelDirectory = configuration.modelDirectory // Load config.json once and decode for both base config and model-specific config let configurationURL = modelDirectory.appending(component: "config.json") @@ -528,17 +516,16 @@ public final class LLMModelFactory: ModelFactory { eosTokenIds = Set(genEosIds) // Override per Python mlx-lm behavior } - // Create mutable configuration with loaded EOS token IDs + // Build a ModelConfiguration with loaded EOS token IDs and tool call format var mutableConfiguration = configuration mutableConfiguration.eosTokenIds = eosTokenIds - - // Auto-detect tool call format from model type if not explicitly set if mutableConfiguration.toolCallFormat == nil { mutableConfiguration.toolCallFormat = ToolCallFormat.infer(from: baseConfig.modelType) } - // Load tokenizer and weights in parallel using async let. - async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub) + // Load tokenizer and weights in parallel + async let tokenizerTask = tokenizerLoader.load( + from: configuration.tokenizerDirectory) try loadWeights( modelDirectory: modelDirectory, model: model, @@ -553,12 +540,25 @@ public final class LLMModelFactory: ModelFactory { DefaultMessageGenerator() } + // Build a ModelConfiguration for the ModelContext + let tokenizerSource: TokenizerSource? = + configuration.tokenizerDirectory == modelDirectory + ? nil + : .directory(configuration.tokenizerDirectory) + let modelConfig = ModelConfiguration( + directory: modelDirectory, + tokenizerSource: tokenizerSource, + defaultPrompt: configuration.defaultPrompt, + extraEOSTokens: mutableConfiguration.extraEOSTokens, + eosTokenIds: mutableConfiguration.eosTokenIds, + toolCallFormat: mutableConfiguration.toolCallFormat) + let processor = LLMUserInputProcessor( - tokenizer: tokenizer, configuration: mutableConfiguration, + tokenizer: tokenizer, configuration: modelConfig, messageGenerator: messageGenerator) return .init( - configuration: mutableConfiguration, model: model, processor: processor, + configuration: modelConfig, model: model, processor: processor, tokenizer: tokenizer) } diff --git a/Libraries/MLXLLM/LoraTrain.swift b/Libraries/MLXLLM/LoraTrain.swift index 0365aceea..0b498a9ac 100644 --- a/Libraries/MLXLLM/LoraTrain.swift +++ b/Libraries/MLXLLM/LoraTrain.swift @@ -5,7 +5,6 @@ import MLX import MLXLMCommon import MLXNN import MLXOptimizers -import Tokenizers /// Equivalent to `lora.py/iterate_batches()`. Used internally by ``LoRATrain``. struct LoRABatchIterator: Sequence, IteratorProtocol { diff --git a/Libraries/MLXLLM/Models/Apertus.swift b/Libraries/MLXLLM/Models/Apertus.swift index fbe92de5d..0f7f27bc6 100644 --- a/Libraries/MLXLLM/Models/Apertus.swift +++ b/Libraries/MLXLLM/Models/Apertus.swift @@ -2,7 +2,6 @@ import Foundation import MLX import MLXLMCommon import MLXNN -import Tokenizers // port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/apertus.py diff --git a/Libraries/MLXLLM/Models/Bitnet.swift b/Libraries/MLXLLM/Models/Bitnet.swift index 2c2f6ae9b..e28aedca2 100644 --- a/Libraries/MLXLLM/Models/Bitnet.swift +++ b/Libraries/MLXLLM/Models/Bitnet.swift @@ -9,7 +9,6 @@ import Foundation import MLX import MLXLMCommon import MLXNN -import Tokenizers // port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/bitnet.py diff --git a/Libraries/MLXLLM/Models/Gemma.swift b/Libraries/MLXLLM/Models/Gemma.swift index 1f512b93e..2c9dee8eb 100644 --- a/Libraries/MLXLLM/Models/Gemma.swift +++ b/Libraries/MLXLLM/Models/Gemma.swift @@ -4,7 +4,6 @@ import Foundation import MLX import MLXLMCommon import MLXNN -import Tokenizers // Port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/gemma.py diff --git a/Libraries/MLXLLM/Models/Gemma2.swift b/Libraries/MLXLLM/Models/Gemma2.swift index 24780c4de..eaf30b23e 100644 --- a/Libraries/MLXLLM/Models/Gemma2.swift +++ b/Libraries/MLXLLM/Models/Gemma2.swift @@ -4,7 +4,6 @@ import Foundation import MLX import MLXLMCommon import MLXNN -import Tokenizers // Port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/gemma2.py diff --git a/Libraries/MLXLLM/Models/Llama.swift b/Libraries/MLXLLM/Models/Llama.swift index 3f47069fb..b81ab7545 100644 --- a/Libraries/MLXLLM/Models/Llama.swift +++ b/Libraries/MLXLLM/Models/Llama.swift @@ -4,7 +4,6 @@ import Foundation import MLX import MLXLMCommon import MLXNN -import Tokenizers // port of https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/llama.py diff --git a/Libraries/MLXLLM/README.md b/Libraries/MLXLLM/README.md index 4937356b5..435dc973a 100644 --- a/Libraries/MLXLLM/README.md +++ b/Libraries/MLXLLM/README.md @@ -13,9 +13,8 @@ This is a port of several models from: - https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/ -using the Hugging Face swift transformers package to provide tokenization: - -- https://github.com/huggingface/swift-transformers +Tokenization is provided via the `TokenizerLoader` protocol – see the main +[README](../../README.md) for available integration packages. The [LLMModelFactory.swift](LLMModelFactory.swift) provides minor overrides and customization -- if you require overrides for the tokenizer or prompt customizations they can be @@ -73,10 +72,17 @@ See [llm-tool](../../Tools/llm-tool) Using LLMs and VLMs from MLXLMCommon is as easy as: ```swift -let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") +import MLXLLM +import MLXLMHuggingFace +import MLXLMTokenizers + +let model = try await loadModel( + using: TokenizersLoader(), + id: "mlx-community/Qwen3-4B-4bit" +) let session = ChatSession(model) -print(try await session.respond(to: "What are two things to see in San Francisco?") -print(try await session.respond(to: "How about a great place to eat?") +print(try await session.respond(to: "What are two things to see in San Francisco?")) +print(try await session.respond(to: "How about a great place to eat?")) ``` For more information see diff --git a/Libraries/MLXLMCommon/Adapters/ModelAdapter.swift b/Libraries/MLXLMCommon/Adapters/ModelAdapter.swift index a0d52df7c..62ae85d58 100644 --- a/Libraries/MLXLMCommon/Adapters/ModelAdapter.swift +++ b/Libraries/MLXLMCommon/Adapters/ModelAdapter.swift @@ -6,7 +6,6 @@ // import Foundation -import Hub import MLX import MLXNN diff --git a/Libraries/MLXLMCommon/Adapters/ModelAdapterFactory.swift b/Libraries/MLXLMCommon/Adapters/ModelAdapterFactory.swift index d41406601..86487bd06 100644 --- a/Libraries/MLXLMCommon/Adapters/ModelAdapterFactory.swift +++ b/Libraries/MLXLMCommon/Adapters/ModelAdapterFactory.swift @@ -6,7 +6,6 @@ // import Foundation -import Hub import MLX import MLXNN @@ -34,7 +33,7 @@ private struct ModelAdapterBaseConfiguration: Decodable { } } -/// A factory responsible for loading and creating model adapters from hub configurations. +/// A factory responsible for loading and creating model adapters from configurations. public final class ModelAdapterFactory: Sendable { /// Shared instance of the adapter factory. @@ -63,18 +62,27 @@ public final class ModelAdapterFactory: Sendable { self.registry = registry } - /// Loads a model adapter from the hub using the provided model configuration. + /// Loads a model adapter using a ``Downloader`` and the provided model configuration. /// /// This method fetches the adapter configuration and weights, decodes the appropriate /// fine-tuning format, and initializes a `ModelAdapter` accordingly. public func load( - hub: HubApi = HubApi(), + from downloader: any Downloader, configuration: ModelConfiguration, + useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } ) async throws -> ModelAdapter { - let adapterDirectory = try await downloadModel( - hub: hub, configuration: configuration, progressHandler: progressHandler - ) + let adapterDirectory: URL + switch configuration.id { + case .id(let id, let revision): + adapterDirectory = try await downloader.download( + id: id, revision: revision, + matching: ["*.safetensors", "*.json"], + useLatest: useLatest, + progressHandler: progressHandler) + case .directory(let directory): + adapterDirectory = directory + } let configurationURL = adapterDirectory.appending(component: "adapter_config.json") let configurationData = try Data(contentsOf: configurationURL) diff --git a/Libraries/MLXLMCommon/BaseConfiguration.swift b/Libraries/MLXLMCommon/BaseConfiguration.swift index ac6d97731..537a382d0 100644 --- a/Libraries/MLXLMCommon/BaseConfiguration.swift +++ b/Libraries/MLXLMCommon/BaseConfiguration.swift @@ -6,7 +6,7 @@ import MLX /// Base ``LanguageModel`` configuration -- provides `modelType` /// and `quantization` (used in loading the model). /// -/// This is used by ``ModelFactory/load(hub:configuration:progressHandler:)`` +/// This is used by ``ModelFactory/load(from:configuration:progressHandler:)`` /// to determine the type of model to load. public struct BaseConfiguration: Codable, Sendable { public let modelType: String diff --git a/Libraries/MLXLMCommon/ChatSession.swift b/Libraries/MLXLMCommon/ChatSession.swift index be8b0d71d..886637afe 100644 --- a/Libraries/MLXLMCommon/ChatSession.swift +++ b/Libraries/MLXLMCommon/ChatSession.swift @@ -3,7 +3,6 @@ import CoreGraphics import Foundation import MLX -import Tokenizers /// Simplified API for multi-turn conversations with LLMs and VLMs. /// diff --git a/Libraries/MLXLMCommon/Documentation.docc/porting.md b/Libraries/MLXLMCommon/Documentation.docc/porting.md index 3a7092279..7b7c9b78a 100644 --- a/Libraries/MLXLMCommon/Documentation.docc/porting.md +++ b/Libraries/MLXLMCommon/Documentation.docc/porting.md @@ -595,8 +595,14 @@ Now we can load the model using `llm-tool` or the `LLMEval` example application, ```swift let modelConfiguration = ModelConfiguration(id: "mlx-community/quantized-gemma-2b-it") -// This will download the weights from Hugging Face Hub and load the model -let container = try await MLXModelFactory.shared.loadContainer(configuration: modelConfiguration) +// e.g. TokenizersLoader() from MLXLMTokenizers +let tokenizerLoader: any TokenizerLoader + +// This will download the weights and load the model +let container = try await MLXModelFactory.shared.loadContainer( + using: tokenizerLoader, + configuration: modelConfiguration +) // Prepare the prompt and parameters used to generate the response let generateParameters = GenerateParameters() diff --git a/Libraries/MLXLMCommon/Downloader.swift b/Libraries/MLXLMCommon/Downloader.swift new file mode 100644 index 000000000..939ebe278 --- /dev/null +++ b/Libraries/MLXLMCommon/Downloader.swift @@ -0,0 +1,110 @@ +import Foundation + +/// A protocol for downloading model repository snapshots to local directories. +/// +/// Conforming types encapsulate the full download lifecycle — cache check, network +/// download, and fallback to cache on failure. Each conformance owns its own caching +/// strategy. The return value is always a local directory URL containing the requested files. +/// +/// The protocol is provider-agnostic. `id` is a plain `String` that each conformance +/// interprets however it wants (e.g. `"org/model"` for Hugging Face, a four-part +/// Kaggle handle, an S3 path). `revision` is optional for providers without versioning. +/// +/// ## See Also +/// - ``ResolvedModelConfiguration`` +/// - ``TokenizerSource`` +public protocol Downloader: Sendable { + /// Download (or retrieve from cache) a snapshot of a repository. + /// + /// - Parameters: + /// - id: Provider-specific repository identifier + /// - revision: Optional revision (branch, tag, commit hash, version number). + /// Providers without versioning receive `nil`. + /// - patterns: Glob patterns to filter which files to download + /// (e.g. `["*.safetensors", "*.json", "*.jinja"]`) + /// - useLatest: When `true`, check the provider for updates even if a cached + /// version exists. When `false`, return the cached version if available. + /// - progressHandler: Callback for download progress + /// - Returns: Local directory URL containing the downloaded files + func download( + id: String, + revision: String?, + matching patterns: [String], + useLatest: Bool, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> URL +} + +/// Identifies where a tokenizer should be loaded from. +/// +/// Used by ``ModelConfiguration`` to specify an alternate tokenizer source +/// when the tokenizer files are not co-located with the model weights. +/// +/// - ``id(_:)`` downloads tokenizer files from a remote provider (requires a ``Downloader``) +/// - ``directory(_:)`` loads tokenizer files from a local path +/// +/// When `nil` on a ``ModelConfiguration``, the tokenizer is loaded from the +/// same directory as the model. +public enum TokenizerSource: Sendable, Equatable { + /// A provider-specific repository identifier for downloading tokenizer files. + /// - Parameters: + /// - id: The repository identifier (e.g. `"org/tokenizer-name"`). + /// - revision: Optional revision (branch, tag, commit hash). When `nil`, + /// the ``Downloader`` decides the default (typically `"main"`). + case id(String, revision: String? = nil) + /// A local directory containing tokenizer files. + case directory(URL) +} + +/// A fully resolved model configuration where all sources have been resolved +/// to local directory paths. +/// +/// Created by resolving a ``ModelConfiguration`` — downloading remote sources +/// via a ``Downloader`` and mapping behavioral properties. Factory implementations +/// receive this type in their `_load` method, so they work purely with local files. +/// +/// ## See Also +/// - ``ModelConfiguration/resolved(modelDirectory:tokenizerDirectory:)`` +/// - ``Downloader`` +public struct ResolvedModelConfiguration: Sendable { + public var modelDirectory: URL + public var tokenizerDirectory: URL + public var name: String + public var defaultPrompt: String + public var extraEOSTokens: Set + public var eosTokenIds: Set + public var toolCallFormat: ToolCallFormat? + + public init( + modelDirectory: URL, + tokenizerDirectory: URL, + name: String, + defaultPrompt: String, + extraEOSTokens: Set, + eosTokenIds: Set, + toolCallFormat: ToolCallFormat? + ) { + self.modelDirectory = modelDirectory + self.tokenizerDirectory = tokenizerDirectory + self.name = name + self.defaultPrompt = defaultPrompt + self.extraEOSTokens = extraEOSTokens + self.eosTokenIds = eosTokenIds + self.toolCallFormat = toolCallFormat + } +} + +extension ResolvedModelConfiguration { + /// Convenience for loading everything from a single local directory. + public init(directory: URL) { + self.init( + modelDirectory: directory, + tokenizerDirectory: directory, + name: directory.deletingLastPathComponent().lastPathComponent + "/" + + directory.lastPathComponent, + defaultPrompt: "", + extraEOSTokens: [], + eosTokenIds: [], + toolCallFormat: nil) + } +} diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index 226f52e76..65ef1fd93 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -3,7 +3,6 @@ import Foundation import MLX import MLXNN -import Tokenizers /// A `LogitSampler` is responsible for sampling `logits` produced by /// a ``LanguageModel`` to produce a token. @@ -34,13 +33,13 @@ public protocol LogitSampler { /// See also: ``LogitSampler`` public protocol LogitProcessor { - /// called before token generation starts with the text tokens of the prompt + /// Called before token generation starts with the text tokens of the prompt mutating func prompt(_ prompt: MLXArray) - /// called to visit and possibly modify the logits + /// Called to visit and possibly modify the logits func process(logits: MLXArray) -> MLXArray - /// called to provide the sampled token + /// Called to provide the sampled token mutating func didSample(token: MLXArray) } @@ -73,22 +72,22 @@ public struct GenerateParameters: Sendable { /// Step to begin using a quantized KV cache when kvBits is non-nil (default: 0) public var quantizedKVStart: Int - /// sampling temperature + /// Sampling temperature public var temperature: Float - /// top p sampling + /// Top-p sampling public var topP: Float - /// top k sampling (0 disables) + /// Top-k sampling (0 disables) public var topK: Int - /// min p sampling threshold relative to the highest probability token (0 disables) + /// Min-p sampling threshold relative to the highest probability token (0 disables) public var minP: Float - /// penalty factor for repeating tokens + /// Penalty factor for repeating tokens public var repetitionPenalty: Float? - /// number of tokens to consider for repetition penalty + /// Number of tokens to consider for repetition penalty public var repetitionContextSize: Int /// additive penalty for tokens that appear in recent context @@ -706,45 +705,61 @@ public struct GenerateResult { /// /// - Parameters: /// - inputText: The input text used for generation. - /// - tokens: The array of tokens generated. + /// - tokenIds: The array of generated token IDs. /// - output: The generated output string. /// - promptTime: The time taken to prompt the input. /// - generateTime: The time taken to generate the output. public init( - inputText: LMInput.Text, tokens: [Int], output: String, promptTime: TimeInterval, + inputText: LMInput.Text, tokenIds: [Int], output: String, promptTime: TimeInterval, generateTime: TimeInterval ) { self.inputText = inputText - self.tokens = tokens + self.tokenIds = tokenIds self.output = output self.promptTime = promptTime self.generateTime = generateTime } + @available(*, deprecated, renamed: "init(inputText:tokenIds:output:promptTime:generateTime:)") + public init( + inputText: LMInput.Text, tokens: [Int], output: String, promptTime: TimeInterval, + generateTime: TimeInterval + ) { + self.init( + inputText: inputText, tokenIds: tokens, output: output, promptTime: promptTime, + generateTime: generateTime) + } + /// input (prompt, images, etc.) public let inputText: LMInput.Text - @available(*, deprecated, message: "use inputText") - public var promptTokens: [Int] { + /// The token IDs of the input prompt. + public var promptTokenIds: [Int] { inputText.tokens.asArray(Int.self) } - /// output tokens - public let tokens: [Int] + @available(*, deprecated, renamed: "promptTokenIds") + public var promptTokens: [Int] { promptTokenIds } + + /// Generated token IDs + public let tokenIds: [Int] + + @available(*, deprecated, renamed: "tokenIds") + public var tokens: [Int] { tokenIds } - /// output text + /// Output text public let output: String /// The number of tokens included in the input prompt. public var promptTokenCount: Int { inputText.tokens.size } /// The number of tokens generated by the language model. - public var generationTokenCount: Int { tokens.count } + public var generationTokenCount: Int { tokenIds.count } - /// time to process the prompt / generate the first token + /// Time to process the prompt (generate the first token) public let promptTime: TimeInterval - /// time to generate the remaining tokens + /// Time to generate the remaining tokens public let generateTime: TimeInterval /// The number of tokens processed per second during the prompt phase. @@ -754,7 +769,7 @@ public struct GenerateResult { /// The number of tokens generated per second during the generation phase. public var tokensPerSecond: Double { - Double(tokens.count) / generateTime + Double(tokenIds.count) / generateTime } public func summary() -> String { @@ -767,53 +782,53 @@ public struct GenerateResult { /// Action from token visitor callback in deprecated callback-based generate functions. public enum GenerateDisposition: Sendable { - /// keep producing tokens until an EOS token is produced + /// Keep producing tokens until an EOS token is produced case more - /// stop producing tokens, e.g. a token limit has been hit + /// Stop producing tokens, e.g. a token limit has been hit case stop } private struct SynchronousGenerationLoopResult { - let generatedTokens: [Int] + let generatedTokenIds: [Int] let promptTime: TimeInterval let generateTime: TimeInterval let promptPrefillTime: TimeInterval let stopReason: GenerateStopReason } -private func buildStopTokenIDs( +private func buildStopTokenIds( modelConfiguration: ModelConfiguration, tokenizer: Tokenizer ) -> Set { // Build complete EOS token set from all sources. - var stopTokenIDs = modelConfiguration.eosTokenIds + var stopTokenIds = modelConfiguration.eosTokenIds if let tokenizerEOS = tokenizer.eosTokenId { - stopTokenIDs.insert(tokenizerEOS) + stopTokenIds.insert(tokenizerEOS) } for token in modelConfiguration.extraEOSTokens { if let id = tokenizer.convertTokenToId(token) { - stopTokenIDs.insert(id) + stopTokenIds.insert(id) } } - return stopTokenIDs + return stopTokenIds } private func runSynchronousGenerationLoop( modelConfiguration: ModelConfiguration, tokenizer: Tokenizer, iterator: TokenIterator, - didGenerate: (_ token: Int, _ generatedTokens: [Int]) -> GenerateDisposition + didGenerate: (_ token: Int, _ generatedTokenIds: [Int]) -> GenerateDisposition ) -> SynchronousGenerationLoopResult { var start = Date.timeIntervalSinceReferenceDate var promptTime: TimeInterval = 0 - let stopTokenIDs = buildStopTokenIDs( + let stopTokenIds = buildStopTokenIds( modelConfiguration: modelConfiguration, tokenizer: tokenizer ) - var generatedTokens = [Int]() + var generatedTokenIds = [Int]() var iterator = iterator var stopReason: GenerateStopReason? @@ -826,14 +841,14 @@ private func runSynchronousGenerationLoop( } // Check for end-of-sequence tokens. - if token == tokenizer.unknownTokenId || stopTokenIDs.contains(token) { + if token == tokenizer.unknownTokenId || stopTokenIds.contains(token) { stopReason = .stop break } - generatedTokens.append(token) + generatedTokenIds.append(token) - if didGenerate(token, generatedTokens) == .stop { + if didGenerate(token, generatedTokenIds) == .stop { stopReason = .cancelled break } @@ -858,7 +873,7 @@ private func runSynchronousGenerationLoop( Stream().synchronize() return SynchronousGenerationLoopResult( - generatedTokens: generatedTokens, + generatedTokenIds: generatedTokenIds, promptTime: promptTime, generateTime: generateTime, promptPrefillTime: iterator.promptPrefillTime, @@ -960,8 +975,8 @@ public func generate( } return GenerateResult( - inputText: input.text, tokens: result.generatedTokens, - output: context.tokenizer.decode(tokens: result.generatedTokens), + inputText: input.text, tokenIds: result.generatedTokenIds, + output: context.tokenizer.decode(tokenIds: result.generatedTokenIds), promptTime: result.promptTime + result.promptPrefillTime, generateTime: result.generateTime ) @@ -1023,7 +1038,7 @@ public func generate( return GenerateCompletionInfo( promptTokenCount: input.text.tokens.size, - generationTokenCount: result.generatedTokens.count, + generationTokenCount: result.generatedTokenIds.count, promptTime: result.promptTime + result.promptPrefillTime, generationTime: result.generateTime, stopReason: result.stopReason @@ -1279,7 +1294,7 @@ private func generateLoopTask( var tokenCount = 0 var stopReason: GenerateStopReason? - let stopTokenIDs = buildStopTokenIDs( + let stopTokenIds = buildStopTokenIds( modelConfiguration: modelConfiguration, tokenizer: tokenizer ) @@ -1298,7 +1313,7 @@ private func generateLoopTask( } // Check for end-of-sequence tokens - if token == tokenizer.unknownTokenId || stopTokenIDs.contains(token) { + if token == tokenizer.unknownTokenId || stopTokenIds.contains(token) { if includeStopToken { tokenCount += 1 if !handler.onStopToken(token, emit: continuation.yield) { diff --git a/Libraries/MLXLMCommon/LanguageModel.swift b/Libraries/MLXLMCommon/LanguageModel.swift index 838834fb2..5142ccbfd 100644 --- a/Libraries/MLXLMCommon/LanguageModel.swift +++ b/Libraries/MLXLMCommon/LanguageModel.swift @@ -1,10 +1,8 @@ // Copyright © 2024 Apple Inc. import Foundation -import Hub import MLX import MLXNN -import Tokenizers /// Time/Height/Width struct to represent information about input images. public struct THW: Sendable { diff --git a/Libraries/MLXLMCommon/Load.swift b/Libraries/MLXLMCommon/Load.swift index b587754fd..1161620db 100644 --- a/Libraries/MLXLMCommon/Load.swift +++ b/Libraries/MLXLMCommon/Load.swift @@ -1,63 +1,12 @@ // Copyright © 2024 Apple Inc. import Foundation -import Hub import MLX import MLXNN -import Tokenizers - -/// Download the model using the `HubApi`. -/// -/// This will download `*.safetensors` and `*.json` if the ``ModelConfiguration`` -/// represents a Hub id, e.g. `mlx-community/gemma-2-2b-it-4bit`. -/// -/// This is typically called via ``ModelFactory/load(hub:configuration:progressHandler:)`` -/// -/// - Parameters: -/// - hub: HubApi instance -/// - configuration: the model identifier -/// - progressHandler: callback for progress -/// - Returns: URL for the directory containing downloaded files -public func downloadModel( - hub: HubApi, configuration: ModelConfiguration, - progressHandler: @Sendable @escaping (Progress) -> Void -) async throws -> URL { - do { - switch configuration.id { - case .id(let id, let revision): - // download the model weights - let repo = Hub.Repo(id: id) - let modelFiles = ["*.safetensors", "*.json", "*.jinja"] - return try await hub.snapshot( - from: repo, - revision: revision, - matching: modelFiles, - progressHandler: progressHandler - ) - case .directory(let directory): - return directory - } - - } catch Hub.HubClientError.authorizationRequired { - // an authorizationRequired means (typically) that the named repo doesn't exist on - // on the server so retry with local only configuration - return configuration.modelDirectory(hub: hub) - - } catch { - let nserror = error as NSError - if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet { - // Error Domain=NSURLErrorDomain Code=-1009 "The Internet connection appears to be offline." - // fall back to the local directory - return configuration.modelDirectory(hub: hub) - } else { - throw error - } - } -} /// Load model weights. /// -/// This is typically called via ``ModelFactory/load(hub:configuration:progressHandler:)``. +/// This is typically called via ``ModelFactory/load(from:configuration:progressHandler:)``. /// This function loads all `safetensor` files in the given `modelDirectory`, /// calls ``LanguageModel/sanitize(weights:metadata:)`` to allow per-model preprocessing, /// applies optional quantization, and diff --git a/Libraries/MLXLMCommon/ModelConfiguration.swift b/Libraries/MLXLMCommon/ModelConfiguration.swift index d4478d1d2..d3d50e3ce 100644 --- a/Libraries/MLXLMCommon/ModelConfiguration.swift +++ b/Libraries/MLXLMCommon/ModelConfiguration.swift @@ -1,13 +1,26 @@ // Copyright © 2024 Apple Inc. import Foundation -import Hub /// Configuration for a given model name with overrides for prompts and tokens. /// /// See e.g. `MLXLM.ModelRegistry` for an example of use. public struct ModelConfiguration: Sendable { + public enum DirectoryError: LocalizedError, Equatable { + case unresolvedModelDirectory(String) + case unresolvedTokenizerDirectory(String) + + public var errorDescription: String? { + switch self { + case .unresolvedModelDirectory(let id): + return "Model configuration '\(id)' has not been resolved to a local directory." + case .unresolvedTokenizerDirectory(let id): + return "Tokenizer source '\(id)' has not been resolved to a local directory." + } + } + } + public enum Identifier: Sendable { case id(String, revision: String = "main") case directory(URL) @@ -24,11 +37,47 @@ public struct ModelConfiguration: Sendable { } } - /// pull the tokenizer from an alternate id - public let tokenizerId: String? + /// The resolved local directory containing model files. + /// + /// - Throws: ``DirectoryError/unresolvedModelDirectory(_:)`` if this configuration still + /// identifies a remote model by ID rather than a local directory. + package var modelDirectory: URL { + get throws { + switch id { + case .directory(let directory): + return directory + case .id(let id, _): + throw DirectoryError.unresolvedModelDirectory(id) + } + } + } + + /// The resolved local directory containing tokenizer files. + /// + /// If ``tokenizerSource`` is `nil`, this falls back to ``modelDirectory``. + /// + /// - Throws: ``DirectoryError/unresolvedTokenizerDirectory(_:)`` if the tokenizer still + /// points to a remote source by ID, or ``DirectoryError/unresolvedModelDirectory(_:)`` + /// if no separate tokenizer source is set and the model itself is unresolved. + package var tokenizerDirectory: URL { + get throws { + switch tokenizerSource { + case .directory(let directory): + return directory + case .id(let id, _): + throw DirectoryError.unresolvedTokenizerDirectory(id) + case nil: + return try modelDirectory + } + } + } - /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated - public let overrideTokenizer: String? + /// Where to load the tokenizer from when it differs from the model directory. + /// + /// - `.id`: download from a remote provider (requires a ``Downloader``) + /// - `.directory`: load from a local path + /// - `nil`: use the same directory as the model + public let tokenizerSource: TokenizerSource? /// A reasonable default prompt for the model public var defaultPrompt: String @@ -44,15 +93,13 @@ public struct ModelConfiguration: Sendable { public init( id: String, revision: String = "main", - tokenizerId: String? = nil, overrideTokenizer: String? = nil, - defaultPrompt: String = "hello", + tokenizerSource: TokenizerSource? = nil, + defaultPrompt: String = "", extraEOSTokens: Set = [], - toolCallFormat: ToolCallFormat? = nil, - preparePrompt: (@Sendable (String) -> String)? = nil + toolCallFormat: ToolCallFormat? = nil ) { self.id = .id(id, revision: revision) - self.tokenizerId = tokenizerId - self.overrideTokenizer = overrideTokenizer + self.tokenizerSource = tokenizerSource self.defaultPrompt = defaultPrompt self.extraEOSTokens = extraEOSTokens self.toolCallFormat = toolCallFormat @@ -60,32 +107,38 @@ public struct ModelConfiguration: Sendable { public init( directory: URL, - tokenizerId: String? = nil, overrideTokenizer: String? = nil, - defaultPrompt: String = "hello", + tokenizerSource: TokenizerSource? = nil, + defaultPrompt: String = "", extraEOSTokens: Set = [], eosTokenIds: Set = [], toolCallFormat: ToolCallFormat? = nil ) { self.id = .directory(directory) - self.tokenizerId = tokenizerId - self.overrideTokenizer = overrideTokenizer + self.tokenizerSource = tokenizerSource self.defaultPrompt = defaultPrompt self.extraEOSTokens = extraEOSTokens self.eosTokenIds = eosTokenIds self.toolCallFormat = toolCallFormat } - public func modelDirectory(hub: HubApi = HubApi()) -> URL { - switch id { - case .id(let id, _): - // download the model weights and config - let repo = Hub.Repo(id: id) - return hub.localRepoLocation(repo) - - case .directory(let directory): - return directory - } + /// Maps this configuration's behavioral properties into a + /// ``ResolvedModelConfiguration`` with the given directories. + /// + /// This is a pure data mapping with no I/O. The directories should + /// already be resolved (downloaded or local) before calling this method. + public func resolved( + modelDirectory: URL, tokenizerDirectory: URL + ) -> ResolvedModelConfiguration { + ResolvedModelConfiguration( + modelDirectory: modelDirectory, + tokenizerDirectory: tokenizerDirectory, + name: name, + defaultPrompt: defaultPrompt, + extraEOSTokens: extraEOSTokens, + eosTokenIds: eosTokenIds, + toolCallFormat: toolCallFormat) } + } extension ModelConfiguration: Equatable { diff --git a/Libraries/MLXLMCommon/ModelContainer.swift b/Libraries/MLXLMCommon/ModelContainer.swift index 6ed5586f8..9ecc309d7 100644 --- a/Libraries/MLXLMCommon/ModelContainer.swift +++ b/Libraries/MLXLMCommon/ModelContainer.swift @@ -1,10 +1,8 @@ // Copyright © 2024 Apple Inc. import Foundation -import Hub import MLX import MLXNN -import Tokenizers /// Container for models that guarantees single threaded access. /// @@ -130,6 +128,20 @@ public final class ModelContainer: Sendable { // MARK: - Thread-safe convenience methods + /// The resolved local model directory for the loaded container. + public var modelDirectory: URL { + get async throws { + try (await configuration).modelDirectory + } + } + + /// The resolved local tokenizer directory for the loaded container. + public var tokenizerDirectory: URL { + get async throws { + try (await configuration).tokenizerDirectory + } + } + /// Prepare user input for generation. /// /// This method safely prepares input within the actor's isolation, @@ -195,11 +207,16 @@ public final class ModelContainer: Sendable { /// Decode token IDs to a string. /// - /// - Parameter tokens: Array of token IDs + /// - Parameter tokenIds: Array of token IDs /// - Returns: Decoded string - public func decode(tokens: [Int]) async -> String { + public func decode(tokenIds: [Int]) async -> String { let tokenizer = await self.tokenizer - return tokenizer.decode(tokens: tokens) + return tokenizer.decode(tokenIds: tokenIds) + } + + @available(*, deprecated, renamed: "decode(tokenIds:)") + public func decode(tokens: [Int]) async -> String { + await decode(tokenIds: tokens) } /// Encode a string to token IDs. diff --git a/Libraries/MLXLMCommon/ModelFactory.swift b/Libraries/MLXLMCommon/ModelFactory.swift index 5f77ac21a..f6964bcf9 100644 --- a/Libraries/MLXLMCommon/ModelFactory.swift +++ b/Libraries/MLXLMCommon/ModelFactory.swift @@ -1,8 +1,10 @@ // Copyright © 2024 Apple Inc. import Foundation -import Hub -import Tokenizers + +/// File patterns required to resolve a tokenizer without downloading model weights. +package let tokenizerDownloadPatterns = ["*.json", "*.jinja"] +package let modelDownloadPatterns = ["*.safetensors"] + tokenizerDownloadPatterns public enum ModelFactoryError: LocalizedError { case unsupportedModelType(String) @@ -53,7 +55,7 @@ public enum ModelFactoryError: LocalizedError { /// Context of types that work together to provide a ``LanguageModel``. /// -/// A ``ModelContext`` is created by ``ModelFactory/load(hub:configuration:progressHandler:)``. +/// A ``ModelContext`` is created by ``ModelFactory/load(from:configuration:progressHandler:)``. /// This contains the following: /// /// - ``ModelConfiguration`` -- identifier for the model @@ -61,7 +63,7 @@ public enum ModelFactoryError: LocalizedError { /// - ``UserInputProcessor`` -- can convert ``UserInput`` into ``LMInput`` /// - `Tokenizer` -- the tokenizer used by ``UserInputProcessor`` /// -/// See also ``ModelFactory/loadContainer(hub:configuration:progressHandler:)`` and +/// See also ``ModelFactory/loadContainer(from:configuration:progressHandler:)`` and /// ``ModelContainer``. public struct ModelContext { public var configuration: ModelConfiguration @@ -83,24 +85,19 @@ public struct ModelContext { /// Protocol for code that can load models. /// /// ## See Also -/// - ``loadModel(hub:id:progressHandler:)`` -/// - ``loadModel(hub:directory:progressHandler:)`` -/// - ``loadModelContainer(hub:id:progressHandler:)`` -/// - ``loadModelContainer(hub:directory:progressHandler:)`` +/// - ``loadModel(from:id:progressHandler:)`` +/// - ``loadModel(from:)-ModelContext`` +/// - ``loadModelContainer(from:id:progressHandler:)`` +/// - ``loadModelContainer(from:)-ModelContainer`` public protocol ModelFactory: Sendable { var modelRegistry: AbstractModelRegistry { get } func _load( - hub: HubApi, configuration: ModelConfiguration, - progressHandler: @Sendable @escaping (Progress) -> Void + configuration: ResolvedModelConfiguration, + tokenizerLoader: any TokenizerLoader ) async throws -> ModelContext - func _loadContainer( - hub: HubApi, configuration: ModelConfiguration, - progressHandler: @Sendable @escaping (Progress) -> Void - ) async throws -> ModelContainer - } extension ModelFactory { @@ -124,170 +121,255 @@ extension ModelFactory { } -/// Default instance of HubApi to use. This is configured to save downloads into the caches directory. -public let defaultHubApi: HubApi = { - HubApi(downloadBase: FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first) -}() - extension ModelFactory { - /// Load a model identified by a ``ModelConfiguration`` and produce a ``ModelContext``. + /// Load a model from a ``Downloader`` and ``ModelConfiguration``, + /// producing a ``ModelContext``. /// - /// This method returns a ``ModelContext``. See also - /// ``loadContainer(hub:configuration:progressHandler:)`` for a method that - /// returns a ``ModelContainer``. + /// This resolves the configuration (downloading remote sources via the downloader) + /// and then loads the model from local files. /// /// ## See Also - /// - ``loadModel(hub:id:progressHandler:)`` - /// - ``loadModelContainer(hub:id:progressHandler:)`` + /// - ``loadModel(from:configuration:useLatest:progressHandler:)`` + /// - ``loadModelContainer(from:configuration:useLatest:progressHandler:)`` public func load( - hub: HubApi = defaultHubApi, configuration: ModelConfiguration, + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + configuration: ModelConfiguration, + useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } ) async throws -> sending ModelContext { - try await _load(hub: hub, configuration: configuration, progressHandler: progressHandler) + let resolved = try await resolve( + configuration: configuration, from: downloader, + useLatest: useLatest, progressHandler: progressHandler) + return try await _load(configuration: resolved, tokenizerLoader: tokenizerLoader) } - /// Load a model identified by a ``ModelConfiguration`` and produce a ``ModelContainer``. + /// Load a model from a ``Downloader`` and ``ModelConfiguration``, + /// producing a ``ModelContainer``. public func loadContainer( - hub: HubApi = defaultHubApi, configuration: ModelConfiguration, + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + configuration: ModelConfiguration, + useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } ) async throws -> ModelContainer { - try await _loadContainer( - hub: hub, configuration: configuration, progressHandler: progressHandler) + let resolved = try await resolve( + configuration: configuration, from: downloader, + useLatest: useLatest, progressHandler: progressHandler) + let context = try await _load(configuration: resolved, tokenizerLoader: tokenizerLoader) + return ModelContainer(context: context) } - public func _loadContainer( - hub: HubApi = defaultHubApi, configuration: ModelConfiguration, - progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } + /// Load a model from a local directory, producing a ``ModelContext``. + /// + /// No downloader is needed — the model and tokenizer are loaded from + /// the given directory. + public func load( + from directory: URL, + using tokenizerLoader: any TokenizerLoader + ) async throws -> sending ModelContext { + try await _load( + configuration: .init(directory: directory), tokenizerLoader: tokenizerLoader) + } + + /// Load a model from a local directory, producing a ``ModelContainer``. + public func loadContainer( + from directory: URL, + using tokenizerLoader: any TokenizerLoader ) async throws -> ModelContainer { let context = try await _load( - hub: hub, configuration: configuration, progressHandler: progressHandler) + configuration: .init(directory: directory), tokenizerLoader: tokenizerLoader) return ModelContainer(context: context) } } -/// Load a model given a ``ModelConfiguration``. +/// Resolve a ``ModelConfiguration`` into a ``ResolvedModelConfiguration`` by +/// downloading remote sources via a ``Downloader``. +/// +/// This handles the `.id` vs `.directory` switch for the model source and +/// resolves ``TokenizerSource`` for the tokenizer. +public func resolve( + configuration: ModelConfiguration, + from downloader: any Downloader, + useLatest: Bool, + progressHandler: @Sendable @escaping (Progress) -> Void +) async throws -> ResolvedModelConfiguration { + let modelDirectory: URL + switch configuration.id { + case .id(let id, let revision): + modelDirectory = try await downloader.download( + id: id, revision: revision, + matching: modelDownloadPatterns, + useLatest: useLatest, + progressHandler: progressHandler) + case .directory(let directory): + modelDirectory = directory + } + + let tokenizerDirectory: URL + switch configuration.tokenizerSource { + case .id(let id, let revision): + tokenizerDirectory = try await downloader.download( + id: id, revision: revision, + matching: tokenizerDownloadPatterns, + useLatest: useLatest, + progressHandler: { _ in }) + case .directory(let directory): + tokenizerDirectory = directory + case nil: + tokenizerDirectory = modelDirectory + } + + return configuration.resolved( + modelDirectory: modelDirectory, + tokenizerDirectory: tokenizerDirectory) +} + +/// Load a model given a ``ModelConfiguration``, downloading via a ``Downloader``. /// -/// This will load and return a ``ModelContext``. This holds the model and tokenzier without -/// an `actor` providing an isolation context. Use this call when you control the isolation context -/// and can hold the ``ModelContext`` directly. +/// Returns a ``ModelContext`` holding the model and tokenizer without +/// an `actor` providing an isolation context. /// /// - Parameters: -/// - hub: optional HubApi -- by default uses ``defaultHubApi`` +/// - downloader: the ``Downloader`` to use for fetching remote resources +/// - tokenizerLoader: the ``TokenizerLoader`` to use for loading the tokenizer /// - configuration: a ``ModelConfiguration`` +/// - useLatest: when true, always checks the provider for the latest version /// - progressHandler: optional callback for progress /// - Returns: a ``ModelContext`` public func loadModel( - hub: HubApi = defaultHubApi, configuration: ModelConfiguration, + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + configuration: ModelConfiguration, + useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } ) async throws -> sending ModelContext { try await load { - try await $0.load(hub: hub, configuration: configuration, progressHandler: progressHandler) + try await $0.load( + from: downloader, using: tokenizerLoader, configuration: configuration, + useLatest: useLatest, progressHandler: progressHandler) } } -/// Load a model given a ``ModelConfiguration``. +/// Load a model given a ``ModelConfiguration``, downloading via a ``Downloader``. /// -/// This will load and return a ``ModelContainer``. This holds a ``ModelContext`` +/// Returns a ``ModelContainer`` holding a ``ModelContext`` /// inside an actor providing isolation control for the values. /// /// - Parameters: -/// - hub: optional HubApi -- by default uses ``defaultHubApi`` +/// - downloader: the ``Downloader`` to use for fetching remote resources +/// - tokenizerLoader: the ``TokenizerLoader`` to use for loading the tokenizer /// - configuration: a ``ModelConfiguration`` +/// - useLatest: when true, always checks the provider for the latest version /// - progressHandler: optional callback for progress /// - Returns: a ``ModelContainer`` public func loadModelContainer( - hub: HubApi = defaultHubApi, configuration: ModelConfiguration, + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + configuration: ModelConfiguration, + useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } ) async throws -> sending ModelContainer { try await load { try await $0.loadContainer( - hub: hub, configuration: configuration, progressHandler: progressHandler) + from: downloader, using: tokenizerLoader, configuration: configuration, + useLatest: useLatest, progressHandler: progressHandler) } } -/// Load a model given a huggingface identifier. +/// Load a model given a model identifier, downloading via a ``Downloader``. /// -/// This will load and return a ``ModelContext``. This holds the model and tokenzier without -/// an `actor` providing an isolation context. Use this call when you control the isolation context -/// and can hold the ``ModelContext`` directly. +/// Returns a ``ModelContext`` holding the model and tokenizer without +/// an `actor` providing an isolation context. /// /// - Parameters: -/// - hub: optional HubApi -- by default uses ``defaultHubApi`` -/// - id: huggingface model identifier, e.g "mlx-community/Qwen3-4B-4bit" +/// - downloader: the ``Downloader`` to use for fetching remote resources +/// - tokenizerLoader: the ``TokenizerLoader`` to use for loading the tokenizer +/// - id: model identifier, e.g "mlx-community/Qwen3-4B-4bit" +/// - revision: revision to download (defaults to "main") +/// - useLatest: when true, always checks the provider for the latest version /// - progressHandler: optional callback for progress /// - Returns: a ``ModelContext`` public func loadModel( - hub: HubApi = defaultHubApi, id: String, revision: String = "main", + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + id: String, + revision: String = "main", + useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } ) async throws -> sending ModelContext { try await load { try await $0.load( - hub: hub, configuration: .init(id: id, revision: revision), - progressHandler: progressHandler) + from: downloader, using: tokenizerLoader, + configuration: .init(id: id, revision: revision), + useLatest: useLatest, progressHandler: progressHandler) } } -/// Load a model given a huggingface identifier. +/// Load a model given a model identifier, downloading via a ``Downloader``. /// -/// This will load and return a ``ModelContainer``. This holds a ``ModelContext`` +/// Returns a ``ModelContainer`` holding a ``ModelContext`` /// inside an actor providing isolation control for the values. /// /// - Parameters: -/// - hub: optional HubApi -- by default uses ``defaultHubApi`` -/// - id: huggingface model identifier, e.g "mlx-community/Qwen3-4B-4bit" +/// - downloader: the ``Downloader`` to use for fetching remote resources +/// - tokenizerLoader: the ``TokenizerLoader`` to use for loading the tokenizer +/// - id: model identifier, e.g "mlx-community/Qwen3-4B-4bit" +/// - revision: revision to download (defaults to "main") +/// - useLatest: when true, always checks the provider for the latest version /// - progressHandler: optional callback for progress /// - Returns: a ``ModelContainer`` public func loadModelContainer( - hub: HubApi = defaultHubApi, id: String, revision: String = "main", + from downloader: any Downloader, + using tokenizerLoader: any TokenizerLoader, + id: String, + revision: String = "main", + useLatest: Bool = false, progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } ) async throws -> sending ModelContainer { try await load { try await $0.loadContainer( - hub: hub, configuration: .init(id: id, revision: revision), - progressHandler: progressHandler) + from: downloader, using: tokenizerLoader, + configuration: .init(id: id, revision: revision), + useLatest: useLatest, progressHandler: progressHandler) } } -/// Load a model given a directory of configuration and weights. +/// Load a model from a local directory of configuration and weights. /// -/// This will load and return a ``ModelContext``. This holds the model and tokenzier without -/// an `actor` providing an isolation context. Use this call when you control the isolation context -/// and can hold the ``ModelContext`` directly. +/// Returns a ``ModelContext`` holding the model and tokenizer without +/// an `actor` providing an isolation context. /// /// - Parameters: -/// - hub: optional HubApi -- by default uses ``defaultHubApi`` /// - directory: directory of configuration and weights -/// - progressHandler: optional callback for progress +/// - tokenizerLoader: the ``TokenizerLoader`` to use for loading the tokenizer /// - Returns: a ``ModelContext`` public func loadModel( - hub: HubApi = defaultHubApi, directory: URL, - progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } + from directory: URL, + using tokenizerLoader: any TokenizerLoader ) async throws -> sending ModelContext { try await load { - try await $0.load( - hub: hub, configuration: .init(directory: directory), progressHandler: progressHandler) + try await $0.load(from: directory, using: tokenizerLoader) } } -/// Load a model given a directory of configuration and weights. +/// Load a model from a local directory of configuration and weights. /// -/// This will load and return a ``ModelContainer``. This holds a ``ModelContext`` +/// Returns a ``ModelContainer`` holding a ``ModelContext`` /// inside an actor providing isolation control for the values. /// /// - Parameters: -/// - hub: optional HubApi -- by default uses ``defaultHubApi`` /// - directory: directory of configuration and weights -/// - progressHandler: optional callback for progress +/// - tokenizerLoader: the ``TokenizerLoader`` to use for loading the tokenizer /// - Returns: a ``ModelContainer`` public func loadModelContainer( - hub: HubApi = defaultHubApi, directory: URL, - progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } + from directory: URL, + using tokenizerLoader: any TokenizerLoader ) async throws -> sending ModelContainer { try await load { - try await $0.loadContainer( - hub: hub, configuration: .init(directory: directory), progressHandler: progressHandler) + try await $0.loadContainer(from: directory, using: tokenizerLoader) } } @@ -342,19 +424,19 @@ public protocol ModelFactoryTrampoline { /// Registry of ``ModelFactory`` trampolines. /// -/// This allows ``loadModel(hub:id:progressHandler:)`` to use any ``ModelFactory`` instances +/// This allows ``loadModel(from:id:progressHandler:)`` to use any ``ModelFactory`` instances /// available but be defined in the `LLMCommon` layer. This is not typically used directly -- it is -/// called via ``loadModel(hub:id:progressHandler:)``: +/// called via ``loadModel(from:id:progressHandler:)``: /// /// ```swift /// let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") /// ``` /// /// ## See Also -/// - ``loadModel(hub:id:progressHandler:)`` -/// - ``loadModel(hub:directory:progressHandler:)`` -/// - ``loadModelContainer(hub:id:progressHandler:)`` -/// - ``loadModelContainer(hub:directory:progressHandler:)`` +/// - ``loadModel(from:id:progressHandler:)`` +/// - ``loadModel(from:)-ModelContext`` +/// - ``loadModelContainer(from:id:progressHandler:)`` +/// - ``loadModelContainer(from:)-ModelContainer`` final public class ModelFactoryRegistry: @unchecked Sendable { public static let shared = ModelFactoryRegistry() diff --git a/Libraries/MLXLMCommon/README.md b/Libraries/MLXLMCommon/README.md index e7eef3e9d..2910c16d5 100644 --- a/Libraries/MLXLMCommon/README.md +++ b/Libraries/MLXLMCommon/README.md @@ -9,16 +9,81 @@ # Quick Start -Using LLMs and VLMs from MLXLMCommon is as easy as: +Using LLMs and VLMs is as easy as: ```swift -let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") +import MLXLLM +import MLXLMHuggingFace +import MLXLMTokenizers + +let model = try await loadModel( + from: HubClient.default, + using: TokenizersLoader(), + id: "mlx-community/Qwen3-4B-4bit" +) let session = ChatSession(model) -print(try await session.respond(to: "What are two things to see in San Francisco?") -print(try await session.respond(to: "How about a great place to eat?") +print(try await session.respond(to: "What are two things to see in San Francisco?")) +print(try await session.respond(to: "How about a great place to eat?")) +``` + +## More Loading Scenarios + +Load from a local directory: + +```swift +import MLXLLM +import MLXLMTokenizers + +let modelDirectory = URL(filePath: "/path/to/model") +let container = try await loadModelContainer( + from: modelDirectory, + using: TokenizersLoader() +) ``` -For more information see +Use a custom Hugging Face client: + +```swift +import MLXLLM +import MLXLMHuggingFace +import MLXLMTokenizers + +let hub = HubClient(token: "hf_...") +let container = try await loadModelContainer( + from: hub, + using: TokenizersLoader(), + id: "mlx-community/Qwen3-4B-4bit" +) +``` + +Use a custom downloader: + +```swift +import MLXLLM +import MLXLMCommon +import MLXLMTokenizers + +struct S3Downloader: Downloader { + func download( + id: String, + revision: String?, + matching patterns: [String], + useLatest: Bool, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> URL { + // Download files and return a local directory URL. + return URL(filePath: "/tmp/model") + } +} + +let container = try await loadModelContainer( + from: S3Downloader(), + using: TokenizersLoader(), + id: "my-bucket/my-model" +) +``` + +For more information see [Evaluation](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxlmcommon/evaluation) or [Using Models](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxlmcommon/using-model) for more advanced API. @@ -38,13 +103,29 @@ of language models, from LLMs to VLMs: A model is typically loaded by using a `ModelFactory` and a `ModelConfiguration`: ```swift +import MLXLMCommon +import MLXLMHuggingFace +import MLXLMTokenizers + // e.g. VLMModelFactory.shared let modelFactory: ModelFactory // e.g. VLMRegistry.paligemma3bMix4488bit let modelConfiguration: ModelConfiguration -let container = try await modelFactory.loadContainer(configuration: modelConfiguration) +let container = try await modelFactory.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), + configuration: modelConfiguration +) + +// Custom Hub client (token, endpoint, etc.). +let customHub = HubClient(token: "hf_...") +let privateContainer = try await modelFactory.loadContainer( + from: customHub, + using: TokenizersLoader(), + configuration: modelConfiguration +) ``` The `container` provides an isolation context (an `actor`) to run inference in the model. @@ -62,15 +143,15 @@ The flow inside the `ModelFactory` goes like this: public class VLMModelFactory: ModelFactory { public func _load( - hub: HubApi, configuration: ModelConfiguration, - progressHandler: @Sendable @escaping (Progress) -> Void + configuration: ResolvedModelConfiguration, + tokenizerLoader: any TokenizerLoader ) async throws -> ModelContext { - // download the weight and config using HubApi + // modelDirectory and tokenizerDirectory are already resolved // load the base configuration // using the typeRegistry create a model (random weights) // load the weights, apply quantization as needed, update the model // calls model.sanitize() for weight preparation - // load the tokenizer + // load the tokenizer via tokenizerLoader.load(from: directory) // (vlm) load the processor configuration, create the processor } } diff --git a/Libraries/MLXLMCommon/Registries/ProcessorTypeRegistry.swift b/Libraries/MLXLMCommon/Registries/ProcessorTypeRegistry.swift index 67d1491eb..fa0e6974e 100644 --- a/Libraries/MLXLMCommon/Registries/ProcessorTypeRegistry.swift +++ b/Libraries/MLXLMCommon/Registries/ProcessorTypeRegistry.swift @@ -1,7 +1,6 @@ // Copyright © 2024 Apple Inc. import Foundation -import Tokenizers public actor ProcessorTypeRegistry { diff --git a/Libraries/MLXLMCommon/Tokenizer.swift b/Libraries/MLXLMCommon/Tokenizer.swift index 1d14fa1e0..9573e7a2c 100644 --- a/Libraries/MLXLMCommon/Tokenizer.swift +++ b/Libraries/MLXLMCommon/Tokenizer.swift @@ -1,126 +1,84 @@ // Copyright © 2024 Apple Inc. import Foundation -import Hub -import Tokenizers -struct TokenizerError: Error { - let message: String +/// A protocol for tokenizing text into token IDs and decoding token IDs into text. +public protocol Tokenizer: Sendable { + func encode(text: String, addSpecialTokens: Bool) -> [Int] + func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String + func convertTokenToId(_ token: String) -> Int? + func convertIdToToken(_ id: Int) -> String? + + var bosToken: String? { get } + var eosToken: String? { get } + var unknownToken: String? { get } + + func applyChatTemplate( + messages: [[String: any Sendable]], + tools: [[String: any Sendable]]?, + additionalContext: [String: any Sendable]? + ) throws -> [Int] } -public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer -{ - let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig( - configuration: configuration, hub: hub) - - return try PreTrainedTokenizer( - tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) -} +extension Tokenizer { + public func encode(text: String) -> [Int] { + encode(text: text, addSpecialTokens: true) + } -public func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async throws -> ( - Config, Config -) { - // from AutoTokenizer.from() -- this lets us override parts of the configuration - let config: LanguageModelConfigurationFromHub - - switch configuration.id { - case .id(let id, let revision): - do { - // the load can fail (async when we try to use it) - let loaded = LanguageModelConfigurationFromHub( - modelName: configuration.tokenizerId ?? id, revision: revision, hubApi: hub) - _ = try await loaded.tokenizerConfig - config = loaded - } catch { - let nserror = error as NSError - if nserror.domain == NSURLErrorDomain - && nserror.code == NSURLErrorNotConnectedToInternet - { - // Internet connection appears to be offline -- fall back to loading from - // the local directory - config = LanguageModelConfigurationFromHub( - modelFolder: configuration.modelDirectory(hub: hub), hubApi: hub) - } else { - throw error - } - } - case .directory(let directory): - config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub) + public func decode(tokenIds: [Int]) -> String { + decode(tokenIds: tokenIds, skipSpecialTokens: false) } - guard var tokenizerConfig = try await config.tokenizerConfig else { - throw TokenizerError(message: "missing config") + public var eosTokenId: Int? { + guard let eosToken else { return nil } + return convertTokenToId(eosToken) } - let tokenizerData = try await config.tokenizerData - tokenizerConfig = updateTokenizerConfig(tokenizerConfig) + public var unknownTokenId: Int? { + guard let unknownToken else { return nil } + return convertTokenToId(unknownToken) + } - return (tokenizerConfig, tokenizerData) -} + public func applyChatTemplate( + messages: [[String: any Sendable]] + ) throws -> [Int] { + try applyChatTemplate(messages: messages, tools: nil, additionalContext: nil) + } -private func updateTokenizerConfig(_ tokenizerConfig: Config) -> Config { - // Workaround: replacement tokenizers for unhandled values in swift-transformers - if let tokenizerClass = tokenizerConfig.tokenizerClass?.string(), - let replacement = replacementTokenizers[tokenizerClass] - { - if var dictionary = tokenizerConfig.dictionary() { - dictionary["tokenizer_class"] = .init(replacement) - return Config(dictionary) - } + public func applyChatTemplate( + messages: [[String: any Sendable]], + tools: [[String: any Sendable]]? + ) throws -> [Int] { + try applyChatTemplate(messages: messages, tools: tools, additionalContext: nil) } - return tokenizerConfig } -public class TokenizerReplacementRegistry: @unchecked Sendable { - - // Note: using NSLock as we have very small (just dictionary get/set) - // critical sections and expect no contention. this allows the methods - // to remain synchronous. - private let lock = NSLock() - - /// overrides for TokenizerModel/knownTokenizers - private var replacementTokenizers = [ - "InternLM2Tokenizer": "PreTrainedTokenizer", - "Qwen2Tokenizer": "PreTrainedTokenizer", - "Qwen3Tokenizer": "PreTrainedTokenizer", - "CohereTokenizer": "PreTrainedTokenizer", - "GPTNeoXTokenizer": "PreTrainedTokenizer", - "TokenizersBackend": "PreTrainedTokenizer", - ] - - public subscript(key: String) -> String? { - get { - lock.withLock { - replacementTokenizers[key] - } - } - set { - lock.withLock { - replacementTokenizers[key] = newValue - } +public enum TokenizerError: LocalizedError { + case missingChatTemplate + + public var errorDescription: String? { + switch self { + case .missingChatTemplate: + "This tokenizer does not have a chat template." } } } -public let replacementTokenizers = TokenizerReplacementRegistry() - public protocol StreamingDetokenizer: IteratorProtocol { - mutating func append(token: Int) - } public struct NaiveStreamingDetokenizer: StreamingDetokenizer { - let tokenizer: Tokenizer + let tokenizer: any Tokenizer var segmentTokens = [Int]() var segment = "" - public init(tokenizer: Tokenizer) { + public init(tokenizer: any Tokenizer) { self.tokenizer = tokenizer } - mutating public func append(token: Int) { + public mutating func append(token: Int) { segmentTokens.append(token) } @@ -129,14 +87,14 @@ public struct NaiveStreamingDetokenizer: StreamingDetokenizer { segmentTokens.removeAll() if let lastToken { segmentTokens.append(lastToken) - segment = tokenizer.decode(tokens: segmentTokens) + segment = tokenizer.decode(tokenIds: segmentTokens) } else { segment = "" } } public mutating func next() -> String? { - let newSegment = tokenizer.decode(tokens: segmentTokens) + let newSegment = tokenizer.decode(tokenIds: segmentTokens) let new = newSegment.suffix(newSegment.count - segment.count) // if the new segment ends with REPLACEMENT CHARACTER this means @@ -153,5 +111,4 @@ public struct NaiveStreamingDetokenizer: StreamingDetokenizer { return String(new) } - } diff --git a/Libraries/MLXLMCommon/TokenizerLoader.swift b/Libraries/MLXLMCommon/TokenizerLoader.swift new file mode 100644 index 000000000..a4f2b9b33 --- /dev/null +++ b/Libraries/MLXLMCommon/TokenizerLoader.swift @@ -0,0 +1,6 @@ +import Foundation + +/// A protocol for loading tokenizers from local directories. +public protocol TokenizerLoader: Sendable { + func load(from directory: URL) async throws -> any Tokenizer +} diff --git a/Libraries/MLXLMCommon/Tool/Parsers/GemmaFunctionParser.swift b/Libraries/MLXLMCommon/Tool/Parsers/GemmaFunctionParser.swift index c1917c318..3539381c0 100644 --- a/Libraries/MLXLMCommon/Tool/Parsers/GemmaFunctionParser.swift +++ b/Libraries/MLXLMCommon/Tool/Parsers/GemmaFunctionParser.swift @@ -69,7 +69,13 @@ public struct GemmaFunctionParser: ToolCallParser, Sendable { ? String(argsStr[argsStr.index(after: commaIdx)...]) : "" // Try JSON decode, fallback to string - arguments[key] = tryParseJSON(value) ?? value + if let data = value.data(using: .utf8), + let json = deserializeJSON(data) + { + arguments[key] = json + } else { + arguments[key] = value + } } return ToolCall(function: .init(name: funcName, arguments: arguments)) diff --git a/Libraries/MLXLMCommon/Tool/Parsers/ParserUtilities.swift b/Libraries/MLXLMCommon/Tool/Parsers/ParserUtilities.swift index 0d80850d8..a12d1064d 100644 --- a/Libraries/MLXLMCommon/Tool/Parsers/ParserUtilities.swift +++ b/Libraries/MLXLMCommon/Tool/Parsers/ParserUtilities.swift @@ -2,30 +2,42 @@ import Foundation -// MARK: - Basic Deserialization +// MARK: - JSON to Sendable Bridge -private func asSendable(_ value: Any) -> (any Sendable)? { +/// Convert a JSON-deserialized value to `any Sendable`. +/// +/// `JSONSerialization` returns `Any`, but all JSON types it produces +/// (String, NSNumber, NSNull, Array, Dictionary) are Sendable. +func asSendable(_ value: Any) -> any Sendable { switch value { - case let dict as [String: Any]: - return dict.compactMapValues { asSendable($0) } - case let array as [Any]: - return array.compactMap { asSendable($0) } - case let sendable as any Sendable: - return sendable - default: - return nil + case let s as String: return s + case let n as NSNumber: return n + case let a as [Any]: return a.map(asSendable) + case let d as [String: Any]: return d.mapValues(asSendable) + case let null as NSNull: return null + default: return "\(value)" } } +/// Deserialize JSON data, returning a Sendable value. +func deserializeJSON(_ data: Data) -> (any Sendable)? { + guard let object = try? JSONSerialization.jsonObject(with: data) else { return nil } + return asSendable(object) +} + +// MARK: - Basic Deserialization + /// Deserialize a string value to JSON or return as string. /// /// Attempts JSON parsing first, falling back to the original string value. /// Reference: Python's `ast.literal_eval` / `json.loads` pattern func tryParseJSON(_ value: String) -> (any Sendable)? { - if let data = value.data(using: .utf8) { - return try? asSendable(JSONSerialization.jsonObject(with: data)) - } - return nil + guard let data = value.data(using: .utf8) else { return nil } + return deserializeJSON(data) +} + +func deserialize(_ value: String) -> any Sendable { + tryParseJSON(value) ?? value } // MARK: - Schema Lookup Functions @@ -93,7 +105,7 @@ func extractTypesFromSchema(_ schema: [String: any Sendable]?) -> [String] { } // Handle enum - infer types from enum values - if let enumValues = schema["enum"] as? [Any], !enumValues.isEmpty { + if let enumValues = schema["enum"] as? [any Sendable], !enumValues.isEmpty { for value in enumValues { switch value { case is NSNull: types.insert("null") @@ -101,8 +113,8 @@ func extractTypesFromSchema(_ schema: [String: any Sendable]?) -> [String] { case is Int: types.insert("integer") case is Double: types.insert("number") case is String: types.insert("string") - case is [Any]: types.insert("array") - case is [String: Any]: types.insert("object") + case is [any Sendable]: types.insert("array") + case is [String: any Sendable]: types.insert("object") default: break } } diff --git a/Libraries/MLXLMCommon/Tool/Tool.swift b/Libraries/MLXLMCommon/Tool/Tool.swift index 81b08859e..2373724d0 100644 --- a/Libraries/MLXLMCommon/Tool/Tool.swift +++ b/Libraries/MLXLMCommon/Tool/Tool.swift @@ -1,7 +1,8 @@ // Copyright © 2025 Apple Inc. import Foundation -import Tokenizers + +public typealias ToolSpec = [String: any Sendable] /// Protocol defining the requirements for a tool. public protocol ToolProtocol: Sendable { diff --git a/Libraries/MLXLMCommon/UserInput.swift b/Libraries/MLXLMCommon/UserInput.swift index 7b8954bcc..0a2e63de4 100644 --- a/Libraries/MLXLMCommon/UserInput.swift +++ b/Libraries/MLXLMCommon/UserInput.swift @@ -4,7 +4,6 @@ import CoreImage import Foundation import MLX -import Tokenizers public typealias Message = [String: any Sendable] diff --git a/Libraries/MLXLMCommon/WiredMemoryUtils.swift b/Libraries/MLXLMCommon/WiredMemoryUtils.swift index 54ee458cb..70602cb16 100644 --- a/Libraries/MLXLMCommon/WiredMemoryUtils.swift +++ b/Libraries/MLXLMCommon/WiredMemoryUtils.swift @@ -3,7 +3,6 @@ import Foundation import MLX import MLXNN -import Tokenizers /// Result of a wired memory measurement pass. public struct WiredMemoryMeasurement: Sendable { @@ -32,7 +31,7 @@ public enum WiredMemoryUtils { /// /// This does not attempt to generate semantically meaningful text; it only ensures /// a valid token sequence of the requested length for memory sizing purposes. - private static func makeTokenIDs( + private static func makeTokenIds( count: Int, tokenizer: Tokenizer, seedText: String = " hello" @@ -75,8 +74,8 @@ public enum WiredMemoryUtils { tokenizer: Tokenizer, seedText: String = " hello" ) -> LMInput { - let tokenIDs = makeTokenIDs(count: count, tokenizer: tokenizer, seedText: seedText) - return LMInput(tokens: MLXArray(tokenIDs)) + let tokenIds = makeTokenIds(count: count, tokenizer: tokenizer, seedText: seedText) + return LMInput(tokens: MLXArray(tokenIds)) } /// Run a prefill-only pass to populate caches for the given input. diff --git a/Libraries/MLXVLM/Models/FastVLM.swift b/Libraries/MLXVLM/Models/FastVLM.swift index 16716afe1..dcfe56650 100644 --- a/Libraries/MLXVLM/Models/FastVLM.swift +++ b/Libraries/MLXVLM/Models/FastVLM.swift @@ -9,11 +9,9 @@ import CoreImage import Foundation -import Hub import MLX import MLXLMCommon import MLXNN -import Tokenizers // MARK: - Configuration @@ -1005,7 +1003,7 @@ public struct FastVLMProcessor: UserInputProcessor { let promptTokens = try tokenizer.applyChatTemplate( messages: messages, tools: input.tools, additionalContext: input.additionalContext) - let decoded = tokenizer.decode(tokens: promptTokens, skipSpecialTokens: false) + let decoded = tokenizer.decode(tokenIds: promptTokens, skipSpecialTokens: false) // Find and replace with token id -200 let pieces = decoded.split(separator: imageToken) diff --git a/Libraries/MLXVLM/Models/Gemma3.swift b/Libraries/MLXVLM/Models/Gemma3.swift index df7259084..d0c968f75 100644 --- a/Libraries/MLXVLM/Models/Gemma3.swift +++ b/Libraries/MLXVLM/Models/Gemma3.swift @@ -2,7 +2,6 @@ import CoreImage import MLX import MLXLMCommon import MLXNN -import Tokenizers // Based on https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/gemma3 diff --git a/Libraries/MLXVLM/Models/GlmOcr.swift b/Libraries/MLXVLM/Models/GlmOcr.swift index 1b25d7e8b..31b1e010f 100644 --- a/Libraries/MLXVLM/Models/GlmOcr.swift +++ b/Libraries/MLXVLM/Models/GlmOcr.swift @@ -9,11 +9,9 @@ import CoreImage import Foundation -import Hub import MLX import MLXLMCommon import MLXNN -import Tokenizers // MARK: - Language diff --git a/Libraries/MLXVLM/Models/Idefics3.swift b/Libraries/MLXVLM/Models/Idefics3.swift index ca6a5a6e7..7c8cd5ff3 100644 --- a/Libraries/MLXVLM/Models/Idefics3.swift +++ b/Libraries/MLXVLM/Models/Idefics3.swift @@ -9,11 +9,9 @@ import CoreImage import Foundation -import Hub import MLX import MLXLMCommon import MLXNN -import Tokenizers // MARK: - Configuration diff --git a/Libraries/MLXVLM/Models/LFM2VL.swift b/Libraries/MLXVLM/Models/LFM2VL.swift index 170c3e728..0d99a766a 100644 --- a/Libraries/MLXVLM/Models/LFM2VL.swift +++ b/Libraries/MLXVLM/Models/LFM2VL.swift @@ -2,11 +2,9 @@ import CoreImage import Foundation -import Hub import MLX import MLXLMCommon import MLXNN -import Tokenizers // MARK: - Vision diff --git a/Libraries/MLXVLM/Models/Mistral3.swift b/Libraries/MLXVLM/Models/Mistral3.swift index d64bac359..57209ab34 100644 --- a/Libraries/MLXVLM/Models/Mistral3.swift +++ b/Libraries/MLXVLM/Models/Mistral3.swift @@ -3,7 +3,6 @@ import Foundation import MLX import MLXLMCommon import MLXNN -import Tokenizers // Port of https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/mistral3 // Note: Mistral3 reuses the vision model from Pixtral @@ -1027,7 +1026,7 @@ public struct Mistral3VLMProcessor: UserInputProcessor { ) // Decode to find and replace image placeholder token - let decoded = tokenizer.decode(tokens: promptTokens, skipSpecialTokens: false) + let decoded = tokenizer.decode(tokenIds: promptTokens, skipSpecialTokens: false) // Process image to get dimensions let preprocessResult = try preprocessImage( diff --git a/Libraries/MLXVLM/Models/Paligemma.swift b/Libraries/MLXVLM/Models/Paligemma.swift index 416e3e5ab..cc16d0cbe 100644 --- a/Libraries/MLXVLM/Models/Paligemma.swift +++ b/Libraries/MLXVLM/Models/Paligemma.swift @@ -4,11 +4,9 @@ import CoreImage import Foundation -import Hub import MLX import MLXLMCommon import MLXNN -import Tokenizers // MARK: - Language diff --git a/Libraries/MLXVLM/Models/Pixtral.swift b/Libraries/MLXVLM/Models/Pixtral.swift index cbb4a4982..69dcc5530 100644 --- a/Libraries/MLXVLM/Models/Pixtral.swift +++ b/Libraries/MLXVLM/Models/Pixtral.swift @@ -3,7 +3,6 @@ import Foundation import MLX import MLXLMCommon import MLXNN -import Tokenizers // Port of https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/pixtral diff --git a/Libraries/MLXVLM/Models/Qwen25VL.swift b/Libraries/MLXVLM/Models/Qwen25VL.swift index c380d10e3..abd4912d9 100644 --- a/Libraries/MLXVLM/Models/Qwen25VL.swift +++ b/Libraries/MLXVLM/Models/Qwen25VL.swift @@ -2,11 +2,9 @@ import CoreImage import Foundation -import Hub import MLX import MLXLMCommon import MLXNN -import Tokenizers // MARK: - Language diff --git a/Libraries/MLXVLM/Models/Qwen2VL.swift b/Libraries/MLXVLM/Models/Qwen2VL.swift index 2f44f04f3..701f755f1 100644 --- a/Libraries/MLXVLM/Models/Qwen2VL.swift +++ b/Libraries/MLXVLM/Models/Qwen2VL.swift @@ -4,11 +4,9 @@ import CoreImage import Foundation -import Hub import MLX import MLXLMCommon import MLXNN -import Tokenizers // MARK: - Language diff --git a/Libraries/MLXVLM/Models/Qwen3VL.swift b/Libraries/MLXVLM/Models/Qwen3VL.swift index 27b6e8a50..22461a92d 100644 --- a/Libraries/MLXVLM/Models/Qwen3VL.swift +++ b/Libraries/MLXVLM/Models/Qwen3VL.swift @@ -4,11 +4,9 @@ import CoreImage import Foundation -import Hub import MLX import MLXLMCommon import MLXNN -import Tokenizers private enum Qwen3VLError: Error { case featureTokenMismatch(expected: Int, actual: Int) diff --git a/Libraries/MLXVLM/Models/QwenVL.swift b/Libraries/MLXVLM/Models/QwenVL.swift index 7678fd07b..acd4602aa 100644 --- a/Libraries/MLXVLM/Models/QwenVL.swift +++ b/Libraries/MLXVLM/Models/QwenVL.swift @@ -1,10 +1,8 @@ import CoreImage import Foundation -import Hub import MLX import MLXLMCommon import MLXNN -import Tokenizers // port of https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/qwen2_vl diff --git a/Libraries/MLXVLM/Models/SmolVLM2.swift b/Libraries/MLXVLM/Models/SmolVLM2.swift index efb9752b7..c3b6ffa7b 100644 --- a/Libraries/MLXVLM/Models/SmolVLM2.swift +++ b/Libraries/MLXVLM/Models/SmolVLM2.swift @@ -10,7 +10,6 @@ import CoreMedia import Foundation import MLX import MLXLMCommon -import Tokenizers // MARK: - Configuration and modeling are Idefics3 @@ -241,7 +240,7 @@ public struct SmolVLMProcessor: UserInputProcessor { let promptTokens = try tokenizer.applyChatTemplate( messages: messages, tools: input.tools, additionalContext: input.additionalContext) - let decoded = tokenizer.decode(tokens: promptTokens, skipSpecialTokens: false) + let decoded = tokenizer.decode(tokenIds: promptTokens, skipSpecialTokens: false) let image = try input.images[0].asCIImage().toSRGB() let (tiles, imageRows, imageCols) = tiles(from: image) @@ -308,7 +307,7 @@ public struct SmolVLMProcessor: UserInputProcessor { // Unfortunately we don't have a "render" option in Tokenizers yet, so decoding let promptTokens = try tokenizer.applyChatTemplate( messages: messagesWithSystem(messages)) - let decoded = tokenizer.decode(tokens: promptTokens, skipSpecialTokens: false) + let decoded = tokenizer.decode(tokenIds: promptTokens, skipSpecialTokens: false) let video = input.videos[0] diff --git a/Libraries/MLXVLM/README.md b/Libraries/MLXVLM/README.md index 83120543e..2c5f1f262 100644 --- a/Libraries/MLXVLM/README.md +++ b/Libraries/MLXVLM/README.md @@ -12,11 +12,18 @@ Using LLMs and VLMs from MLXLMCommon is as easy as: ```swift -let model = try await loadModel(id: "mlx-community/Qwen2.5-VL-3B-Instruct-4bit") +import MLXVLM +import MLXLMHuggingFace +import MLXLMTokenizers + +let model = try await loadModel( + using: TokenizersLoader(), + id: "mlx-community/Qwen2.5-VL-3B-Instruct-4bit" +) let session = ChatSession(model) let answer1 = try await session.respond( - to: "what kind of creature is in the picture?" + to: "what kind of creature is in the picture?", image: .url(URL(fileURLWithPath: "support/test.jpg")) ) print(answer1) @@ -39,9 +46,8 @@ This is a port of several models from: - https://github.com/Blaizzy/mlx-vlm -using the Hugging Face swift transformers package to provide tokenization: - -- https://github.com/huggingface/swift-transformers +Tokenization is provided via the `TokenizerLoader` protocol – see the main +[README](../../README.md) for available integration packages. The [VLMModelFactory.swift](VLMModelFactory.swift) provides minor overrides and customization -- if you require overrides for the tokenizer or prompt customizations they can be @@ -298,7 +304,7 @@ media as needed. For example it might: - modify the prompt by injecting `` tokens that the model expects In the python implementations, much of this code typically lives in the `transformers` -package from huggingface -- inspection will be required to determine which code +package from Hugging Face -- inspection will be required to determine which code is called and what it does. You can examine the processors in the `Models` directory: they reference the files and functions that they are based on. diff --git a/Libraries/MLXVLM/VLMModelFactory.swift b/Libraries/MLXVLM/VLMModelFactory.swift index c3f65df7d..64cee4494 100644 --- a/Libraries/MLXVLM/VLMModelFactory.swift +++ b/Libraries/MLXVLM/VLMModelFactory.swift @@ -1,10 +1,8 @@ // Copyright © 2024 Apple Inc. import Foundation -import Hub import MLX import MLXLMCommon -import Tokenizers public enum VLMError: LocalizedError, Equatable { case imageRequired @@ -77,7 +75,7 @@ private func create( /// Registry of model type, e.g 'llama', to functions that can instantiate the model from configuration. /// -/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``. +/// Typically called via ``LLMModelFactory/load(from:configuration:progressHandler:)``. public enum VLMTypeRegistry { /// Shared instance with default model types. @@ -284,12 +282,10 @@ public final class VLMModelFactory: ModelFactory { public let modelRegistry: AbstractModelRegistry public func _load( - hub: HubApi, configuration: ModelConfiguration, - progressHandler: @Sendable @escaping (Progress) -> Void + configuration: ResolvedModelConfiguration, + tokenizerLoader: any TokenizerLoader ) async throws -> sending ModelContext { - // download weights and config - let modelDirectory = try await downloadModel( - hub: hub, configuration: configuration, progressHandler: progressHandler) + let modelDirectory = configuration.modelDirectory // Load config.json once and decode for both base config and model-specific config let configurationURL = modelDirectory.appending(component: "config.json") @@ -328,7 +324,6 @@ public final class VLMModelFactory: ModelFactory { eosTokenIds = Set(genEosIds) // Override per Python mlx-lm behavior } - // Create mutable configuration with loaded EOS token IDs var mutableConfiguration = configuration mutableConfiguration.eosTokenIds = eosTokenIds @@ -337,11 +332,13 @@ public final class VLMModelFactory: ModelFactory { mutableConfiguration.toolCallFormat = ToolCallFormat.infer(from: baseConfig.modelType) } - // Load tokenizer, processor config, and weights in parallel using async let. + // Load tokenizer from model directory (or alternate tokenizer repo), + // processor config, and weights in parallel using async let. // Note: loadProcessorConfig does synchronous I/O but is marked async to enable // parallel scheduling. This may briefly block a cooperative thread pool thread, // but the config file is small and model loading is not a high-concurrency path. - async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub) + async let tokenizerTask = tokenizerLoader.load( + from: configuration.tokenizerDirectory) async let processorConfigTask = loadProcessorConfig(from: modelDirectory) try loadWeights( @@ -375,8 +372,21 @@ public final class VLMModelFactory: ModelFactory { configuration: processorConfigData, processorType: processorType, tokenizer: tokenizer) + // Build a ModelConfiguration for the ModelContext + let tokenizerSource: TokenizerSource? = + configuration.tokenizerDirectory == modelDirectory + ? nil + : .directory(configuration.tokenizerDirectory) + let modelConfig = ModelConfiguration( + directory: modelDirectory, + tokenizerSource: tokenizerSource, + defaultPrompt: configuration.defaultPrompt, + extraEOSTokens: mutableConfiguration.extraEOSTokens, + eosTokenIds: mutableConfiguration.eosTokenIds, + toolCallFormat: mutableConfiguration.toolCallFormat) + return .init( - configuration: mutableConfiguration, model: model, processor: processor, + configuration: modelConfig, model: model, processor: processor, tokenizer: tokenizer) } diff --git a/Package.swift b/Package.swift index b2a8c2528..8dcf54c51 100644 --- a/Package.swift +++ b/Package.swift @@ -1,6 +1,7 @@ // swift-tools-version: 6.1 // The swift-tools-version declares the minimum version of Swift required to build this package. +import CompilerPluginSupport import PackageDescription let package = Package( @@ -24,13 +25,19 @@ let package = Package( .library( name: "MLXEmbedders", targets: ["MLXEmbedders"]), + .library( + name: "MLXHuggingFace", + targets: ["MLXHuggingFace"]), + .library( + name: "BenchmarkHelpers", + targets: ["BenchmarkHelpers"]), + .library( + name: "IntegrationTestHelpers", + targets: ["IntegrationTestHelpers"]), ], dependencies: [ .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")), - .package( - url: "https://github.com/huggingface/swift-transformers", - .upToNextMinor(from: "1.2.0") - ), + .package(url: "https://github.com/swiftlang/swift-syntax.git", from: "600.0.0-latest"), ], targets: [ .target( @@ -40,7 +47,6 @@ let package = Package( .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "Transformers", package: "swift-transformers"), ], path: "Libraries/MLXLLM", exclude: [ @@ -54,7 +60,6 @@ let package = Package( .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "Transformers", package: "swift-transformers"), ], path: "Libraries/MLXVLM", exclude: [ @@ -67,7 +72,6 @@ let package = Package( .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "Transformers", package: "swift-transformers"), ], path: "Libraries/MLXLMCommon", exclude: [ @@ -79,7 +83,6 @@ let package = Package( dependencies: [ .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), - .product(name: "Transformers", package: "swift-transformers"), .target(name: "MLXLMCommon"), ], path: "Libraries/MLXEmbedders", @@ -87,49 +90,61 @@ let package = Package( "README.md" ] ), - .testTarget( - name: "MLXLMTests", + .target( + name: "BenchmarkHelpers", dependencies: [ - .product(name: "MLX", package: "mlx-swift"), - .product(name: "MLXNN", package: "mlx-swift"), - .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "Transformers", package: "swift-transformers"), "MLXLMCommon", "MLXLLM", "MLXVLM", "MLXEmbedders", + .product(name: "MLX", package: "mlx-swift"), ], - path: "Tests/MLXLMTests", - exclude: [ - "README.md" + path: "Libraries/BenchmarkHelpers" + ), + .target( + name: "IntegrationTestHelpers", + dependencies: [ + "MLXLMCommon", + "MLXLLM", + "MLXVLM", + "MLXEmbedders", + .product(name: "MLX", package: "mlx-swift"), ], - resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")] + path: "Libraries/IntegrationTestHelpers", + exclude: ["README.md"] ), .testTarget( - name: "MLXLMIntegrationTests", + name: "MLXLMTests", dependencies: [ .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXNN", package: "mlx-swift"), .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "Transformers", package: "swift-transformers"), "MLXLMCommon", "MLXLLM", "MLXVLM", "MLXEmbedders", ], - path: "Tests/MLXLMIntegrationTests", + path: "Tests/MLXLMTests", exclude: [ "README.md" - ] + ], + resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")] ), - .testTarget( - name: "Benchmarks", + .macro( + name: "MLXHuggingFaceMacros", dependencies: [ - "MLXLLM", - "MLXVLM", + .product(name: "SwiftSyntaxMacros", package: "swift-syntax"), + .product(name: "SwiftCompilerPlugin", package: "swift-syntax"), + ], + path: "Libraries/MLXHuggingFaceMacros" + ), + .target( + name: "MLXHuggingFace", + dependencies: [ + "MLXHuggingFaceMacros", "MLXLMCommon", ], - path: "Tests/Benchmarks" + path: "Libraries/MLXHuggingFace" ), ] ) @@ -137,7 +152,6 @@ let package = Package( if Context.environment["MLX_SWIFT_BUILD_DOC"] == "1" || Context.environment["SPI_GENERATE_DOCS"] == "1" { - // docc builder package.dependencies.append( .package(url: "https://github.com/apple/swift-docc-plugin", from: "1.3.0") ) diff --git a/README.md b/README.md index dfab56302..bf4e54f17 100644 --- a/README.md +++ b/README.md @@ -1,68 +1,298 @@ # MLX Swift LM -MLX Swift LM is a Swift package to build tools and applications with large -language models (LLMs) and vision language models (VLMs) in [MLX Swift](https://github.com/ml-explore/mlx-swift). +> [!IMPORTANT] +> The `main` branch is a _new_ major version number: 3.x. In order +> to decouple from tokenizer and downloader packages some breaking +> changes were introduced. See [Breaking Changes](#breaking-changes) for more information. + +MLX Swift LM is a Swift package to build tools and applications with large language models (LLMs) and vision language models (VLMs) in [MLX Swift](https://github.com/ml-explore/mlx-swift). Some key features include: -- Integration with the Hugging Face Hub to easily use thousands of LLMs with a single command. +- Model loading with integrations for a variety of tokenizer and model downloading packages. - Low-rank (LoRA) and full model fine-tuning with support for quantized models. - Many model architectures for both LLMs and VLMs. -For some example applications and tools that use MLX Swift LM check out -the [MLX Swift Examples](https://github.com/ml-explore/mlx-swift-examples). - -# Using MLX Swift LM +For some example applications and tools that use MLX Swift LM, check out [MLX Swift Examples](https://github.com/ml-explore/mlx-swift-examples). -The MLXLLM, MLXVLM, MLXLMCommon, and MLXEmbedders libraries are available -as Swift Packages. +## Usage -Add the following dependency to your Package.swift: +Add the core package to your `Package.swift`: ```swift -.package(url: "https://github.com/ml-explore/mlx-swift-lm/", branch: "main"), +.package(url: "https://github.com/ml-explore/mlx-swift-lm", branch: "main"), ``` -or use the latest release: +Then add your preferred tokenizer and downloader integrations: ```swift -.package(url: "https://github.com/ml-explore/mlx-swift-lm/", .upToNextMinor(from: "2.29.1")), +.package(url: "https://github.com/DePasqualeOrg/swift-tokenizers-mlx", from: "0.1.0"), +.package(url: "https://github.com/DePasqualeOrg/swift-hf-api-mlx", from: "0.1.0"), ``` -Then add one or more libraries to the target as a dependency: +And add the libraries to your target: ```swift .target( name: "YourTargetName", dependencies: [ - .product(name: "MLXLLM", package: "mlx-swift-lm") + .product(name: "MLXLLM", package: "mlx-swift-lm"), + .product(name: "MLXLMTokenizers", package: "swift-tokenizers-mlx"), + .product(name: "MLXLMHuggingFace", package: "swift-hf-api-mlx"), ]), ``` -Alternatively, add `https://github.com/ml-explore/mlx-swift-lm/` to the -`Project Dependencies` and set the `Dependency Rule` to `Branch` and `main` in -Xcode. +### Tokenizer and Downloader Integrations + +MLX Swift LM focuses on model implementations. Tokenization and model downloading are handled by separate packages. Adapters make it easy to use your preferred downloader and tokenizer packages. + +| Downloader package | Adapter | +| ------------------------------------------------------------ | ------------------------------------------------------------ | +| [huggingface/swift-huggingface](https://github.com/huggingface/swift-huggingface) | [DePasqualeOrg/swift-huggingface-mlx](https://github.com/DePasqualeOrg/swift-huggingface-mlx) | +| [DePasqualeOrg/swift-hf-api](https://github.com/DePasqualeOrg/swift-hf-api) | [DePasqualeOrg/swift-hf-api-mlx](https://github.com/DePasqualeOrg/swift-hf-api-mlx) | + +| Tokenizer package | Adapter | +| ------------------------------------------------------------ | ------------------------------------------------------------ | +| [DePasqualeOrg/swift-tokenizers](https://github.com/DePasqualeOrg/swift-tokenizers) | [DePasqualeOrg/swift-tokenizers-mlx](https://github.com/DePasqualeOrg/swift-tokenizers-mlx) | +| [huggingface/swift-transformers](https://github.com/huggingface/swift-transformers) | [DePasqualeOrg/swift-transformers-mlx](https://github.com/DePasqualeOrg/swift-transformers-mlx) | + +> **Note:** The adapters are offered for convenience and are not required. You can also use tokenizer and downloader packages directly by setting up the required protocol conformance for MLX Swift LM. See the integration packages for examples of how to do this. -# Quick Start +### Quick Start -See also [MLXLMCommon](Libraries/MLXLMCommon). You can get started with a wide -variety of open weights LLMs and VLMs using this simplified API: +You can get started with a wide variety of open-weights LLMs and VLMs using this simplified API (for more details, see [MLXLMCommon](Libraries/MLXLMCommon)): ```swift -let model = try await loadModel(id: "mlx-community/Qwen3-4B-4bit") +import MLXLLM +import MLXLMHuggingFace +import MLXLMTokenizers + +let model = try await loadModel( + from: HubClient.default, + using: TokenizersLoader(), + id: "mlx-community/Qwen3-4B-4bit" +) let session = ChatSession(model) print(try await session.respond(to: "What are two things to see in San Francisco?")) print(try await session.respond(to: "How about a great place to eat?")) ``` +Loading from a local directory: + +```swift +import MLXLLM +import MLXLMTokenizers + +let modelDirectory = URL(filePath: "/path/to/model") +let container = try await loadModelContainer( + from: modelDirectory, + using: TokenizersLoader() +) +``` + +Use a custom Hugging Face client: + +```swift +import MLXLLM +import MLXLMHuggingFace +import MLXLMTokenizers + +let hub = HubClient(token: "hf_...") +let container = try await loadModelContainer( + from: hub, + using: TokenizersLoader(), + id: "mlx-community/Qwen3-4B-4bit" +) +``` + +Use a custom downloader: + +```swift +import MLXLLM +import MLXLMCommon +import MLXLMTokenizers + +struct S3Downloader: Downloader { + func download( + id: String, + revision: String?, + matching patterns: [String], + useLatest: Bool, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> URL { + // Download files and return a local directory URL. + return URL(filePath: "/tmp/model") + } +} + +let container = try await loadModelContainer( + from: S3Downloader(), + using: TokenizersLoader(), + id: "my-bucket/my-model" +) +``` + Or use the underlying API to control every aspect of the evaluation. -# Documentation +## Migrating to Version 3 + +Version 3 of MLX Swift LM decouples the tokenizer and downloader implementations. See the [integrations](#Tokenizer and Downloader Integrations) section for details. + +### New dependencies + +Add your preferred tokenizer and downloader adapters: + +```swift +// Before (2.x) – single dependency +.package(url: "https://github.com/ml-explore/mlx-swift-lm/", from: "2.30.0"), + +// After (3.x) – core + adapters +.package(url: "https://github.com/ml-explore/mlx-swift-lm/", from: "3.0.0"), +.package(url: "https://github.com/DePasqualeOrg/swift-tokenizers-mlx/", from: "0.1.0"), +.package(url: "https://github.com/DePasqualeOrg/swift-hf-api-mlx/", from: "0.1.0"), +``` + +And add their products to your target: + +```swift +.product(name: "MLXLMTokenizers", package: "swift-tokenizers-mlx"), +.product(name: "MLXLMHFAPI", package: "swift-hf-api-mlx"), + +// If you use MLXEmbedders: +.product(name: "MLXEmbeddersTokenizers", package: "swift-tokenizers-mlx"), +.product(name: "MLXEmbeddersHFAPI", package: "swift-hf-api-mlx"), +``` + +### New imports + +```swift +// Before (2.x) +import MLXLLM + +// After (3.x) +import MLXLLM +import MLXLMHFAPI // Downloader adapter +import MLXLMTokenizers // Tokenizer adapter +``` + +If you use MLXEmbedders: + +```swift +import MLXEmbedders +import MLXEmbeddersHFAPI // Downloader adapter +import MLXEmbeddersTokenizers // Tokenizer adapter +``` + +### Loading API changes + +The core APIs now include a `from:` parameter of type `URL` or `any Downloader` as well as a `using:` parameter for the tokenizer loader. Tokenizer integration packages may supply convenience methods with a default tokenizer loader, allowing you to omit the `using:` parameter. + +The most visible call-site changes are: + +- `hub:` → `from:`: Models are now loaded from a directory `URL` or `Downloader`. +- `HubApi` → `HubClient`: A new implementation of the Hugging Face Hub client is used. + +Example when downloading from Hugging Face: + +```swift +// Before (2.x) – hub defaulted to HubApi() +let container = try await loadModelContainer( + id: "mlx-community/Qwen3-4B-4bit" +) + +// After (3.x) – Using Swift Hugging Face + Swift Tokenizers +let container = try await loadModelContainer( + from: HubClient.default, + id: "mlx-community/Qwen3-4B-4bit" +) +``` + +At the lower-level core API, you can still pass any `Downloader` and any `TokenizerLoader` explicitly. + +Loading from a local directory: + +```swift +// Before (2.x) +let container = try await loadModelContainer(directory: modelDirectory) + +// After (3.x) +let container = try await loadModelContainer(from: modelDirectory) +``` + +Loading with a model factory: + +```swift +let container = try await LLMModelFactory.shared.loadContainer( + from: HubClient.default, + configuration: modelConfiguration +) +``` + +Loading an embedder: + +```swift +import MLXEmbedders +import MLXEmbeddersHFAPI +import MLXEmbeddersTokenizers + +let container = try await loadModelContainer( + from: HubClient.default, + configuration: .configuration(id: "sentence-transformers/all-MiniLM-L6-v2") +) +``` + +### Renamed methods + +`decode(tokens:)` is renamed to `decode(tokenIds:)` to align with the `transformers` library in Python: + +```swift +// Before (2.x) +let text = tokenizer.decode(tokens: ids) + +// After (3.0) +let text = tokenizer.decode(tokenIds: ids) +``` + +## Documentation Developers can use these examples in their own programs -- just import the swift package! - [Porting and implementing models](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxlmcommon/porting) -- [MLXLLMCommon](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxlmcommon) -- common API for LLM and VLM -- [MLXLLM](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxllm) -- large language model example implementations -- [MLXVLM](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxvlm) -- vision language model example implementations -- [MLXEmbedders](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxembedders) -- popular Encoders / Embedding models example implementations +- [MLXLLMCommon](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxlmcommon): Common API for LLM and VLM +- [MLXLLM](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxllm): Large language model example implementations +- [MLXVLM](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxvlm): Vision language model example implementations +- [MLXEmbedders](https://swiftpackageindex.com/ml-explore/mlx-swift-lm/main/documentation/mlxembedders): Popular encoders and embedding models example implementations + +## Breaking Changes + +### Loading API + +The `hub` parameter (previously `HubApi`) has been replaced with `from` (any `Downloader` or `URL` for a local directory). Functions that previously defaulted to `defaultHubApi` no longer have a default – callers must either pass a `Downloader` explicitly or use the convenience methods in `MLXLMHuggingFace` / `MLXEmbeddersHuggingFace`, which default to `HubClient.default`. + +For most users who were using the default Hub client, adding `import MLXLMHuggingFace` or `import MLXEmbeddersHuggingFace` and using the convenience overloads is sufficient. + +Users who were passing a custom `HubApi` instance should create a `HubClient` instead and pass it as the `from` parameter. `HubClient` conforms to `Downloader` via `MLXLMHuggingFace`. + +### `ModelConfiguration` + +- `tokenizerId` and `overrideTokenizer` have been replaced by `tokenizerSource: TokenizerSource?`, which supports `.id(String)` for remote sources and `.directory(URL)` for local paths. +- `preparePrompt` has been removed. This shouldn't be used anyway, since support for chat templates is available. +- `modelDirectory(hub:)` has been removed. For local directories, pass the `URL` directly to the loading functions. For remote models, the `Downloader` protocol handles resolution. + +### Tokenizer loading + +`loadTokenizer(configuration:hub:)` has been removed. Tokenizer loading now uses `AutoTokenizer.from(directory:)` from Swift Tokenizers directly. + +`replacementTokenizers` (the `TokenizerReplacementRegistry`) has been removed. Use `AutoTokenizer.register(_:for:)` from Swift Tokenizers instead. + +### `defaultHubApi` + +The `defaultHubApi` global has been removed. Hugging Face Hub access is now provided by `HubClient.default` from the `HuggingFace` module. + +### Low-level APIs + +- `downloadModel(hub:configuration:progressHandler:)` → `Downloader.download(id:revision:matching:useLatest:progressHandler:)` +- `loadTokenizerConfig(configuration:hub:)` → `AutoTokenizer.from(directory:)` +- `ModelFactory._load(hub:configuration:progressHandler:)` → `_load(configuration: ResolvedModelConfiguration)` +- `ModelFactory._loadContainer`: removed (base `loadContainer` now builds the container from `_load`) + diff --git a/Tests/Benchmarks/ModelLoadingBenchmarks.swift b/Tests/Benchmarks/ModelLoadingBenchmarks.swift deleted file mode 100644 index cdbeb9458..000000000 --- a/Tests/Benchmarks/ModelLoadingBenchmarks.swift +++ /dev/null @@ -1,112 +0,0 @@ -import Foundation -import Hub -import MLX -import MLXLLM -import MLXLMCommon -import MLXVLM -import Testing - -private let benchmarksEnabled = ProcessInfo.processInfo.environment["RUN_BENCHMARKS"] != nil - -private struct BenchmarkStats { - let mean: Double - let median: Double - let stdDev: Double - let min: Double - let max: Double - - init(times: [Double]) { - precondition(!times.isEmpty, "BenchmarkStats requires at least one timing measurement") - let sorted = times.sorted() - self.min = sorted.first ?? 0 - self.max = sorted.last ?? 0 - let mean = times.reduce(0, +) / Double(times.count) - self.mean = mean - self.median = sorted[sorted.count / 2] - - let squaredDiffs = times.map { ($0 - mean) * ($0 - mean) } - self.stdDev = sqrt(squaredDiffs.reduce(0, +) / Double(times.count)) - } - - func printSummary(label: String) { - print("\(label) results:") - print(" Mean: \(String(format: "%.0f", mean))ms") - print(" Median: \(String(format: "%.0f", median))ms") - print(" StdDev: \(String(format: "%.1f", stdDev))ms") - print(" Range: \(String(format: "%.0f", min))-\(String(format: "%.0f", max))ms") - } -} - -@Suite(.serialized) -struct ModelLoadingBenchmarks { - - /// Benchmark LLM model loading - /// Tests: parallel tokenizer/weights, single config.json read - @Test(.enabled(if: benchmarksEnabled)) - func loadLLM() async throws { - let modelId = "mlx-community/Qwen3-0.6B-4bit" - let hub = HubApi() - let config = ModelConfiguration(id: modelId) - - // Warm-up run: ensure model is downloaded and caches are primed - _ = try await LLMModelFactory.shared.load(hub: hub, configuration: config) { _ in } - Memory.clearCache() - - // Benchmark multiple runs - let runs = 7 - var times: [Double] = [] - - for i in 1 ... runs { - let start = CFAbsoluteTimeGetCurrent() - - _ = try await LLMModelFactory.shared.load( - hub: hub, - configuration: config - ) { _ in } - - let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000 - times.append(elapsed) - print("LLM load run \(i): \(String(format: "%.0f", elapsed))ms") - - // Clear GPU cache to ensure independent measurements - Memory.clearCache() - } - - BenchmarkStats(times: times).printSummary(label: "LLM load") - } - - /// Benchmark VLM model loading - /// Tests: parallel tokenizer/weights, single config.json read, parallel processor config - @Test(.enabled(if: benchmarksEnabled)) - func loadVLM() async throws { - let modelId = "mlx-community/Qwen2-VL-2B-Instruct-4bit" - let hub = HubApi() - let config = ModelConfiguration(id: modelId) - - // Warm-up run: ensure model is downloaded and caches are primed - _ = try await VLMModelFactory.shared.load(hub: hub, configuration: config) { _ in } - Memory.clearCache() - - // Benchmark multiple runs - let runs = 7 - var times: [Double] = [] - - for i in 1 ... runs { - let start = CFAbsoluteTimeGetCurrent() - - _ = try await VLMModelFactory.shared.load( - hub: hub, - configuration: config - ) { _ in } - - let elapsed = (CFAbsoluteTimeGetCurrent() - start) * 1000 - times.append(elapsed) - print("VLM load run \(i): \(String(format: "%.0f", elapsed))ms") - - // Clear GPU cache to ensure independent measurements - Memory.clearCache() - } - - BenchmarkStats(times: times).printSummary(label: "VLM load") - } -} diff --git a/Tests/MLXLMIntegrationTests/ChatSessionIntegrationTests.swift b/Tests/MLXLMIntegrationTests/ChatSessionIntegrationTests.swift deleted file mode 100644 index 206b00aef..000000000 --- a/Tests/MLXLMIntegrationTests/ChatSessionIntegrationTests.swift +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright © 2025 Apple Inc. - -import CoreImage -import Foundation -import MLX -import MLXLLM -import MLXLMCommon -import MLXNN -import MLXOptimizers -import MLXVLM -import Tokenizers -import XCTest - -/// Tests for the streamlined API using real models -public class ChatSessionIntegrationTests: XCTestCase { - - nonisolated(unsafe) static var llmContainer: ModelContainer! - nonisolated(unsafe) static var vlmContainer: ModelContainer! - - override public class func setUp() { - super.setUp() - // Load models once for all tests - let llmExpectation = XCTestExpectation(description: "Load LLM") - let vlmExpectation = XCTestExpectation(description: "Load VLM") - - Task { - do { - llmContainer = try await IntegrationTestModels.shared.llmContainer() - llmExpectation.fulfill() - } catch { - fatalError("Unable to load llm: \(error)") - } - } - - Task { - do { - vlmContainer = try await IntegrationTestModels.shared.vlmContainer() - vlmExpectation.fulfill() - } catch { - fatalError("Unable to load vlm: \(error)") - } - } - - _ = XCTWaiter.wait(for: [llmExpectation, vlmExpectation], timeout: 300) - } - - func testOneShot() async throws { - let session = ChatSession(Self.llmContainer) - let result = try await session.respond(to: "What is 2+2? Reply with just the number.") - print("One-shot result:", result) - XCTAssertTrue(result.contains("4") || result.lowercased().contains("four")) - } - - func testOneShotStream() async throws { - let session = ChatSession(Self.llmContainer) - var result = "" - for try await token in session.streamResponse( - to: "What is 2+2? Reply with just the number.") - { - print(token, terminator: "") - result += token - } - print() // newline - XCTAssertTrue(result.contains("4") || result.lowercased().contains("four")) - } - - func testMultiTurnConversation() async throws { - let session = ChatSession( - Self.llmContainer, instructions: "You are a helpful assistant. Keep responses brief.") - - let response1 = try await session.respond(to: "My name is Alice.") - print("Response 1:", response1) - - let response2 = try await session.respond(to: "What is my name?") - print("Response 2:", response2) - - // If multi-turn works, response2 should mention "Alice" - XCTAssertTrue( - response2.lowercased().contains("alice"), - "Model should remember the name 'Alice' from previous turn") - } - - func testVisionModel() async throws { - let session = ChatSession(Self.vlmContainer) - - // Create a simple red image for testing - let redImage = CIImage(color: .red).cropped(to: CGRect(x: 0, y: 0, width: 100, height: 100)) - - let result = try await session.respond( - to: "What color is this image? Reply with just the color name.", - image: .ciImage(redImage)) - print("Vision result:", result) - XCTAssertTrue(result.lowercased().contains("red")) - } - - func testStreamDetailsWithTools() async throws { - let tools: [ToolSpec] = [ - [ - "type": "function", - "function": [ - "name": "get_weather", - "description": "Get the current weather for a location", - "parameters": [ - "type": "object", - "properties": [ - "location": [ - "type": "string", - "description": "The city name", - ] as [String: any Sendable] - ] as [String: any Sendable], - "required": ["location"], - ] as [String: any Sendable], - ] as [String: any Sendable], - ] as ToolSpec - ] - let session = ChatSession(Self.llmContainer, tools: tools) - - var responseText = "" - var toolCalls: [ToolCall] = [] - - // Use streamDetails to receive tool calls (respond/streamResponse drops them) - for try await generation in session.streamDetails( - to: "What is the weather in San Francisco?", - images: [], - videos: [] - ) { - switch generation { - case .chunk(let text): - responseText += text - case .toolCall(let toolCall): - toolCalls.append(toolCall) - case .info: - break - } - } - - print("Tools result text:", responseText) - print("Tool calls:", toolCalls) - - // The model should either produce a tool call or mention the tool/weather - let hasContent = responseText.count > 0 || !toolCalls.isEmpty - XCTAssertTrue(hasContent, "Response should contain either text or tool calls") - - let weather = try await session.respond( - to: "Foggy with a high in the low 60s, clearing later in the day", role: .tool) - XCTAssertTrue(weather.contains("fog"), "Weather should mention fog: \(weather)") - } - - func testToolInvocation() async throws { - struct EmptyInput: Codable {} - - struct TimeOutput: Codable { - let time: String - } - - let timeTool = Tool( - name: "get_time", - description: "Get the current date and time including day of week.", - parameters: [] - ) { _ in - TimeOutput(time: "Wed Feb 18 17:50:43 PST 2026") - } - - let session = ChatSession(Self.llmContainer, tools: [timeTool.schema]) { toolCall in - if toolCall.function.name == timeTool.name { - return try await toolCall.execute(with: timeTool).toolResult - } - return "Unknown tool: \(toolCall.function.name)" - } - - let day = try await session.respond(to: "What day of week is it?") - XCTAssertTrue(day.contains("Wed"), "Weather should mention Wed: \(day)") - } - - func testPromptRehydration() async throws { - // Simulate a persisted history (e.g. loaded from SwiftData) - let history: [Chat.Message] = [ - .system("You are a helpful assistant."), - .user("My name is Bob."), - .assistant("Hello Bob! How can I help you today?"), - ] - - let session = ChatSession(Self.llmContainer, history: history) - - // Ask a question that requires the context - let response = try await session.respond(to: "What is my name?") - - print("Rehydration result:", response) - - XCTAssertTrue( - response.lowercased().contains("bob"), - "Model should recognize the name 'Bob' from the injected history, proving successful prompt re-hydration." - ) - } -} diff --git a/Tests/MLXLMIntegrationTests/EmbedderIntegrationTests.swift b/Tests/MLXLMIntegrationTests/EmbedderIntegrationTests.swift deleted file mode 100644 index ba311464d..000000000 --- a/Tests/MLXLMIntegrationTests/EmbedderIntegrationTests.swift +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright © 2026 Apple Inc. - -import Foundation -import MLX -import MLXEmbedders -import MLXLMCommon -import MLXVLM -import Testing -import Tokenizers - -struct EmbedderIntegrationtests { - - private func readeMeExampleResult() async throws -> ([String], [[Float]]) { - let modelContainer = try await loadModelContainer(configuration: .nomic_text_v1_5) - let searchInputs = [ - "search_query: Animals in Tropical Climates.", - "search_document: Elephants", - "search_document: Horses", - "search_document: Polar Bears", - ] - - // Generate embeddings - let resultEmbeddings = await modelContainer.perform { - (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in - let inputs = searchInputs.map { - tokenizer.encode(text: $0, addSpecialTokens: true) - } - // Pad to longest - let maxLength = inputs.reduce(into: 16) { acc, elem in - acc = max(acc, elem.count) - } - - let padded = stacked( - inputs.map { elem in - MLXArray( - elem - + Array( - repeating: tokenizer.eosTokenId ?? 0, - count: maxLength - elem.count)) - }) - let mask = (padded .!= tokenizer.eosTokenId ?? 0) - let tokenTypes = MLXArray.zeros(like: padded) - let result = pooling( - model(padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask), - normalize: true, applyLayerNorm: true - ) - result.eval() - return result.map { $0.asArray(Float.self) } - } - - return (searchInputs, resultEmbeddings) - } - - @Test("MLXEmbedders README.md example") - func testReadMeExample() async throws { - let (searchInputs, resultEmbeddings) = try await readeMeExampleResult() - - // Compute similarities - let searchQueryEmbedding = resultEmbeddings[0] - let documentEmbeddings = resultEmbeddings[1...] - let similarities = documentEmbeddings.map { documentEmbedding in - zip(searchQueryEmbedding, documentEmbedding).map(*).reduce(0, +) - } - let documentNames = searchInputs[1...].map { - $0.replacingOccurrences(of: "search_document: ", with: "") - } - let expectedSimilarities: [Float] = [ - 0.6854175, // Elephants - 0.6644787, // Horses - 0.63326025, // Polar Bears - ] - - for (index, resultSimilarity) in similarities.enumerated() { - #expect( - abs(resultSimilarity - expectedSimilarities[index]) < 0.01, - "The expected similarity does not match the result similarity for \(documentNames[index])" - ) - } - } - - @Test("Gemma 3 Embedder integration") - func testGemma3Embedder() async throws { - // Gemma 3 1B model - let modelId = "mlx-community/gemma-3-1b-it-qat-4bit" - let modelContainer = try await loadModelContainer(configuration: .init(id: modelId)) - - let inputs = [ - "The Coca-Cola Company is a soft drink company based in Atlanta, Georgia, USA.", - "In the United States, PepsiCo Inc. is a leading soft drink company.", - ] - - let resultEmbeddings = await modelContainer.perform { - (model: EmbeddingModel, tokenizer: Tokenizer, pooling: Pooling) -> [[Float]] in - let encoded = inputs.map { - tokenizer.encode(text: $0, addSpecialTokens: true) - } - // Pad to longest sequence - let maxLength = encoded.reduce(into: 1) { acc, elem in - acc = max(acc, elem.count) - } - - let padded = stacked( - encoded.map { elem in - MLXArray( - elem - + Array( - repeating: tokenizer.eosTokenId ?? 0, - count: maxLength - elem.count)) - }) - - // Mask out padding tokens - let mask = (padded .!= (tokenizer.eosTokenId ?? 0)) - let tokenTypes = MLXArray.zeros(like: padded) - - // Generate embeddings. EmbeddingGemma returns a pooledOutput by default. - let modelOutput = model( - padded, positionIds: nil, tokenTypeIds: tokenTypes, attentionMask: mask) - - // Pooling strategy .cls (the default if no pooling config exists) - // will pick up the pooledOutput from the EmbeddingGemma model. - let result = pooling( - modelOutput, - normalize: true, applyLayerNorm: true - ) - result.eval() - return result.map { $0.asArray(Float.self) } - } - - #expect(resultEmbeddings.count == inputs.count, "Should have one embedding per input") - for embedding in resultEmbeddings { - // Gemma 3 1B hidden size is 1152 - #expect(embedding.count == 1152, "Gemma 3 1B embedding size should be 1152") - - // Verify normalization (L2 norm should be ~1.0) - let l2Norm = sqrt(embedding.map { $0 * $0 }.reduce(0, +)) - #expect(abs(l2Norm - 1.0) < 0.05, "Embeddings should be approximately L2-normalized") - } - - // Basic semantic check: similarity between related sentences should be positive - let similarity = zip(resultEmbeddings[0], resultEmbeddings[1]).map(*).reduce(0, +) - //print("similarity: \(similarity)") - #expect(similarity > 0.0, "Similarity between related sentences should be positive") - } - -} diff --git a/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift b/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift deleted file mode 100644 index fbd84d164..000000000 --- a/Tests/MLXLMIntegrationTests/IntegrationTestModels.swift +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright © 2025 Apple Inc. - -import Foundation -import MLXLLM -import MLXLMCommon -import MLXVLM - -enum IntegrationTestModelIDs { - static let llmModelId = "mlx-community/Qwen3-4B-Instruct-2507-4bit" - static let vlmModelId = "mlx-community/Qwen3-VL-4B-Instruct-4bit" -} - -actor IntegrationTestModels { - static let shared = IntegrationTestModels() - - private var llmTask: Task? - private var vlmTask: Task? - - func llmContainer() async throws -> ModelContainer { - if let task = llmTask { - return try await task.value - } - - let task = Task { - try await LLMModelFactory.shared.loadContainer( - configuration: .init(id: IntegrationTestModelIDs.llmModelId) - ) - } - llmTask = task - return try await task.value - } - - func vlmContainer() async throws -> ModelContainer { - if let task = vlmTask { - return try await task.value - } - - let task = Task { - try await VLMModelFactory.shared.loadContainer( - configuration: .init(id: IntegrationTestModelIDs.vlmModelId) - ) - } - vlmTask = task - return try await task.value - } -} diff --git a/Tests/MLXLMIntegrationTests/README.md b/Tests/MLXLMIntegrationTests/README.md deleted file mode 100644 index 8b1378917..000000000 --- a/Tests/MLXLMIntegrationTests/README.md +++ /dev/null @@ -1 +0,0 @@ - diff --git a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift b/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift deleted file mode 100644 index 3f40d9882..000000000 --- a/Tests/MLXLMIntegrationTests/ToolCallIntegrationTests.swift +++ /dev/null @@ -1,377 +0,0 @@ -// Copyright © 2025 Apple Inc. - -import Foundation -import MLX -import MLXLLM -import MLXLMCommon -import MLXVLM -import XCTest - -/// Integration tests for tool call format auto-detection and end-to-end parsing. -/// -/// These tests verify that: -/// 1. Tool call formats are correctly auto-detected from model_type -/// 2. Tool calls are correctly parsed from actual model generation output -/// -/// References: -/// - LFM2: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tool_parsers/default.py -/// - GLM4: https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/tool_parsers/glm47.py -public class ToolCallIntegrationTests: XCTestCase { - - // MARK: - Model IDs - - static let lfm2ModelId = "mlx-community/LFM2-2.6B-Exp-4bit" - static let glm4ModelId = "mlx-community/GLM-4-9B-0414-4bit" - static let mistral3ModelId = "mlx-community/Ministral-3-3B-Instruct-2512-4bit" - - // MARK: - Shared State - - nonisolated(unsafe) static var lfm2Container: ModelContainer? - nonisolated(unsafe) static var glm4Container: ModelContainer? - nonisolated(unsafe) static var mistral3Container: ModelContainer? - - // MARK: - Tool Schema - - static let weatherToolSchema: [[String: any Sendable]] = [ - [ - "type": "function", - "function": [ - "name": "get_weather", - "description": "Get the current weather for a location", - "parameters": [ - "type": "object", - "properties": [ - "location": [ - "type": "string", - "description": "The city name, e.g. San Francisco", - ] as [String: any Sendable], - "unit": [ - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "Temperature unit", - ] as [String: any Sendable], - ] as [String: any Sendable], - "required": ["location"], - ] as [String: any Sendable], - ] as [String: any Sendable], - ] - ] - - // MARK: - Setup - - override public class func setUp() { - super.setUp() - - let lfm2Expectation = XCTestExpectation(description: "Load LFM2") - let glm4Expectation = XCTestExpectation(description: "Load GLM4") - let mistral3Expectation = XCTestExpectation(description: "Load Mistral3") - - Task { - do { - lfm2Container = try await LLMModelFactory.shared.loadContainer( - configuration: .init(id: lfm2ModelId) - ) - } catch { - print("Failed to load LFM2: \(error)") - } - lfm2Expectation.fulfill() - } - - Task { - do { - glm4Container = try await LLMModelFactory.shared.loadContainer( - configuration: .init(id: glm4ModelId) - ) - } catch { - print("Failed to load GLM4: \(error)") - } - glm4Expectation.fulfill() - } - - Task { - do { - mistral3Container = try await VLMModelFactory.shared.loadContainer( - configuration: .init(id: mistral3ModelId) - ) - } catch { - print("Failed to load Mistral3: \(error)") - } - mistral3Expectation.fulfill() - } - - _ = XCTWaiter.wait( - for: [lfm2Expectation, glm4Expectation, mistral3Expectation], timeout: 600) - } - - // MARK: - LFM2 Tests - - func testLFM2ToolCallFormatAutoDetection() async throws { - guard let container = Self.lfm2Container else { - throw XCTSkip("LFM2 model not available") - } - - let config = await container.configuration - XCTAssertEqual( - config.toolCallFormat, .lfm2, - "LFM2 model should auto-detect .lfm2 tool call format" - ) - } - - func testLFM2EndToEndToolCallGeneration() async throws { - guard let container = Self.lfm2Container else { - throw XCTSkip("LFM2 model not available") - } - - // Create input with tool schema - let input = UserInput( - chat: [ - .system( - "You are a helpful assistant with access to tools. When asked about weather, use the get_weather function." - ), - .user("What's the weather in Tokyo?"), - ], - tools: Self.weatherToolSchema - ) - - // Generate with tools - let (result, toolCalls) = try await generateWithTools( - container: container, - input: input, - maxTokens: 100 - ) - - print("LFM2 Output: \(result)") - print("LFM2 Tool Calls: \(toolCalls)") - - // Verify we got a tool call (model may or may not call the tool) - if !toolCalls.isEmpty { - let toolCall = toolCalls.first! - XCTAssertEqual(toolCall.function.name, "get_weather") - // Location should contain something related to Tokyo - if let location = toolCall.function.arguments["location"]?.asString { - XCTAssertTrue( - location.lowercased().contains("tokyo"), - "Expected location to contain 'Tokyo', got: \(location)" - ) - } - } - } - - // MARK: - GLM4 Tests - - func testGLM4ToolCallFormatAutoDetection() async throws { - guard let container = Self.glm4Container else { - throw XCTSkip("GLM4 model not available") - } - - let config = await container.configuration - XCTAssertEqual( - config.toolCallFormat, .glm4, - "GLM4 model should auto-detect .glm4 tool call format" - ) - } - - func testGLM4EndToEndToolCallGeneration() async throws { - guard let container = Self.glm4Container else { - throw XCTSkip("GLM4 model not available") - } - - // Create input with tool schema - let input = UserInput( - chat: [ - .system( - "You are a helpful assistant with access to tools. When asked about weather, use the get_weather function." - ), - .user("What's the weather in Paris?"), - ], - tools: Self.weatherToolSchema - ) - - // Generate with tools - let (result, toolCalls) = try await generateWithTools( - container: container, - input: input, - maxTokens: 100 - ) - - print("GLM4 Output: \(result)") - print("GLM4 Tool Calls: \(toolCalls)") - - // Verify we got a tool call (model may or may not call the tool) - if !toolCalls.isEmpty { - let toolCall = toolCalls.first! - XCTAssertEqual(toolCall.function.name, "get_weather") - // Location should contain something related to Paris - if let location = toolCall.function.arguments["location"]?.asString { - XCTAssertTrue( - location.lowercased().contains("paris"), - "Expected location to contain 'Paris', got: \(location)" - ) - } - } - } - - // MARK: - Mistral3 Tests - - func testMistral3ToolCallFormatAutoDetection() async throws { - guard let container = Self.mistral3Container else { - throw XCTSkip("Mistral3 model not available") - } - - let config = await container.configuration - XCTAssertEqual( - config.toolCallFormat, .mistral, - "Mistral3 model should auto-detect .mistral tool call format" - ) - } - - func testMistral3EndToEndToolCallGeneration() async throws { - guard let container = Self.mistral3Container else { - throw XCTSkip("Mistral3 model not available") - } - - let input = UserInput( - chat: [ - .system( - "You are a helpful assistant with access to tools. When asked about weather, use the get_weather function." - ), - .user("What's the weather in Tokyo?"), - ], - tools: Self.weatherToolSchema - ) - - let (result, toolCalls) = try await generateWithTools( - container: container, - input: input, - maxTokens: 100 - ) - - print("Mistral3 Output: \(result)") - print("Mistral3 Tool Calls: \(toolCalls)") - - // Verify we got a tool call (model may or may not call the tool) - if !toolCalls.isEmpty { - let toolCall = toolCalls.first! - XCTAssertEqual(toolCall.function.name, "get_weather") - if let location = toolCall.function.arguments["location"]?.asString { - XCTAssertTrue( - location.lowercased().contains("tokyo"), - "Expected location to contain 'Tokyo', got: \(location)" - ) - } - } - } - - func testMistral3MultipleToolCallGeneration() async throws { - guard let container = Self.mistral3Container else { - throw XCTSkip("Mistral3 model not available") - } - - let multiToolSchema: [[String: any Sendable]] = - Self.weatherToolSchema + [ - [ - "type": "function", - "function": [ - "name": "get_time", - "description": "Get the current time in a given timezone", - "parameters": [ - "type": "object", - "properties": [ - "timezone": [ - "type": "string", - "description": - "The timezone, e.g. America/New_York, Asia/Tokyo", - ] as [String: any Sendable] - ] as [String: any Sendable], - "required": ["timezone"], - ] as [String: any Sendable], - ] as [String: any Sendable], - ] - ] - - let input = UserInput( - chat: [ - .system( - "You are a helpful assistant with access to tools. Always use the available tools to answer questions. Call multiple tools in parallel when needed." - ), - .user( - "What's the weather in Tokyo and what time is it there?" - ), - ], - tools: multiToolSchema - ) - - let (result, toolCalls) = try await generateWithTools( - container: container, - input: input, - maxTokens: 150 - ) - - print("Mistral3 Output: \(result)") - print("Mistral3 Calls: \(toolCalls)") - - // Verify all returned tool calls have valid names from our schema - let validNames: Set = ["get_weather", "get_time"] - for toolCall in toolCalls { - XCTAssertTrue( - validNames.contains(toolCall.function.name), - "Unexpected tool call: \(toolCall.function.name)" - ) - } - - // If the model made multiple calls, verify we got more than one - if toolCalls.count > 1 { - print("Successfully parsed \(toolCalls.count) tool calls from Mistral3") - } - } - - // MARK: - Helper Methods - - /// Generate text and collect any tool calls - private func generateWithTools( - container: ModelContainer, - input: UserInput, - maxTokens: Int - ) async throws -> (text: String, toolCalls: [ToolCall]) { - let result = try await container.perform(nonSendable: input) { - (context: ModelContext, input) in - let lmInput = try await context.processor.prepare(input: input) - let parameters = GenerateParameters(maxTokens: maxTokens) - - let stream = try generate( - input: lmInput, - parameters: parameters, - context: context - ) - - var collectedText = "" - var collectedToolCalls: [ToolCall] = [] - - for try await generation in stream { - switch generation { - case .chunk(let text): - collectedText += text - case .toolCall(let toolCall): - collectedToolCalls.append(toolCall) - case .info: - break - } - } - - return (collectedText, collectedToolCalls) - } - - return result - } -} - -// MARK: - JSONValue Extension for Testing - -extension JSONValue { - var asString: String? { - if case .string(let s) = self { - return s - } - return nil - } -} diff --git a/Tests/MLXLMTests/BaseConfigurationTests.swift b/Tests/MLXLMTests/BaseConfigurationTests.swift index 4e77e63cd..1c871fa09 100644 --- a/Tests/MLXLMTests/BaseConfigurationTests.swift +++ b/Tests/MLXLMTests/BaseConfigurationTests.swift @@ -21,7 +21,6 @@ public class BaseConfigurationTests: XCTestCase { let config = try JSONDecoder().decode( BaseConfiguration.self, from: json.data(using: .utf8)!) - XCTAssertEqual(config.quantization, .init(groupSize: 128, bits: 4)) XCTAssertEqual( config.perLayerQuantization?.quantization(layer: "x"), .init(groupSize: 128, bits: 4)) } @@ -48,8 +47,6 @@ public class BaseConfigurationTests: XCTestCase { let config = try JSONDecoder().decode( BaseConfiguration.self, from: json.data(using: .utf8)!) - XCTAssertEqual(config.quantization, .init(groupSize: 64, bits: 4)) - // a random layer -- no specific configuration gets default XCTAssertEqual( config.perLayerQuantization?.quantization(layer: "x"), diff --git a/Tests/MLXLMTests/ChatSessionTests.swift b/Tests/MLXLMTests/ChatSessionTests.swift index b98f8e425..7b8055cd8 100644 --- a/Tests/MLXLMTests/ChatSessionTests.swift +++ b/Tests/MLXLMTests/ChatSessionTests.swift @@ -5,7 +5,6 @@ import MLX import MLXLLM import MLXNN import MLXOptimizers -import Tokenizers import XCTest @testable import MLXLMCommon diff --git a/Tests/MLXLMTests/ResolveTests.swift b/Tests/MLXLMTests/ResolveTests.swift new file mode 100644 index 000000000..67fea00e3 --- /dev/null +++ b/Tests/MLXLMTests/ResolveTests.swift @@ -0,0 +1,204 @@ +import Foundation +import MLXLMCommon +import Testing + +/// A mock downloader that records every call for later assertion. +private struct MockDownloader: Downloader { + + struct Call: Equatable, Sendable { + let id: String + let revision: String? + let patterns: [String] + } + + let calls: LockIsolated<[Call]> + let directory: URL + + init(directory: URL = URL(filePath: "/mock")) { + self.calls = LockIsolated([]) + self.directory = directory + } + + func download( + id: String, + revision: String?, + matching patterns: [String], + useLatest: Bool, + progressHandler: @Sendable @escaping (Progress) -> Void + ) async throws -> URL { + calls.withLock { $0.append(Call(id: id, revision: revision, patterns: patterns)) } + // Return a unique directory per id so tests can distinguish model vs tokenizer paths. + return directory.appending(component: id.replacingOccurrences(of: "/", with: "_")) + } +} + +/// Minimal lock-based isolation for collecting values from async contexts. +private final class LockIsolated: @unchecked Sendable { + private var _value: Value + private let lock = NSLock() + + init(_ value: Value) { _value = value } + + func withLock(_ body: (inout Value) -> R) -> R { + lock.lock() + defer { lock.unlock() } + return body(&_value) + } + + var value: Value { + lock.lock() + defer { lock.unlock() } + return _value + } +} + +@Suite struct ResolveTests { + + @Test func nilTokenizerSourceUsesModelDirectory() async throws { + let downloader = MockDownloader() + let config = ModelConfiguration( + id: "org/model", revision: "abc123", tokenizerSource: nil) + + let resolved = try await resolve( + configuration: config, from: downloader, + useLatest: false, progressHandler: { _ in }) + + // Only one download call — the model itself. + #expect(downloader.calls.value.count == 1) + #expect(downloader.calls.value[0].id == "org/model") + #expect(downloader.calls.value[0].revision == "abc123") + #expect(downloader.calls.value[0].patterns.contains("*.jinja")) + + // No separate tokenizer download, so both point to the model directory. + #expect(resolved.modelDirectory == resolved.tokenizerDirectory) + } + + @Test func tokenizerSourceIDWithoutRevisionPassesNil() async throws { + let downloader = MockDownloader() + let config = ModelConfiguration( + id: "org/model", revision: "abc123", + tokenizerSource: .id("org/tokenizer")) + + let resolved = try await resolve( + configuration: config, from: downloader, + useLatest: false, progressHandler: { _ in }) + + #expect(downloader.calls.value.count == 2) + + // Model download uses model revision. + #expect(downloader.calls.value[0].id == "org/model") + #expect(downloader.calls.value[0].revision == "abc123") + + // Tokenizer download uses nil revision (provider default). + #expect(downloader.calls.value[1].id == "org/tokenizer") + #expect(downloader.calls.value[1].revision == nil) + #expect(downloader.calls.value[1].patterns.contains("*.jinja")) + + // Model and tokenizer come from different repos, so directories differ. + #expect(resolved.modelDirectory != resolved.tokenizerDirectory) + } + + @Test func tokenizerSourceIDWithExplicitRevision() async throws { + let downloader = MockDownloader() + let config = ModelConfiguration( + id: "org/model", revision: "v1.0", + tokenizerSource: .id("org/tokenizer", revision: "tok-v2")) + + let resolved = try await resolve( + configuration: config, from: downloader, + useLatest: false, progressHandler: { _ in }) + + #expect(downloader.calls.value.count == 2) + + #expect(downloader.calls.value[0].id == "org/model") + #expect(downloader.calls.value[0].revision == "v1.0") + + #expect(downloader.calls.value[1].id == "org/tokenizer") + #expect(downloader.calls.value[1].revision == "tok-v2") + #expect(downloader.calls.value[1].patterns.contains("*.jinja")) + + // Model and tokenizer come from different repos, so directories differ. + #expect(resolved.modelDirectory != resolved.tokenizerDirectory) + } + + @Test func localDirectorySkipsDownloader() async throws { + let downloader = MockDownloader() + let localDir = URL(filePath: "/local/org/model") + let config = ModelConfiguration(directory: localDir) + + let resolved = try await resolve( + configuration: config, from: downloader, + useLatest: false, progressHandler: { _ in }) + + // No downloads should occur for a local directory. + #expect(downloader.calls.value.isEmpty) + + // Both directories point to the local path. + #expect(resolved.modelDirectory == localDir) + #expect(resolved.tokenizerDirectory == localDir) + } + + @Test func localDirectoryWithRemoteTokenizerSource() async throws { + let downloader = MockDownloader() + let localDir = URL(filePath: "/local/org/model") + let config = ModelConfiguration( + directory: localDir, + tokenizerSource: .id("org/tokenizer", revision: "v3")) + + let resolved = try await resolve( + configuration: config, from: downloader, + useLatest: false, progressHandler: { _ in }) + + // Only the tokenizer is downloaded; the model directory is local. + #expect(downloader.calls.value.count == 1) + #expect(downloader.calls.value[0].id == "org/tokenizer") + #expect(downloader.calls.value[0].revision == "v3") + #expect(downloader.calls.value[0].patterns.contains("*.jinja")) + + #expect(resolved.modelDirectory == localDir) + #expect(resolved.tokenizerDirectory != localDir) + } + + @Test func localConfigurationExposesResolvedDirectories() throws { + let modelDir = URL(filePath: "/local/org/model") + let tokenizerDir = URL(filePath: "/local/org/tokenizer") + let config = ModelConfiguration( + directory: modelDir, + tokenizerSource: .directory(tokenizerDir)) + + #expect(try config.modelDirectory == modelDir) + #expect(try config.tokenizerDirectory == tokenizerDir) + } + + @Test func tokenizerDirectoryFallsBackToModelDirectory() throws { + let modelDir = URL(filePath: "/local/org/model") + let config = ModelConfiguration(directory: modelDir) + + #expect(try config.modelDirectory == modelDir) + #expect(try config.tokenizerDirectory == modelDir) + } + + @Test func unresolvedRemoteConfigurationThrowsForDirectories() { + let config = ModelConfiguration( + id: "org/model", + tokenizerSource: .id("org/tokenizer")) + + do { + _ = try config.modelDirectory + Issue.record("Expected modelDirectory to throw for unresolved remote config") + } catch let error as ModelConfiguration.DirectoryError { + #expect(error == .unresolvedModelDirectory("org/model")) + } catch { + Issue.record("Unexpected error: \(error)") + } + + do { + _ = try config.tokenizerDirectory + Issue.record("Expected tokenizerDirectory to throw for unresolved remote tokenizer") + } catch let error as ModelConfiguration.DirectoryError { + #expect(error == .unresolvedTokenizerDirectory("org/tokenizer")) + } catch { + Issue.record("Unexpected error: \(error)") + } + } +} diff --git a/Tests/MLXLMTests/TestTokenizer.swift b/Tests/MLXLMTests/TestTokenizer.swift index 4908e5f1e..a07d802f0 100644 --- a/Tests/MLXLMTests/TestTokenizer.swift +++ b/Tests/MLXLMTests/TestTokenizer.swift @@ -3,10 +3,9 @@ import Foundation import MLX import MLXLMCommon -import Tokenizers /// A test tokenizer -- this can be used in place of a real tokenizer for unit/integration tests. -struct TestTokenizer: Tokenizer { +struct TestTokenizer: MLXLMCommon.Tokenizer { let length = 8 let maxLength = 50 @@ -35,26 +34,18 @@ struct TestTokenizer: Tokenizer { ) } - func tokenize(text: String) -> [String] { - text.split(separator: " ").map { String($0) } - } - - func encode(text: String) -> [Int] { - (0 ..< length).enumerated().map { (index, _) in + func encode(text: String, addSpecialTokens: Bool) -> [Int] { + (0 ..< length).map { _ in Int.random(in: 1 ..< vocabularySize) } } - func encode(text: String, addSpecialTokens: Bool) -> [Int] { - encode(text: text) - } - - func decode(tokens: [Int], skipSpecialTokens: Bool) -> String { - var tokens = tokens - if tokens.count > maxLength { - tokens.append(_eosTokenId) + func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String { + var tokenIds = tokenIds + if tokenIds.count > maxLength { + tokenIds.append(_eosTokenId) } - return tokens.map { convertIdToToken($0) ?? "" }.joined(separator: " ") + return tokenIds.map { convertIdToToken($0) ?? "" }.joined(separator: " ") } func convertTokenToId(_ token: String) -> Int? { @@ -69,54 +60,16 @@ struct TestTokenizer: Tokenizer { } var bosToken: String? = nil - - var bosTokenId: Int? = 0 - var eosToken: String? = nil - var eosTokenId: Int? { _eosTokenId } var unknownToken: String? = nil var unknownTokenId: Int? { _unknownTokenId } - func applyChatTemplate(messages: [Tokenizers.Message]) throws -> [Int] { - encode(text: "") - } - - func applyChatTemplate(messages: [Tokenizers.Message], tools: [Tokenizers.ToolSpec]?) throws - -> [Int] - { - encode(text: "") - } - - func applyChatTemplate( - messages: [Tokenizers.Message], tools: [Tokenizers.ToolSpec]?, - additionalContext: [String: any Sendable]? - ) throws -> [Int] { - encode(text: "") - } - - func applyChatTemplate( - messages: [Tokenizers.Message], chatTemplate: Tokenizers.ChatTemplateArgument - ) throws -> [Int] { - encode(text: "") - } - - func applyChatTemplate(messages: [Tokenizers.Message], chatTemplate: String) throws -> [Int] { - encode(text: "") - } - - func applyChatTemplate( - messages: [Tokenizers.Message], chatTemplate: Tokenizers.ChatTemplateArgument?, - addGenerationPrompt: Bool, truncation: Bool, maxLength: Int?, tools: [Tokenizers.ToolSpec]? - ) throws -> [Int] { - encode(text: "") - } - func applyChatTemplate( - messages: [Tokenizers.Message], chatTemplate: Tokenizers.ChatTemplateArgument?, - addGenerationPrompt: Bool, truncation: Bool, maxLength: Int?, tools: [Tokenizers.ToolSpec]?, + messages: [[String: any Sendable]], + tools: [[String: any Sendable]]?, additionalContext: [String: any Sendable]? ) throws -> [Int] { encode(text: "") diff --git a/skills/mlx-swift-lm/SKILL.md b/skills/mlx-swift-lm/SKILL.md index 206ecbfb6..395afa318 100644 --- a/skills/mlx-swift-lm/SKILL.md +++ b/skills/mlx-swift-lm/SKILL.md @@ -65,8 +65,12 @@ MLXEmbedders - Embedding models and pooling utilities ```swift import MLXLLM import MLXLMCommon +import MLXLMHuggingFace // from swift-huggingface-mlx +import MLXLMTokenizers // from swift-tokenizers-mlx let modelContainer = try await LLMModelFactory.shared.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), configuration: .init(id: "mlx-community/Qwen3-4B-4bit") ) @@ -85,8 +89,12 @@ for try await chunk in session.streamResponse(to: "Explain structured concurrenc ```swift import MLXVLM import MLXLMCommon +import MLXLMHuggingFace // from swift-huggingface-mlx +import MLXLMTokenizers // from swift-tokenizers-mlx let modelContainer = try await VLMModelFactory.shared.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), configuration: .init(id: "mlx-community/Qwen2-VL-2B-Instruct-4bit") ) @@ -103,9 +111,13 @@ let response = try await session.respond( ### Embeddings ```swift -import Embedders +import MLXEmbedders +import MLXEmbeddersHuggingFace // from swift-huggingface-mlx +import MLXLMTokenizers // from swift-tokenizers-mlx let container = try await loadModelContainer( + from: HubClient.default, + using: TokenizersLoader(), configuration: ModelConfiguration(id: "mlx-community/bge-small-en-v1.5-mlx") ) diff --git a/skills/mlx-swift-lm/references/concurrency.md b/skills/mlx-swift-lm/references/concurrency.md index 6c8404898..24d4d757d 100644 --- a/skills/mlx-swift-lm/references/concurrency.md +++ b/skills/mlx-swift-lm/references/concurrency.md @@ -128,7 +128,11 @@ public final class ModelContainer: Sendable { ```swift // Multiple tasks can call perform() safely -let container = try await loadModelContainer() +let container = try await loadModelContainer( + from: HubClient.default, + using: TokenizersLoader(), // TokenizersLoader() from MLXLMTokenizers (swift-tokenizers-mlx) + id: "mlx-community/Qwen3-4B-4bit" +) Task { await container.perform { context in diff --git a/skills/mlx-swift-lm/references/embeddings.md b/skills/mlx-swift-lm/references/embeddings.md index 753c27e63..bf81529ff 100644 --- a/skills/mlx-swift-lm/references/embeddings.md +++ b/skills/mlx-swift-lm/references/embeddings.md @@ -48,30 +48,42 @@ The Embedders library provides text embedding models for semantic search, RAG, c ### Using Pre-registered Configuration ```swift -import Embedders +import MLXEmbedders let config = ModelConfiguration.bge_small -let container = try await loadModelContainer(configuration: config) +let container = try await loadModelContainer( + from: HubClient.default, + using: TokenizersLoader(), // TokenizersLoader() from MLXLMTokenizers (swift-tokenizers-mlx) + configuration: config +) ``` ### Using Custom Model ID ```swift let config = ModelConfiguration(id: "BAAI/bge-small-en-v1.5") -let container = try await loadModelContainer(configuration: config) +let container = try await loadModelContainer( + from: HubClient.default, + using: TokenizersLoader(), + configuration: config +) ``` ### From Local Directory ```swift -let config = ModelConfiguration(directory: localModelURL) -let container = try await loadModelContainer(configuration: config) +let container = try await loadModelContainer( + from: localModelURL, + using: TokenizersLoader() +) ``` ### With Progress Tracking ```swift let container = try await loadModelContainer( + from: HubClient.default, + using: TokenizersLoader(), configuration: config ) { progress in print("Download progress: \(progress.fractionCompleted)") diff --git a/skills/mlx-swift-lm/references/lora-adapters.md b/skills/mlx-swift-lm/references/lora-adapters.md index 7e99ffa7a..16116a3de 100644 --- a/skills/mlx-swift-lm/references/lora-adapters.md +++ b/skills/mlx-swift-lm/references/lora-adapters.md @@ -230,6 +230,8 @@ public protocol LoRALayer: Module { ```swift // Load base model let container = try await LLMModelFactory.shared.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), // TokenizersLoader() from MLXLMTokenizers (swift-tokenizers-mlx) configuration: .init(id: "mlx-community/Llama-3.2-3B-Instruct-4bit") ) diff --git a/skills/mlx-swift-lm/references/model-container.md b/skills/mlx-swift-lm/references/model-container.md index 11d190679..4cba423f1 100644 --- a/skills/mlx-swift-lm/references/model-container.md +++ b/skills/mlx-swift-lm/references/model-container.md @@ -25,22 +25,28 @@ ```swift // Via factory (recommended) let container = try await LLMModelFactory.shared.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), // TokenizersLoader() from MLXLMTokenizers (swift-tokenizers-mlx) configuration: .init(id: "mlx-community/Qwen3-4B-4bit") ) -// With custom hub -let hub = HubApi(hfToken: "your_token") +// With custom hub (from MLXLMHuggingFace) +let hub = HubClient(token: "hf_...") let container = try await LLMModelFactory.shared.loadContainer( - hub: hub, + from: hub, + using: TokenizersLoader(), configuration: .init(id: "private/model") ) // With progress tracking let container = try await LLMModelFactory.shared.loadContainer( - configuration: config -) { progress in - print("Downloaded: \(progress.fractionCompleted)") -} + from: HubClient.default, + using: TokenizersLoader(), + configuration: config, + progressHandler: { progress in + print("Downloaded: \(progress.fractionCompleted)") + } +) ``` ### Using ModelContainer @@ -173,6 +179,8 @@ let factory = LLMModelFactory.shared // Load container let container = try await factory.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), configuration: LLMRegistry.llama3_2_3B_4bit ) @@ -189,6 +197,8 @@ let customFactory = LLMModelFactory( let factory = VLMModelFactory.shared let container = try await factory.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), configuration: VLMRegistry.qwen2VL2BInstruct4Bit ) ``` @@ -237,7 +247,12 @@ Map `model_type` from config.json to model initializers: ```swift // Download location -let modelDir = configuration.modelDirectory(hub: HubApi()) +let resolved = try await resolve( + configuration: configuration, + from: HubClient.default, + progressHandler: { _ in } +) +let modelDir = resolved.modelDirectory // ~/.cache/huggingface/hub/models--mlx-community--Model-Name/... ``` diff --git a/skills/mlx-swift-lm/references/model-porting.md b/skills/mlx-swift-lm/references/model-porting.md index 3c9cf41be..17f57577a 100644 --- a/skills/mlx-swift-lm/references/model-porting.md +++ b/skills/mlx-swift-lm/references/model-porting.md @@ -365,6 +365,8 @@ If you need custom keys, override `loraDefaultKeys`. ```swift let modelContainer = try await LLMModelFactory.shared.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), // TokenizersLoader() from MLXLMTokenizers (swift-tokenizers-mlx) configuration: ModelConfiguration(id: "mlx-community/YourModel-4bit") ) diff --git a/skills/mlx-swift-lm/references/supported-models.md b/skills/mlx-swift-lm/references/supported-models.md index 4b7f13bb9..8f7d45592 100644 --- a/skills/mlx-swift-lm/references/supported-models.md +++ b/skills/mlx-swift-lm/references/supported-models.md @@ -150,6 +150,8 @@ Models not in registries can be loaded by ID: // Any mlx-community model let config = ModelConfiguration(id: "mlx-community/SomeModel-4bit") let container = try await LLMModelFactory.shared.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), // TokenizersLoader() from MLXLMTokenizers (swift-tokenizers-mlx) configuration: config ) diff --git a/skills/mlx-swift-lm/references/tokenizer-chat.md b/skills/mlx-swift-lm/references/tokenizer-chat.md index bc05f2222..e85609cb3 100644 --- a/skills/mlx-swift-lm/references/tokenizer-chat.md +++ b/skills/mlx-swift-lm/references/tokenizer-chat.md @@ -28,6 +28,8 @@ Tokenizers are loaded automatically by model factories: ```swift let container = try await LLMModelFactory.shared.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), // TokenizersLoader() from MLXLMTokenizers (swift-tokenizers-mlx) configuration: config ) let tokenizer = await container.tokenizer @@ -35,26 +37,13 @@ let tokenizer = await container.tokenizer ### Manual Loading -```swift -let tokenizer = try await loadTokenizer( - configuration: config, - hub: HubApi() -) -``` - -### Loading Components +Tokenizer loading is handled by the `TokenizerLoader` protocol. Each integration +package provides a concrete loader: ```swift -// Load tokenizer config and data separately -let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig( - configuration: config, - hub: hub -) - -let tokenizer = try PreTrainedTokenizer( - tokenizerConfig: tokenizerConfig, - tokenizerData: tokenizerData -) +// Using TokenizersLoader from MLXLMTokenizers (swift-tokenizers-mlx) +let loader = TokenizersLoader() +let tokenizer = try await loader.load(from: modelDirectory) ``` ## Tokenizer Usage diff --git a/skills/mlx-swift-lm/references/training.md b/skills/mlx-swift-lm/references/training.md index dc1bc988e..978d22f33 100644 --- a/skills/mlx-swift-lm/references/training.md +++ b/skills/mlx-swift-lm/references/training.md @@ -50,6 +50,8 @@ import MLXLMCommon // Load base model let container = try await LLMModelFactory.shared.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), // TokenizersLoader() from MLXLMTokenizers (swift-tokenizers-mlx) configuration: .init(id: "mlx-community/Llama-3.2-3B-Instruct-4bit") ) @@ -345,6 +347,8 @@ import MLXOptimizers func trainAdapter() async throws { // Load model let container = try await LLMModelFactory.shared.loadContainer( + from: HubClient.default, + using: TokenizersLoader(), configuration: .init(id: "mlx-community/Llama-3.2-1B-Instruct-4bit") )