@@ -4,8 +4,9 @@ use std::fs;
44use std:: path:: PathBuf ;
55
66use crate :: utils:: file;
7+ use crate :: utils:: format:: format_parameters;
78
8- #[ derive( Debug , Serialize , Deserialize , Clone ) ]
9+ #[ derive( Debug , Serialize , Deserialize , Clone , PartialEq ) ]
910pub struct ModelSpec {
1011 #[ serde( skip_serializing_if = "Option::is_none" ) ]
1112 pub model_type : Option < String > ,
@@ -17,6 +18,80 @@ pub struct ModelSpec {
1718 pub parameters : Option < String > ,
1819}
1920
21+ impl ModelSpec {
22+ /// Extract model spec from config.json
23+ pub fn from_config ( config : & serde_json:: Value ) -> Option < Self > {
24+ let model_type = config
25+ . get ( "model_type" )
26+ . and_then ( |v| v. as_str ( ) )
27+ . map ( |s| s. to_string ( ) ) ;
28+
29+ let architectures = config
30+ . get ( "architectures" )
31+ . and_then ( |v| v. as_array ( ) )
32+ . map ( |arr| {
33+ arr. iter ( )
34+ . filter_map ( |v| v. as_str ( ) . map ( |s| s. to_string ( ) ) )
35+ . collect :: < Vec < String > > ( )
36+ } )
37+ . filter ( |v| !v. is_empty ( ) ) ;
38+
39+ let context_window = config
40+ . get ( "n_positions" )
41+ . or_else ( || config. get ( "max_position_embeddings" ) )
42+ . or_else ( || config. get ( "n_ctx" ) )
43+ . and_then ( |v| v. as_u64 ( ) )
44+ . map ( |v| v as u32 ) ;
45+
46+ let parameters = Self :: estimate_parameters ( config) ;
47+
48+ if model_type. is_some ( )
49+ || architectures. is_some ( )
50+ || context_window. is_some ( )
51+ || parameters. is_some ( )
52+ {
53+ Some ( ModelSpec {
54+ model_type,
55+ architectures,
56+ context_window,
57+ parameters,
58+ } )
59+ } else {
60+ None
61+ }
62+ }
63+
64+ /// Estimate model parameters from config
65+ fn estimate_parameters ( config : & serde_json:: Value ) -> Option < String > {
66+ let n_layer = config
67+ . get ( "n_layer" )
68+ . or_else ( || config. get ( "num_hidden_layers" ) )
69+ . and_then ( |v| v. as_u64 ( ) ) ?;
70+
71+ let n_embd = config
72+ . get ( "n_embd" )
73+ . or_else ( || config. get ( "hidden_size" ) )
74+ . and_then ( |v| v. as_u64 ( ) ) ?;
75+
76+ let vocab_size = config. get ( "vocab_size" ) . and_then ( |v| v. as_u64 ( ) ) ?;
77+
78+ let n_positions = config
79+ . get ( "n_positions" )
80+ . or_else ( || config. get ( "max_position_embeddings" ) )
81+ . and_then ( |v| v. as_u64 ( ) )
82+ . unwrap_or ( 2048 ) ;
83+
84+ // Rough parameter estimation for transformer models
85+ // Each layer: ~12 * n_embd^2 (attention + FFN)
86+ // Embeddings: vocab_size * n_embd + n_positions * n_embd
87+ let layer_params = 12 * n_layer * n_embd * n_embd;
88+ let embedding_params = vocab_size * n_embd + n_positions * n_embd;
89+ let total_params = layer_params + embedding_params;
90+
91+ Some ( format_parameters ( total_params) )
92+ }
93+ }
94+
2095#[ derive( Debug , Serialize , Deserialize , Clone ) ]
2196pub struct ModelInfo {
2297 pub name : String ,
@@ -349,4 +424,75 @@ mod tests {
349424 assert_eq ! ( model_info. name, "test/legacy-model" ) ;
350425 assert ! ( model_info. spec. is_none( ) ) ;
351426 }
427+
428+ #[ test]
429+ fn test_model_spec_from_config_gpt2 ( ) {
430+ use serde_json:: json;
431+
432+ let config = json ! ( {
433+ "model_type" : "gpt2" ,
434+ "architectures" : [ "GPT2LMHeadModel" ] ,
435+ "n_layer" : 5 ,
436+ "n_embd" : 32 ,
437+ "vocab_size" : 1000 ,
438+ "n_positions" : 512
439+ } ) ;
440+
441+ let spec = ModelSpec :: from_config ( & config) ;
442+ assert ! ( spec. is_some( ) ) ;
443+
444+ let spec = spec. unwrap ( ) ;
445+ assert_eq ! ( spec. model_type, Some ( "gpt2" . to_string( ) ) ) ;
446+ assert_eq ! ( spec. architectures, Some ( vec![ "GPT2LMHeadModel" . to_string( ) ] ) ) ;
447+ assert_eq ! ( spec. context_window, Some ( 512 ) ) ;
448+ assert_eq ! ( spec. parameters, Some ( "109.82K" . to_string( ) ) ) ;
449+ }
450+
451+ #[ test]
452+ fn test_model_spec_from_config_bert_style ( ) {
453+ use serde_json:: json;
454+
455+ let config = json ! ( {
456+ "model_type" : "bert" ,
457+ "num_hidden_layers" : 12 ,
458+ "hidden_size" : 768 ,
459+ "vocab_size" : 30000 ,
460+ "max_position_embeddings" : 512
461+ } ) ;
462+
463+ let spec = ModelSpec :: from_config ( & config) ;
464+ assert ! ( spec. is_some( ) ) ;
465+
466+ let spec = spec. unwrap ( ) ;
467+ assert_eq ! ( spec. model_type, Some ( "bert" . to_string( ) ) ) ;
468+ assert_eq ! ( spec. context_window, Some ( 512 ) ) ;
469+ assert ! ( spec. parameters. unwrap( ) . contains( "M" ) ) ;
470+ }
471+
472+ #[ test]
473+ fn test_model_spec_from_config_partial ( ) {
474+ use serde_json:: json;
475+
476+ let config = json ! ( {
477+ "model_type" : "llama" ,
478+ "n_ctx" : 4096
479+ } ) ;
480+
481+ let spec = ModelSpec :: from_config ( & config) ;
482+ assert ! ( spec. is_some( ) ) ;
483+
484+ let spec = spec. unwrap ( ) ;
485+ assert_eq ! ( spec. model_type, Some ( "llama" . to_string( ) ) ) ;
486+ assert_eq ! ( spec. context_window, Some ( 4096 ) ) ;
487+ assert_eq ! ( spec. parameters, None ) ;
488+ }
489+
490+ #[ test]
491+ fn test_model_spec_from_config_empty ( ) {
492+ use serde_json:: json;
493+
494+ let config = json ! ( { } ) ;
495+ let spec = ModelSpec :: from_config ( & config) ;
496+ assert_eq ! ( spec, None ) ;
497+ }
352498}
0 commit comments