Skip to content

Commit 6079013

Browse files
committed
rebase
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent 89830c0 commit 6079013

3 files changed

Lines changed: 11 additions & 3 deletions

File tree

bionemo-recipes/recipes/llama3_native_te/train_ddp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@ def main(args: DictConfig) -> float | None:
119119
with transformer_engine.pytorch.quantized_model_init(
120120
recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs
121121
):
122-
model = model_class(config)
122+
model = (
123+
model_class(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
124+
if model_class is NVLlamaForCausalLM
125+
else model_class(config)
126+
)
123127

124128
logger.info("Initialized Model:\n%s", model)
125129

bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,11 @@ def main(args: DictConfig) -> float | None:
128128
recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs
129129
),
130130
):
131-
model = model_class(config)
131+
model = (
132+
model_class(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
133+
if model_class is NVLlamaForCausalLM
134+
else model_class(config)
135+
)
132136

133137
logger.info("Initialized Model:\n%s", model)
134138

bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def main(args: DictConfig) -> float | None:
127127
recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs
128128
),
129129
):
130-
model = NVLlamaForCausalLM(config)
130+
model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
131131

132132
logger.info("Initialized Model:\n%s", model)
133133

0 commit comments

Comments
 (0)