File tree Expand file tree Collapse file tree
bionemo-recipes/recipes/llama3_native_te Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -119,7 +119,11 @@ def main(args: DictConfig) -> float | None:
119119 with transformer_engine .pytorch .quantized_model_init (
120120 recipe = fp8_recipe , ** args .fp8_config .quantized_model_init_kwargs
121121 ):
122- model = model_class (config )
122+ model = (
123+ model_class (config , fp8_recipe = fp8_recipe , fp4_recipe = fp4_recipe )
124+ if model_class is NVLlamaForCausalLM
125+ else model_class (config )
126+ )
123127
124128 logger .info ("Initialized Model:\n %s" , model )
125129
Original file line number Diff line number Diff line change @@ -128,7 +128,11 @@ def main(args: DictConfig) -> float | None:
128128 recipe = fp8_recipe , ** args .fp8_config .quantized_model_init_kwargs
129129 ),
130130 ):
131- model = model_class (config )
131+ model = (
132+ model_class (config , fp8_recipe = fp8_recipe , fp4_recipe = fp4_recipe )
133+ if model_class is NVLlamaForCausalLM
134+ else model_class (config )
135+ )
132136
133137 logger .info ("Initialized Model:\n %s" , model )
134138
Original file line number Diff line number Diff line change @@ -127,7 +127,7 @@ def main(args: DictConfig) -> float | None:
127127 recipe = fp8_recipe , ** args .fp8_config .quantized_model_init_kwargs
128128 ),
129129 ):
130- model = NVLlamaForCausalLM (config )
130+ model = NVLlamaForCausalLM (config , fp8_recipe = fp8_recipe , fp4_recipe = fp4_recipe )
131131
132132 logger .info ("Initialized Model:\n %s" , model )
133133
You can’t perform that action at this time.
0 commit comments