Skip to content

Commit 9f1b334

Browse files
committed
add macros to help people maintain parity with current mlx-swift-lm huggingface integration
1 parent 63fd72d commit 9f1b334

3 files changed

Lines changed: 302 additions & 1 deletion

File tree

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import Foundation
2+
import MLXLMCommon
3+
4+
@freestanding(expression)
5+
public macro hubDownloader(_ hub: Any) -> MLXLMCommon.Downloader =
6+
#externalMacro(module: "MLXHuggingFaceMacros", type: "DownloaderMacro")
7+
8+
@freestanding(expression)
9+
public macro hubDownloader() -> MLXLMCommon.Downloader =
10+
#externalMacro(module: "MLXHuggingFaceMacros", type: "DownloaderMacro")
11+
12+
@freestanding(expression)
13+
public macro adaptHuggingFaceTokenizer(_ upstream: Any) -> MLXLMCommon.Tokenizer =
14+
#externalMacro(module: "MLXHuggingFaceMacros", type: "TokenizerAdaptorMacro")
15+
16+
@freestanding(expression)
17+
public macro huggingFaceTokenizerLoader() -> MLXLMCommon.TokenizerLoader =
18+
#externalMacro(module: "MLXHuggingFaceMacros", type: "TokenizerLoaderMacro")
19+
20+
@freestanding(expression)
21+
public macro huggingFaceLoadModelContainer(
22+
configuration: ModelConfiguration
23+
) -> ModelContainer =
24+
#externalMacro(module: "MLXHuggingFaceMacros", type: "LoadContainerMacro")
25+
26+
@freestanding(expression)
27+
public macro huggingFaceLoadModelContainer(
28+
configuration: ModelConfiguration,
29+
progressHandler: @Sendable @escaping (Progress) -> Void
30+
) -> ModelContainer =
31+
#externalMacro(module: "MLXHuggingFaceMacros", type: "LoadContainerMacro")
32+
33+
@freestanding(expression)
34+
public macro huggingFaceLoadModel(
35+
configuration: ModelConfiguration
36+
) -> ModelContext =
37+
#externalMacro(module: "MLXHuggingFaceMacros", type: "LoadContextMacro")
38+
39+
@freestanding(expression)
40+
public macro huggingFaceLoadModel(
41+
configuration: ModelConfiguration,
42+
progressHandler: @Sendable @escaping (Progress) -> Void
43+
) -> ModelContext =
44+
#externalMacro(module: "MLXHuggingFaceMacros", type: "LoadContextMacro")
45+
46+
public enum HuggingFaceDownloaderError: LocalizedError {
47+
case invalidRepositoryID(String)
48+
49+
public var errorDescription: String? {
50+
switch self {
51+
case .invalidRepositoryID(let id):
52+
return "Invalid Hugging Face repository ID: '\(id)'. Expected format 'namespace/name'."
53+
}
54+
}
55+
}
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import SwiftCompilerPlugin
2+
import SwiftSyntax
3+
import SwiftSyntaxBuilder
4+
import SwiftSyntaxMacros
5+
6+
@main
7+
struct Macros: CompilerPlugin {
8+
let providingMacros: [Macro.Type] = [
9+
DownloaderMacro.self,
10+
TokenizerAdaptorMacro.self,
11+
TokenizerLoaderMacro.self,
12+
LoadContainerMacro.self,
13+
LoadContextMacro.self,
14+
]
15+
}
16+
17+
public struct DownloaderMacro: ExpressionMacro {
18+
public static func expansion(
19+
of node: some FreestandingMacroExpansionSyntax,
20+
in context: some MacroExpansionContext
21+
) throws -> ExprSyntax {
22+
let argument = node.arguments.first?.expression.description ?? "HubClient()"
23+
24+
return
25+
"""
26+
// make sure you:
27+
//
28+
// import HuggingFace
29+
//
30+
{ (hubApi: HubClient) -> MLXLMCommon.Downloader in
31+
struct HubBridge: MLXLMCommon.Downloader {
32+
private let upstream: HubClient
33+
34+
init(_ upstream: HubClient) {
35+
self.upstream = upstream
36+
}
37+
38+
public func download(
39+
id: String,
40+
revision: String?,
41+
matching patterns: [String],
42+
useLatest: Bool,
43+
progressHandler: @Sendable @escaping (Progress) -> Void
44+
) async throws -> URL {
45+
guard let repoID = HuggingFace.Repo.ID(rawValue: id) else {
46+
throw HuggingFaceDownloaderError.invalidRepositoryID(id)
47+
}
48+
let revision = revision ?? "main"
49+
50+
return try await upstream.downloadSnapshot(
51+
of: repoID,
52+
revision: revision,
53+
matching: patterns,
54+
progressHandler: { @MainActor progress in
55+
progressHandler(progress)
56+
}
57+
)
58+
}
59+
}
60+
61+
return HubBridge(hubApi)
62+
}(\(raw: argument))
63+
"""
64+
}
65+
}
66+
67+
public struct TokenizerAdaptorMacro: ExpressionMacro {
68+
public static func expansion(
69+
of node: some FreestandingMacroExpansionSyntax,
70+
in context: some MacroExpansionContext
71+
) throws -> ExprSyntax {
72+
guard let argument = node.arguments.first?.expression else {
73+
throw MacroExpansionError.message("#adaptHuggingFaceTokenizer requires an argument")
74+
}
75+
76+
return
77+
"""
78+
// make sure you:
79+
//
80+
// import Tokenizers
81+
//
82+
{ (huggingFaceTokenizer: Tokenizers.Tokenizer) -> MLXLMCommon.Tokenizer in
83+
struct TokenizerBridge: MLXLMCommon.Tokenizer {
84+
private let upstream: any Tokenizers.Tokenizer
85+
86+
init(_ upstream: any Tokenizers.Tokenizer) {
87+
self.upstream = upstream
88+
}
89+
90+
func encode(text: String, addSpecialTokens: Bool) -> [Int] {
91+
upstream.encode(text: text, addSpecialTokens: addSpecialTokens)
92+
}
93+
94+
// swift-transformers uses `decode(tokens:)` instead of `decode(tokenIds:)`.
95+
func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String {
96+
upstream.decode(tokens: tokenIds, skipSpecialTokens: skipSpecialTokens)
97+
}
98+
99+
func convertTokenToId(_ token: String) -> Int? {
100+
upstream.convertTokenToId(token)
101+
}
102+
103+
func convertIdToToken(_ id: Int) -> String? {
104+
upstream.convertIdToToken(id)
105+
}
106+
107+
var bosToken: String? { upstream.bosToken }
108+
var eosToken: String? { upstream.eosToken }
109+
var unknownToken: String? { upstream.unknownToken }
110+
111+
func applyChatTemplate(
112+
messages: [[String: any Sendable]],
113+
tools: [[String: any Sendable]]?,
114+
additionalContext: [String: any Sendable]?
115+
) throws -> [Int] {
116+
do {
117+
return try upstream.applyChatTemplate(
118+
messages: messages, tools: tools, additionalContext: additionalContext)
119+
} catch Tokenizers.TokenizerError.missingChatTemplate {
120+
throw MLXLMCommon.TokenizerError.missingChatTemplate
121+
}
122+
}
123+
}
124+
125+
return TokenizerBridge(huggingFaceTokenizer)
126+
}(\(argument))
127+
"""
128+
}
129+
}
130+
131+
public struct TokenizerLoaderMacro: ExpressionMacro {
132+
public static func expansion(
133+
of node: some FreestandingMacroExpansionSyntax,
134+
in context: some MacroExpansionContext
135+
) throws -> ExprSyntax {
136+
return
137+
"""
138+
{ () -> MLXLMCommon.TokenizerLoader in
139+
struct TransformersLoader: MLXLMCommon.TokenizerLoader {
140+
public init() {}
141+
142+
public func load(from directory: URL) async throws -> any MLXLMCommon.Tokenizer {
143+
// make sure you:
144+
//
145+
// import Tokenizers
146+
//
147+
let upstream = try await AutoTokenizer.from(modelFolder: directory)
148+
return #adaptHuggingFaceTokenizer(upstream)
149+
}
150+
}
151+
152+
return TransformersLoader()
153+
}()
154+
"""
155+
}
156+
}
157+
158+
public struct LoadContainerMacro: ExpressionMacro {
159+
public static func expansion(
160+
of node: some FreestandingMacroExpansionSyntax,
161+
in context: some MacroExpansionContext
162+
) throws -> ExprSyntax {
163+
guard let configuration = node.arguments.first?.expression else {
164+
throw MacroExpansionError.message(
165+
"#huggingFaceLoadModelContainer requires a configuration")
166+
}
167+
168+
let progress =
169+
if let expr = node.arguments.first(where: { $0.label?.text == "progressHandler" })?
170+
.expression
171+
{
172+
expr.description
173+
} else {
174+
"{ _ in }"
175+
}
176+
177+
return
178+
"""
179+
loadModelContainer(
180+
from: #hubDownloader(),
181+
using: #huggingFaceTokenizerLoader(),
182+
configuration: \(configuration),
183+
progressHandler: \(raw: progress))
184+
"""
185+
}
186+
}
187+
188+
public struct LoadContextMacro: ExpressionMacro {
189+
public static func expansion(
190+
of node: some FreestandingMacroExpansionSyntax,
191+
in context: some MacroExpansionContext
192+
) throws -> ExprSyntax {
193+
guard let configuration = node.arguments.first?.expression else {
194+
throw MacroExpansionError.message("#huggingFaceLoadModel requires a configuration")
195+
}
196+
197+
let progress =
198+
if let expr = node.arguments.first(where: { $0.label?.text == "progressHandler" })?
199+
.expression
200+
{
201+
expr.description
202+
} else {
203+
"{ _ in }"
204+
}
205+
206+
return
207+
"""
208+
loadModel(
209+
from: #hubDownloader(),
210+
using: #huggingFaceTokenizerLoader(),
211+
configuration: \(configuration),
212+
progressHandler: \(raw: progress))
213+
"""
214+
}
215+
}
216+
217+
enum MacroExpansionError: Error, CustomStringConvertible {
218+
case message(String)
219+
220+
var description: String {
221+
switch self {
222+
case .message(let text): return text
223+
}
224+
}
225+
}

