2323
2424from transformers import AutoConfig , PretrainedConfig
2525
26- _HYBRID_ARCHITECTURES = frozenset ({
27- "NemotronHForCausalLM" ,
28- "NemotronHybridForCausalLM" ,
29- })
26+ _HYBRID_ARCHITECTURES = frozenset (
27+ {
28+ "NemotronHForCausalLM" ,
29+ "NemotronHybridForCausalLM" ,
30+ }
31+ )
3032
3133# Number of extra embedding rows the SpeechLM adds on top of the backbone's
3234# native vocab during training: ``<|audio|>`` locator plus headroom for other
@@ -70,9 +72,7 @@ def __init__(
7072 self .pretrained_weights = pretrained_weights
7173 self .lora = lora
7274
73- self .text_config = AutoConfig .from_pretrained (
74- pretrained_llm , trust_remote_code = True
75- )
75+ self .text_config = AutoConfig .from_pretrained (pretrained_llm , trust_remote_code = True )
7676
7777 raw_archs = getattr (self .text_config , "architectures" , [])
7878 if len (raw_archs ) != 1 :
@@ -91,17 +91,10 @@ def __init__(
9191 # downstream ``init_vllm_registered_model(architectures=...)`` call
9292 # that threads this text_config through resolves correctly.
9393 self .text_config .architectures = ["NemotronHForCausalLM" ]
94- if (
95- not hasattr (self .text_config , "total_num_kv_heads" )
96- or self .text_config .total_num_kv_heads is None
97- ):
98- self .text_config .total_num_kv_heads = getattr (
99- self .text_config , "num_key_value_heads" , 2
100- )
94+ if not hasattr (self .text_config , "total_num_kv_heads" ) or self .text_config .total_num_kv_heads is None :
95+ self .text_config .total_num_kv_heads = getattr (self .text_config , "num_key_value_heads" , 2 )
10196 if not hasattr (self .text_config , "rms_norm_eps" ):
102- self .text_config .rms_norm_eps = getattr (
103- self .text_config , "layer_norm_epsilon" , 1e-5
104- )
97+ self .text_config .rms_norm_eps = getattr (self .text_config , "layer_norm_epsilon" , 1e-5 )
10598
10699 self .text_config .vocab_size += _SPEECHLM_EMBED_EXTRA_ROWS
107100
@@ -156,6 +149,4 @@ def __getattr__(self, name):
156149 return getattr (self .text_config , name )
157150 except AttributeError :
158151 pass
159- raise AttributeError (
160- f"'{ type (self ).__name__ } ' has no attribute '{ name } '"
161- )
152+ raise AttributeError (f"'{ type (self ).__name__ } ' has no attribute '{ name } '" )
0 commit comments