|
| 1 | +import SwiftCompilerPlugin |
| 2 | +import SwiftSyntax |
| 3 | +import SwiftSyntaxBuilder |
| 4 | +import SwiftSyntaxMacros |
| 5 | + |
| 6 | +@main |
| 7 | +struct Macros: CompilerPlugin { |
| 8 | + let providingMacros: [Macro.Type] = [ |
| 9 | + DownloaderMacro.self, |
| 10 | + TokenizerAdaptorMacro.self, |
| 11 | + TokenizerLoaderMacro.self, |
| 12 | + LoadContainerMacro.self, |
| 13 | + LoadContextMacro.self, |
| 14 | + ] |
| 15 | +} |
| 16 | + |
| 17 | +public struct DownloaderMacro: ExpressionMacro { |
| 18 | + public static func expansion( |
| 19 | + of node: some FreestandingMacroExpansionSyntax, |
| 20 | + in context: some MacroExpansionContext |
| 21 | + ) throws -> ExprSyntax { |
| 22 | + let argument = node.arguments.first?.expression.description ?? "HubClient()" |
| 23 | + |
| 24 | + return |
| 25 | + """ |
| 26 | + // make sure you: |
| 27 | + // |
| 28 | + // import HuggingFace |
| 29 | + // |
| 30 | + { (hubApi: HubClient) -> MLXLMCommon.Downloader in |
| 31 | + struct HubBridge: MLXLMCommon.Downloader { |
| 32 | + private let upstream: HubClient |
| 33 | +
|
| 34 | + init(_ upstream: HubClient) { |
| 35 | + self.upstream = upstream |
| 36 | + } |
| 37 | +
|
| 38 | + public func download( |
| 39 | + id: String, |
| 40 | + revision: String?, |
| 41 | + matching patterns: [String], |
| 42 | + useLatest: Bool, |
| 43 | + progressHandler: @Sendable @escaping (Progress) -> Void |
| 44 | + ) async throws -> URL { |
| 45 | + guard let repoID = HuggingFace.Repo.ID(rawValue: id) else { |
| 46 | + throw HuggingFaceDownloaderError.invalidRepositoryID(id) |
| 47 | + } |
| 48 | + let revision = revision ?? "main" |
| 49 | +
|
| 50 | + return try await upstream.downloadSnapshot( |
| 51 | + of: repoID, |
| 52 | + revision: revision, |
| 53 | + matching: patterns, |
| 54 | + progressHandler: { @MainActor progress in |
| 55 | + progressHandler(progress) |
| 56 | + } |
| 57 | + ) |
| 58 | + } |
| 59 | + } |
| 60 | +
|
| 61 | + return HubBridge(hubApi) |
| 62 | + }(\(raw: argument)) |
| 63 | + """ |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +public struct TokenizerAdaptorMacro: ExpressionMacro { |
| 68 | + public static func expansion( |
| 69 | + of node: some FreestandingMacroExpansionSyntax, |
| 70 | + in context: some MacroExpansionContext |
| 71 | + ) throws -> ExprSyntax { |
| 72 | + guard let argument = node.arguments.first?.expression else { |
| 73 | + throw MacroExpansionError.message("#adaptHuggingFaceTokenizer requires an argument") |
| 74 | + } |
| 75 | + |
| 76 | + return |
| 77 | + """ |
| 78 | + // make sure you: |
| 79 | + // |
| 80 | + // import Tokenizers |
| 81 | + // |
| 82 | + { (huggingFaceTokenizer: Tokenizers.Tokenizer) -> MLXLMCommon.Tokenizer in |
| 83 | + struct TokenizerBridge: MLXLMCommon.Tokenizer { |
| 84 | + private let upstream: any Tokenizers.Tokenizer |
| 85 | +
|
| 86 | + init(_ upstream: any Tokenizers.Tokenizer) { |
| 87 | + self.upstream = upstream |
| 88 | + } |
| 89 | +
|
| 90 | + func encode(text: String, addSpecialTokens: Bool) -> [Int] { |
| 91 | + upstream.encode(text: text, addSpecialTokens: addSpecialTokens) |
| 92 | + } |
| 93 | +
|
| 94 | + // swift-transformers uses `decode(tokens:)` instead of `decode(tokenIds:)`. |
| 95 | + func decode(tokenIds: [Int], skipSpecialTokens: Bool) -> String { |
| 96 | + upstream.decode(tokens: tokenIds, skipSpecialTokens: skipSpecialTokens) |
| 97 | + } |
| 98 | +
|
| 99 | + func convertTokenToId(_ token: String) -> Int? { |
| 100 | + upstream.convertTokenToId(token) |
| 101 | + } |
| 102 | +
|
| 103 | + func convertIdToToken(_ id: Int) -> String? { |
| 104 | + upstream.convertIdToToken(id) |
| 105 | + } |
| 106 | +
|
| 107 | + var bosToken: String? { upstream.bosToken } |
| 108 | + var eosToken: String? { upstream.eosToken } |
| 109 | + var unknownToken: String? { upstream.unknownToken } |
| 110 | +
|
| 111 | + func applyChatTemplate( |
| 112 | + messages: [[String: any Sendable]], |
| 113 | + tools: [[String: any Sendable]]?, |
| 114 | + additionalContext: [String: any Sendable]? |
| 115 | + ) throws -> [Int] { |
| 116 | + do { |
| 117 | + return try upstream.applyChatTemplate( |
| 118 | + messages: messages, tools: tools, additionalContext: additionalContext) |
| 119 | + } catch Tokenizers.TokenizerError.missingChatTemplate { |
| 120 | + throw MLXLMCommon.TokenizerError.missingChatTemplate |
| 121 | + } |
| 122 | + } |
| 123 | + } |
| 124 | +
|
| 125 | + return TokenizerBridge(huggingFaceTokenizer) |
| 126 | + }(\(argument)) |
| 127 | + """ |
| 128 | + } |
| 129 | +} |
| 130 | + |
| 131 | +public struct TokenizerLoaderMacro: ExpressionMacro { |
| 132 | + public static func expansion( |
| 133 | + of node: some FreestandingMacroExpansionSyntax, |
| 134 | + in context: some MacroExpansionContext |
| 135 | + ) throws -> ExprSyntax { |
| 136 | + return |
| 137 | + """ |
| 138 | + { () -> MLXLMCommon.TokenizerLoader in |
| 139 | + struct TransformersLoader: MLXLMCommon.TokenizerLoader { |
| 140 | + public init() {} |
| 141 | +
|
| 142 | + public func load(from directory: URL) async throws -> any MLXLMCommon.Tokenizer { |
| 143 | + // make sure you: |
| 144 | + // |
| 145 | + // import Tokenizers |
| 146 | + // |
| 147 | + let upstream = try await AutoTokenizer.from(modelFolder: directory) |
| 148 | + return #adaptHuggingFaceTokenizer(upstream) |
| 149 | + } |
| 150 | + } |
| 151 | +
|
| 152 | + return TransformersLoader() |
| 153 | + }() |
| 154 | + """ |
| 155 | + } |
| 156 | +} |
| 157 | + |
| 158 | +public struct LoadContainerMacro: ExpressionMacro { |
| 159 | + public static func expansion( |
| 160 | + of node: some FreestandingMacroExpansionSyntax, |
| 161 | + in context: some MacroExpansionContext |
| 162 | + ) throws -> ExprSyntax { |
| 163 | + guard let configuration = node.arguments.first?.expression else { |
| 164 | + throw MacroExpansionError.message( |
| 165 | + "#huggingFaceLoadModelContainer requires a configuration") |
| 166 | + } |
| 167 | + |
| 168 | + let progress = |
| 169 | + if let expr = node.arguments.first(where: { $0.label?.text == "progressHandler" })? |
| 170 | + .expression |
| 171 | + { |
| 172 | + expr.description |
| 173 | + } else { |
| 174 | + "{ _ in }" |
| 175 | + } |
| 176 | + |
| 177 | + return |
| 178 | + """ |
| 179 | + loadModelContainer( |
| 180 | + from: #hubDownloader(), |
| 181 | + using: #huggingFaceTokenizerLoader(), |
| 182 | + configuration: \(configuration), |
| 183 | + progressHandler: \(raw: progress)) |
| 184 | + """ |
| 185 | + } |
| 186 | +} |
| 187 | + |
| 188 | +public struct LoadContextMacro: ExpressionMacro { |
| 189 | + public static func expansion( |
| 190 | + of node: some FreestandingMacroExpansionSyntax, |
| 191 | + in context: some MacroExpansionContext |
| 192 | + ) throws -> ExprSyntax { |
| 193 | + guard let configuration = node.arguments.first?.expression else { |
| 194 | + throw MacroExpansionError.message("#huggingFaceLoadModel requires a configuration") |
| 195 | + } |
| 196 | + |
| 197 | + let progress = |
| 198 | + if let expr = node.arguments.first(where: { $0.label?.text == "progressHandler" })? |
| 199 | + .expression |
| 200 | + { |
| 201 | + expr.description |
| 202 | + } else { |
| 203 | + "{ _ in }" |
| 204 | + } |
| 205 | + |
| 206 | + return |
| 207 | + """ |
| 208 | + loadModel( |
| 209 | + from: #hubDownloader(), |
| 210 | + using: #huggingFaceTokenizerLoader(), |
| 211 | + configuration: \(configuration), |
| 212 | + progressHandler: \(raw: progress)) |
| 213 | + """ |
| 214 | + } |
| 215 | +} |
| 216 | + |
| 217 | +enum MacroExpansionError: Error, CustomStringConvertible { |
| 218 | + case message(String) |
| 219 | + |
| 220 | + var description: String { |
| 221 | + switch self { |
| 222 | + case .message(let text): return text |
| 223 | + } |
| 224 | + } |
| 225 | +} |
0 commit comments