@@ -87,11 +87,9 @@ fn match_hub_layout(
8787}
8888
8989fn resolve_local ( folder : & Path ) -> Option < ResolvedPaths > {
90- // Native model2vec.
9190 if let r @ Some ( _) = match_local_layout ( folder, folder, "config.json" , ModelLayout :: Native ) {
9291 return r;
9392 }
94- // Sentence Transformers root layout.
9593 if let r @ Some ( _) = match_local_layout (
9694 folder,
9795 folder,
@@ -121,11 +119,9 @@ fn resolve_local(folder: &Path) -> Option<ResolvedPaths> {
121119}
122120
123121fn resolve_hub ( repo : & ApiRepo , prefix : & str ) -> Result < ResolvedPaths > {
124- // Native model2vec.
125122 if let Some ( r) = match_hub_layout ( repo, prefix, prefix, "config.json" , ModelLayout :: Native ) {
126123 return r;
127124 }
128- // Sentence Transformers root layout.
129125 if let Some ( r) = match_hub_layout (
130126 repo,
131127 prefix,
@@ -291,6 +287,13 @@ impl StaticModel {
291287 /// * `normalize` - Whether to L2-normalize output embeddings
292288 /// * `weights` - Optional per-token weights for quantized models
293289 /// * `token_mapping` - Optional token ID mapping for quantized models
290+ fn check_shape ( len : usize , rows : usize , cols : usize ) -> Result < ( ) > {
291+ if len != rows * cols {
292+ return Err ( anyhow ! ( "embeddings length {} != rows {} * cols {}" , len, rows, cols) ) ;
293+ }
294+ Ok ( ( ) )
295+ }
296+
294297 pub fn from_owned (
295298 tokenizer : Tokenizer ,
296299 embeddings : Vec < f32 > ,
@@ -300,14 +303,7 @@ impl StaticModel {
300303 weights : Option < Vec < f32 > > ,
301304 token_mapping : Option < Vec < usize > > ,
302305 ) -> Result < Self > {
303- if embeddings. len ( ) != rows * cols {
304- return Err ( anyhow ! (
305- "embeddings length {} != rows {} * cols {}" ,
306- embeddings. len( ) ,
307- rows,
308- cols
309- ) ) ;
310- }
306+ Self :: check_shape ( embeddings. len ( ) , rows, cols) ?;
311307
312308 let ( median_token_length, unk_token_id) = Self :: compute_metadata ( & tokenizer) ?;
313309
@@ -345,14 +341,7 @@ impl StaticModel {
345341 weights : Option < & ' static [ f32 ] > ,
346342 token_mapping : Option < & ' static [ usize ] > ,
347343 ) -> Result < Self > {
348- if embeddings. len ( ) != rows * cols {
349- return Err ( anyhow ! (
350- "embeddings length {} != rows {} * cols {}" ,
351- embeddings. len( ) ,
352- rows,
353- cols
354- ) ) ;
355- }
344+ Self :: check_shape ( embeddings. len ( ) , rows, cols) ?;
356345
357346 let ( median_token_length, unk_token_id) = Self :: compute_metadata ( & tokenizer) ?;
358347
@@ -375,10 +364,7 @@ impl StaticModel {
375364 lens. sort_unstable ( ) ;
376365 let median_token_length = lens. get ( lens. len ( ) / 2 ) . copied ( ) . unwrap_or ( 1 ) ;
377366
378- let spec_json = tokenizer
379- . to_string ( false )
380- . map_err ( |e| anyhow ! ( "tokenizer -> JSON failed: {e}" ) ) ?;
381- let spec: Value = serde_json:: from_str ( & spec_json) ?;
367+ let spec: Value = serde_json:: to_value ( tokenizer) . context ( "failed to serialize tokenizer" ) ?;
382368 let unk_token = spec
383369 . get ( "model" )
384370 . and_then ( |m| m. get ( "unk_token" ) )
@@ -430,7 +416,7 @@ impl StaticModel {
430416 . tokenizer
431417 . encode_batch_fast :: < String > (
432418 truncated. into_iter ( ) . map ( Into :: into) . collect ( ) ,
433- /* add_special_tokens = */ false ,
419+ false ,
434420 )
435421 . expect ( "tokenization failed" ) ;
436422 for encoding in encodings {
0 commit comments