Package.swift

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// swift-tools-version: 5.12
22

3+
import CompilerPluginSupport
34
import PackageDescription
45

56
let package = Package(
@@ -23,6 +24,9 @@ let package = Package(
2324
.library(
2425
name: "MLXEmbedders",
2526
targets: ["MLXEmbedders"]),
27+
.library(
28+
name: "MLXHuggingFace",
29+
targets: ["MLXHuggingFace"]),
2630
.library(
2731
name: "BenchmarkHelpers",
2832
targets: ["BenchmarkHelpers"]),
@@ -31,7 +35,8 @@ let package = Package(
3135
targets: ["IntegrationTestHelpers"]),
3236
],
3337
dependencies: [
34-
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.1"))
38+
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.1")),
39+
.package(url: "https://github.com/swiftlang/swift-syntax.git", from: "600.0.0-latest"),
3540
],
3641
targets: [
3742
.target(
@@ -124,6 +129,22 @@ let package = Package(
124129
],
125130
resources: [.process("Resources/1080p_30.mov"), .process("Resources/audio_only.mov")]
126131
),
132+
.macro(
133+
name: "MLXHuggingFaceMacros",
134+
dependencies: [
135+
.product(name: "SwiftSyntaxMacros", package: "swift-syntax"),
136+
.product(name: "SwiftCompilerPlugin", package: "swift-syntax"),
137+
],
138+
path: "Libraries/MLXHuggingFaceMacros"
139+
),
140+
.target(
141+
name: "MLXHuggingFace",
142+
dependencies: [
143+
"MLXHuggingFaceMacros",
144+
"MLXLMCommon",
145+
],
146+
path: "Libraries/MLXHuggingFace"
147+
),
127148
]
128149
)
129150

0 commit comments

Comments
 (0)