@@ -6,8 +6,51 @@ import MLX
66import MLXNN
77import Tokenizers
88
9- struct EmbedderError : Error {
10- let message : String
9+ public enum EmbedderError : LocalizedError {
10+ case unsupportedModelType( String )
11+ case missingConfigurationFile( String , String )
12+ case configurationFileError( String , String , Error )
13+ case configurationDecodingError( String , String , DecodingError )
14+ case missingTokenizerConfig
15+
16+ public var errorDescription : String ? {
17+ switch self {
18+ case . unsupportedModelType( let type) :
19+ return " Unsupported model type: \( type) "
20+ case . missingConfigurationFile( let file, let modelName) :
21+ return " Missing or unreadable configuration file ' \( file) ' for model ' \( modelName) ' "
22+ case . configurationFileError( let file, let modelName, let error) :
23+ return " Error reading ' \( file) ' for model ' \( modelName) ': \( error. localizedDescription) "
24+ case . configurationDecodingError( let file, let modelName, let decodingError) :
25+ let errorDetail = extractDecodingErrorDetail ( decodingError)
26+ return " Failed to parse \( file) for model ' \( modelName) ': \( errorDetail) "
27+ case . missingTokenizerConfig:
28+ return " Missing tokenizer configuration "
29+ }
30+ }
31+
32+ private func extractDecodingErrorDetail( _ error: DecodingError ) -> String {
33+ switch error {
34+ case . keyNotFound( let key, let context) :
35+ let path = ( context. codingPath + [ key] ) . map { $0. stringValue } . joined ( separator: " . " )
36+ return " Missing field ' \( path) ' "
37+ case . typeMismatch( _, let context) :
38+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
39+ return " Type mismatch at ' \( path) ' "
40+ case . valueNotFound( _, let context) :
41+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
42+ return " Missing value at ' \( path) ' "
43+ case . dataCorrupted( let context) :
44+ if context. codingPath. isEmpty {
45+ return " Invalid JSON "
46+ } else {
47+ let path = context. codingPath. map { $0. stringValue } . joined ( separator: " . " )
48+ return " Invalid data at ' \( path) ' "
49+ }
50+ @unknown default :
51+ return error. localizedDescription
52+ }
53+ }
1154}
1255
1356func prepareModelDirectory(
@@ -53,20 +96,38 @@ public func load(
5396 // Start tokenizer loading asynchronously, then load model synchronously.
5497 // Both operations run in parallel because async let begins execution immediately.
5598 async let tokenizerTask = loadTokenizer ( configuration: configuration, hub: hub)
56- let model = try loadSynchronous ( modelDirectory: modelDirectory)
99+ let model = try loadSynchronous ( modelDirectory: modelDirectory, modelName : configuration . name )
57100 let tokenizer = try await tokenizerTask
58101
59102 return ( model, tokenizer)
60103}
61104
62- func loadSynchronous( modelDirectory: URL ) throws -> EmbeddingModel {
105+ func loadSynchronous( modelDirectory: URL , modelName : String ) throws -> EmbeddingModel {
63106 // Load config.json once and decode for both base config and model-specific config
64107 let configurationURL = modelDirectory. appending ( component: " config.json " )
65- let configData = try Data ( contentsOf: configurationURL)
66- let baseConfig = try JSONDecoder ( ) . decode ( BaseConfiguration . self, from: configData)
108+ let configData : Data
109+ do {
110+ configData = try Data ( contentsOf: configurationURL)
111+ } catch {
112+ throw EmbedderError . configurationFileError (
113+ configurationURL. lastPathComponent, modelName, error)
114+ }
115+ let baseConfig : BaseConfiguration
116+ do {
117+ baseConfig = try JSONDecoder ( ) . decode ( BaseConfiguration . self, from: configData)
118+ } catch let error as DecodingError {
119+ throw EmbedderError . configurationDecodingError (
120+ configurationURL. lastPathComponent, modelName, error)
121+ }
67122
68123 let modelType = ModelType ( rawValue: baseConfig. modelType)
69- let model = try modelType. createModel ( configuration: configData)
124+ let model : EmbeddingModel
125+ do {
126+ model = try modelType. createModel ( configuration: configData)
127+ } catch let error as DecodingError {
128+ throw EmbedderError . configurationDecodingError (
129+ configurationURL. lastPathComponent, modelName, error)
130+ }
70131
71132 // load the weights
72133 var weights = [ String: MLXArray] ( )
0 commit comments