Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
464 changes: 464 additions & 0 deletions Libraries/BenchmarkHelpers/BenchmarkHelpers.swift

Large diffs are not rendered by default.

614 changes: 614 additions & 0 deletions Libraries/IntegrationTestHelpers/IntegrationTestHelpers.swift

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions Libraries/IntegrationTestHelpers/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Integration Test Helpers

`IntegrationTestHelpers` and `BenchmarkHelpers` provide shared test logic for verifying end-to-end model loading, inference, tokenizer performance, and download performance. They are designed to be used by integration packages that supply their own `Downloader` and `TokenizerLoader` implementations.

## Integration packages

- [Swift Tokenizers MLX](https://github.com/DePasqualeOrg/swift-tokenizers-mlx): Uses [Swift Tokenizers](https://github.com/DePasqualeOrg/swift-tokenizers) and [Swift HF API](https://github.com/DePasqualeOrg/swift-hf-api)
- [Swift Transformers MLX](https://github.com/DePasqualeOrg/swift-transformers-mlx): Uses [Swift Transformers](https://github.com/huggingface/swift-transformers) and [Swift Hugging Face](https://github.com/huggingface/swift-huggingface)

Integration tests and benchmarks are run from those packages.
19 changes: 8 additions & 11 deletions Libraries/MLXEmbedders/EmbeddingModel.swift
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
// Copyright © 2024 Apple Inc.

import Foundation
@preconcurrency import Hub
import MLX
import MLXLMCommon
import MLXNN
import Tokenizers

/// Container for models that guarantees single threaded access.
///
Expand Down Expand Up @@ -44,23 +43,21 @@ public actor ModelContainer {
self.pooler = pooler
}

/// build the model and tokenizer without passing non-sendable data over isolation barriers
/// Build the model and tokenizer without passing non-sendable data over isolation barriers
public init(
hub: HubApi,
modelDirectory: URL,
configuration: ModelConfiguration
tokenizerDirectory: URL,
configuration: ModelConfiguration,
tokenizerLoader: any TokenizerLoader
) async throws {
// Load tokenizer config and model in parallel using async let.
async let tokenizerConfigTask = loadTokenizerConfig(
configuration: configuration, hub: hub)
// Load tokenizer and model in parallel
async let tokenizerTask = tokenizerLoader.load(from: tokenizerDirectory)

self.model = try loadSynchronous(
modelDirectory: modelDirectory, modelName: configuration.name)
self.pooler = loadPooling(modelDirectory: modelDirectory, model: model)

let (tokenizerConfig, tokenizerData) = try await tokenizerConfigTask
self.tokenizer = try PreTrainedTokenizer(
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
self.tokenizer = try await tokenizerTask
}

/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as
Expand Down
144 changes: 97 additions & 47 deletions Libraries/MLXEmbedders/Load.swift
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
// Copyright © 2024 Apple Inc.

import Foundation
@preconcurrency import Hub
import MLX
import MLXLMCommon
import MLXNN
import Tokenizers

/// Errors encountered during the model loading and initialization process.
///
Expand All @@ -26,9 +24,6 @@ public enum EmbedderError: LocalizedError {
/// The configuration file exists but contains invalid JSON or missing required fields.
case configurationDecodingError(String, String, DecodingError)

/// Thrown when the tokenizer configuration is missing from the model bundle or Hub.
case missingTokenizerConfig

/// A human-readable description of the error.
public var errorDescription: String? {
switch self {
Expand All @@ -39,8 +34,6 @@ public enum EmbedderError: LocalizedError {
case .configurationDecodingError(let file, let modelName, let decodingError):
let errorDetail = extractDecodingErrorDetail(decodingError)
return "Failed to parse \(file) for model '\(modelName)': \(errorDetail)"
case .missingTokenizerConfig:
return "Missing tokenizer configuration"
}
}

Expand Down Expand Up @@ -70,43 +63,48 @@ public enum EmbedderError: LocalizedError {
}
}

/// Prepares the local model directory by downloading files from the Hub or resolving a local path.
///
/// If the `ModelConfiguration` identifies a remote repo, this function downloads weights
/// (`.safetensors`) and config files. It includes a fallback mechanism: if the user is
/// offline or unauthorized, it attempts to resolve the files from the local cache.
/// Resolve model and tokenizer directories from a ``ModelConfiguration``
/// using a ``Downloader``.
///
/// - Parameters:
/// - hub: The `HubApi` instance for managing downloads.
/// - downloader: The downloader to use for fetching remote resources.
/// - configuration: The configuration identifying the model.
/// - useLatest: When true, always checks the provider for updates.
/// - progressHandler: A closure to monitor download progress.
/// - Returns: A `URL` pointing to the directory containing model files.
func prepareModelDirectory(
hub: HubApi,
/// - Returns: A tuple of (modelDirectory, tokenizerDirectory).
func resolveDirectories(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code looks nearly identical to ModelAdapterFactory.load(). I wonder if these can/should share implementation?

It looks like the MLXEmbedders retains the top level load() function while the LLM/VLM side migrate this into ModelFactory. Comparing against main it looks like that is the pattern there too, but maybe we should take this opportunity to consolidate some of this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I investigated this and it seems it's not so straightforward, since MLXEmbedders has its own ModelConfiguration type.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I guess it does -- the ModelConfiguration in MLXLMCommon has properties for detokenizing, so that kind of makes sense. Perhaps a task for after this PR. With a major version bump we could take a look at various APIs and see if any could be simplified or cleaned up.

from downloader: any Downloader,
configuration: ModelConfiguration,
useLatest: Bool = false,
progressHandler: @Sendable @escaping (Progress) -> Void
) async throws -> URL {
do {
switch configuration.id {
case .id(let id):
let repo = Hub.Repo(id: id)
let modelFiles = ["*.safetensors", "config.json", "*/config.json"]
return try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)

case .directory(let directory):
return directory
}
} catch Hub.HubClientError.authorizationRequired {
return configuration.modelDirectory(hub: hub)
} catch {
let nserror = error as NSError
if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet {
return configuration.modelDirectory(hub: hub)
} else {
throw error
}
) async throws -> (modelDirectory: URL, tokenizerDirectory: URL) {
let modelDirectory: URL
switch configuration.id {
case .id(let id, let revision):
modelDirectory = try await downloader.download(
id: id, revision: revision,
matching: modelDownloadPatterns,
useLatest: useLatest,
progressHandler: progressHandler)
case .directory(let directory):
modelDirectory = directory
}

let tokenizerDirectory: URL
switch configuration.tokenizerSource {
case .id(let id, let revision):
tokenizerDirectory = try await downloader.download(
id: id, revision: revision,
matching: tokenizerDownloadPatterns,
useLatest: useLatest,
progressHandler: { _ in })
case .directory(let directory):
tokenizerDirectory = directory
case nil:
tokenizerDirectory = modelDirectory
}

return (modelDirectory, tokenizerDirectory)
}

/// Asynchronously loads the `EmbeddingModel` and its associated `Tokenizer`.
Expand All @@ -116,19 +114,23 @@ func prepareModelDirectory(
/// structure is being built synchronously.
///
/// - Parameters:
/// - hub: The `HubApi` instance (defaults to a new instance).
/// - downloader: The ``Downloader`` to use for fetching remote resources.
/// - configuration: The model configuration.
/// - useLatest: When true, always checks the provider for updates.
/// - progressHandler: A closure for tracking download progress.
/// - Returns: A tuple containing the initialized `EmbeddingModel` and `Tokenizer`.
public func load(
hub: HubApi = defaultHubApi,
from downloader: any Downloader,
using tokenizerLoader: any TokenizerLoader,
configuration: ModelConfiguration,
useLatest: Bool = false,
progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
) async throws -> (EmbeddingModel, Tokenizer) {
let modelDirectory = try await prepareModelDirectory(
hub: hub, configuration: configuration, progressHandler: progressHandler)
let (modelDirectory, tokenizerDirectory) = try await resolveDirectories(
from: downloader, configuration: configuration, useLatest: useLatest,
progressHandler: progressHandler)

async let tokenizerTask = loadTokenizer(configuration: configuration, hub: hub)
async let tokenizerTask = tokenizerLoader.load(from: tokenizerDirectory)
let model = try loadSynchronous(modelDirectory: modelDirectory, modelName: configuration.name)
let tokenizer = try await tokenizerTask

Expand Down Expand Up @@ -213,17 +215,65 @@ func loadSynchronous(modelDirectory: URL, modelName: String) throws -> Embedding
/// or tasks may need to access the embedding model simultaneously.
///
/// - Parameters:
/// - hub: The `HubApi` instance.
/// - downloader: The ``Downloader`` to use for fetching remote resources.
/// - configuration: The model configuration.
/// - useLatest: When true, always checks the provider for updates.
/// - progressHandler: A closure for tracking download progress.
/// - Returns: A thread-safe `ModelContainer` instance.
public func loadModelContainer(
hub: HubApi = defaultHubApi,
from downloader: any Downloader,
using tokenizerLoader: any TokenizerLoader,
configuration: ModelConfiguration,
useLatest: Bool = false,
progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
) async throws -> ModelContainer {
let modelDirectory = try await prepareModelDirectory(
hub: hub, configuration: configuration, progressHandler: progressHandler)
let (modelDirectory, tokenizerDirectory) = try await resolveDirectories(
from: downloader, configuration: configuration, useLatest: useLatest,
progressHandler: progressHandler)

return try await ModelContainer(
hub: hub, modelDirectory: modelDirectory, configuration: configuration)
modelDirectory: modelDirectory,
tokenizerDirectory: tokenizerDirectory,
configuration: configuration,
tokenizerLoader: tokenizerLoader)
}

/// Load an embedding model from a local directory.
///
/// No downloader is needed — the model and tokenizer are loaded from
/// the given directory.
///
/// - Parameter directory: The local directory containing model files.
/// - Returns: A tuple containing the initialized `EmbeddingModel` and `Tokenizer`.
public func load(
from directory: URL,
using tokenizerLoader: any TokenizerLoader
) async throws -> (EmbeddingModel, Tokenizer) {
let name =
directory.deletingLastPathComponent().lastPathComponent + "/"
+ directory.lastPathComponent
async let tokenizerTask = tokenizerLoader.load(from: directory)
let model = try loadSynchronous(modelDirectory: directory, modelName: name)
let tokenizer = try await tokenizerTask
return (model, tokenizer)
}

/// Load an embedding model container from a local directory.
///
/// No downloader is needed — the model and tokenizer are loaded from
/// the given directory.
///
/// - Parameters:
/// - directory: The local directory containing model files.
/// - tokenizerLoader: The ``TokenizerLoader`` to use for loading the tokenizer.
/// - Returns: A thread-safe `ModelContainer` instance.
public func loadModelContainer(
from directory: URL,
using tokenizerLoader: any TokenizerLoader
) async throws -> ModelContainer {
try await ModelContainer(
modelDirectory: directory,
tokenizerDirectory: directory,
configuration: ModelConfiguration(directory: directory),
tokenizerLoader: tokenizerLoader)
}
57 changes: 17 additions & 40 deletions Libraries/MLXEmbedders/Models.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright © 2024 Apple Inc.

import Foundation
import Hub
import MLXLMCommon

/// A registry and configuration provider for embedding models.
///
Expand All @@ -22,7 +22,7 @@ public struct ModelConfiguration: Sendable {
/// The backing storage for the model's location.
public enum Identifier: Sendable {
/// A Hugging Face Hub repository identifier (e.g., "BAAI/bge-small-en-v1.5").
case id(String)
case id(String, revision: String = "main")
/// A file system URL pointing to a local model directory.
case directory(URL)
}
Expand All @@ -36,67 +36,44 @@ public struct ModelConfiguration: Sendable {
/// it returns a path-based name (e.g., "ParentDir/ModelDir").
public var name: String {
switch id {
case .id(let string):
case .id(let string, _):
string
case .directory(let url):
url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent
}
}

/// An optional alternate Hub ID to use specifically for loading the tokenizer.
/// Where to load the tokenizer from when it differs from the model directory.
///
/// Use this if the model weights and tokenizer configuration are hosted in different repositories.
public let tokenizerId: String?

/// An optional override string for specifying a specific tokenizer implementation.
///
/// This is useful for providing compatibility hints to `swift-tokenizers` before
/// official support is updated.
public let overrideTokenizer: String?
/// - `.id`: download from a remote provider (requires a ``Downloader``)
/// - `.directory`: load from a local path
/// - `nil`: use the same directory as the model
public let tokenizerSource: TokenizerSource?

/// Initializes a configuration using a Hub repository ID.
/// - Parameters:
/// - id: The Hugging Face repo ID.
/// - tokenizerId: Optional alternate repo for the tokenizer.
/// - overrideTokenizer: Optional specific tokenizer implementation name.
/// - revision: The Git revision to use (defaults to "main").
/// - tokenizerSource: Optional alternate source for the tokenizer.
public init(
id: String,
tokenizerId: String? = nil,
overrideTokenizer: String? = nil
revision: String = "main",
tokenizerSource: TokenizerSource? = nil
) {
self.id = .id(id)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.id = .id(id, revision: revision)
self.tokenizerSource = tokenizerSource
}

/// Initializes a configuration using a local directory.
/// - Parameters:
/// - directory: The `URL` of the model on disk.
/// - tokenizerId: Optional alternate repo for the tokenizer.
/// - overrideTokenizer: Optional specific tokenizer implementation name.
/// - tokenizerSource: Optional alternate source for the tokenizer.
public init(
directory: URL,
tokenizerId: String? = nil,
overrideTokenizer: String? = nil
tokenizerSource: TokenizerSource? = nil
) {
self.id = .directory(directory)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
}

/// Resolves the local file system URL where the model is (or will be) stored.
///
/// - Parameter hub: The `HubApi` used to resolve Hub paths.
/// - Returns: A `URL` pointing to the local directory.
public func modelDirectory(hub: HubApi = HubApi()) -> URL {
switch id {
case .id(let id):
let repo = Hub.Repo(id: id)
return hub.localRepoLocation(repo)

case .directory(let directory):
return directory
}
self.tokenizerSource = tokenizerSource
}

// MARK: - Registry Management
Expand Down
Loading
Loading