33import ArgumentParser
44import CoreImage
55import Foundation
6- import Hub
6+ import HuggingFace
77import MLX
8+ import MLXHuggingFace
89import MLXLLM
910import MLXLMCommon
1011import MLXVLM
@@ -30,6 +31,17 @@ struct ModelArguments: ParsableArguments, Sendable {
3031 @Option ( help: " Hub download directory " )
3132 var download : URL ?
3233
34+ var downloader : any Downloader {
35+ let client =
36+ if let download {
37+ HubClient ( cache: HubCache ( cacheDirectory: download) )
38+ } else {
39+ HubClient ( )
40+ }
41+ let downloader = #hubDownloader( client)
42+ return downloader
43+ }
44+
3345 @Sendable
3446 func load( defaultModel: String , modelFactory: ModelFactory ) async throws -> ModelContainer {
3547 let modelConfiguration : ModelConfiguration
@@ -46,14 +58,10 @@ struct ModelArguments: ParsableArguments, Sendable {
4658 modelConfiguration = modelFactory. configuration ( id: modelName)
4759 }
4860
49- let hub =
50- if let download {
51- HubApi ( downloadBase: download)
52- } else {
53- HubApi ( )
54- }
55-
56- return try await modelFactory. loadContainer ( hub: hub, configuration: modelConfiguration)
61+ return try await loadModelContainer (
62+ from: self . downloader,
63+ using: #huggingFaceTokenizerLoader( ) ,
64+ configuration: modelConfiguration)
5765 }
5866}
5967
@@ -157,6 +165,9 @@ struct GenerateArguments: ParsableArguments, Sendable {
157165 @Flag ( name: . shortAndLong, help: " If true only print the generated output " )
158166 var quiet = false
159167
168+ @Flag ( name: . customLong( " tool-time " ) , help: " Enable time telling tool " )
169+ var useTimeTool = false
170+
160171 var generateParameters : GenerateParameters {
161172 GenerateParameters (
162173 maxTokens: maxTokens,
@@ -167,6 +178,23 @@ struct GenerateArguments: ParsableArguments, Sendable {
167178 repetitionContextSize: repetitionContextSize)
168179 }
169180
181+ var toolSpecs : [ MLXLMCommon . ToolSpec ] {
182+ var tools = [ MLXLMCommon . ToolSpec] ( )
183+
184+ if useTimeTool {
185+ tools. append ( timeTool. schema)
186+ }
187+
188+ return tools
189+ }
190+
191+ func call( toolCall: ToolCall ) async throws -> String {
192+ if useTimeTool && toolCall. function. name == timeTool. name {
193+ return try await toolCall. execute ( with: timeTool) . toolResult
194+ }
195+ return " Unknown tool: \( toolCall. function. name) "
196+ }
197+
170198 func prepare(
171199 _ context: inout ModelContext
172200 ) {
@@ -188,7 +216,14 @@ struct GenerateArguments: ParsableArguments, Sendable {
188216 print ( string, terminator: " " )
189217 case . info( let info) :
190218 return ( info, output)
191- case . toolCall:
219+ case . toolCall( let toolCall) :
220+ do {
221+ // TODO maybe just use ChatSession here?
222+ let x = try await call ( toolCall: toolCall)
223+ print ( " TOOL RESULT: \( x) " )
224+ } catch {
225+ print ( " \n Error executing tool: \( error. localizedDescription) " )
226+ }
192227 break
193228 }
194229 }
@@ -323,7 +358,8 @@ struct EvaluateCommand: AsyncParsableCommand {
323358 modelContainer,
324359 instructions: generate. system,
325360 generateParameters: generate. generateParameters,
326- processing: media. processing
361+ processing: media. processing,
362+ tools: generate. toolSpecs
327363 )
328364
329365 if !generate. quiet {
0 commit comments