Skip to content

Commit eac12c3

Browse files
committed
fixed quant stats init and adds fp32 master weights
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent 4353055 commit eac12c3

2 files changed

Lines changed: 53 additions & 8 deletions

File tree

bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import transformer_engine.pytorch
3434
from omegaconf import DictConfig, OmegaConf
3535
from torch.distributed.device_mesh import init_device_mesh
36-
from torch.distributed.fsdp import fully_shard
36+
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
3737
from torch.optim import AdamW
3838
from transformer_engine.common.recipe import Format
3939
from transformers.models.llama.configuration_llama import LlamaConfig
@@ -105,7 +105,11 @@ def main(args: DictConfig) -> float | None:
105105
model_class = LlamaForCausalLM
106106

107107
# --- Model Configuration ---
108-
config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
108+
config = config_class.from_pretrained(
109+
args.config_name_or_path,
110+
dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16,
111+
**args.config_kwargs,
112+
)
109113

110114
# Resolve layer-wise quantization assignments and store on config.
111115
layer_precision = resolve_layer_precision(
@@ -153,10 +157,20 @@ def main(args: DictConfig) -> float | None:
153157
logger.info("Initialized Model:\n%s", model)
154158

155159
# --- Distributed Wrapping (FSDP2) ---
160+
if args.use_fp32_master_weights:
161+
mp_policy = MixedPrecisionPolicy(
162+
param_dtype=torch.bfloat16,
163+
reduce_dtype=torch.float32,
164+
output_dtype=torch.bfloat16,
165+
cast_forward_inputs=False,
166+
)
167+
else:
168+
mp_policy = MixedPrecisionPolicy()
169+
156170
# Each decoder layer should be individually sharded before sharding the full model.
157171
for layer in model.model.layers:
158-
fully_shard(layer, mesh=device_mesh["dp"])
159-
fully_shard(model, mesh=device_mesh["dp"])
172+
fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy)
173+
fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy)
160174

161175
# Attach quantization recipes to the model (layer precision is already on config).
162176
if isinstance(model, NVLlamaForCausalLM):

bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@
2929
from pathlib import Path
3030

3131
import hydra
32+
import nvdlfw_inspect.api as debug_api
3233
import nvtx
3334
import torch
3435
import transformer_engine.pytorch
3536
from omegaconf import DictConfig, OmegaConf
3637
from torch.distributed.device_mesh import init_device_mesh
37-
from torch.distributed.fsdp import fully_shard
38+
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
3839
from torch.optim import AdamW
3940
from transformer_engine.common.recipe import Format
4041

@@ -50,7 +51,7 @@
5051
from distributed_config import DistributedConfig
5152
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
5253
from perf_logger import PerfLogger
53-
from quantization import resolve_layer_precision
54+
from quantization import initialize_quant_stats_logging, resolve_layer_precision
5455
from scheduler import get_cosine_annealing_schedule_with_warmup
5556

5657

@@ -80,6 +81,7 @@ def main(args: DictConfig) -> float | None:
8081
logger.info("Created device mesh: %s", device_mesh)
8182

8283
# --- Model Configuration ---
84+
<<<<<<< HEAD
8385
<<<<<<< HEAD
8486
# Create quantization recipes -- only used if FP8/FP4 is enabled in the config.
8587
fp8_recipe = None
@@ -99,6 +101,13 @@ def main(args: DictConfig) -> float | None:
99101
model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
100102
=======
101103
config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
104+
=======
105+
config = NVLlamaConfig.from_pretrained(
106+
args.config_name_or_path,
107+
dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16,
108+
**args.config_kwargs,
109+
)
110+
>>>>>>> 80e4897e (fixed quant stats init and adds fp32 master weights)
102111

103112
# Resolve layer-wise quantization assignments and store on config.
104113
layer_precision = resolve_layer_precision(
@@ -110,6 +119,14 @@ def main(args: DictConfig) -> float | None:
110119
)
111120
config.layer_precision = layer_precision
112121

122+
if args.quant_stats_config.enabled:
123+
initialize_quant_stats_logging(
124+
quant_stats_file=args.quant_stats_config.quant_stats_file,
125+
quant_log_dir=args.quant_stats_config.quant_log_dir,
126+
rank=dist_config.rank,
127+
layer_precision=layer_precision,
128+
)
129+
113130
# Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.
114131
fp8_recipe = None
115132
fp4_recipe = None
@@ -140,11 +157,21 @@ def main(args: DictConfig) -> float | None:
140157
# --- Distributed Wrapping (FSDP2 + CP) ---
141158
cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp")
142159

160+
if args.use_fp32_master_weights:
161+
mp_policy = MixedPrecisionPolicy(
162+
param_dtype=torch.bfloat16,
163+
reduce_dtype=torch.float32,
164+
output_dtype=torch.bfloat16,
165+
cast_forward_inputs=False,
166+
)
167+
else:
168+
mp_policy = MixedPrecisionPolicy()
169+
143170
# Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
144171
# Each decoder layer should be individually sharded before sharding the full model.
145172
for layer in model.model.layers:
146-
fully_shard(layer, mesh=cp_dp_mesh)
147-
fully_shard(model, mesh=cp_dp_mesh)
173+
fully_shard(layer, mesh=cp_dp_mesh, mp_policy=mp_policy)
174+
fully_shard(model, mesh=cp_dp_mesh, mp_policy=mp_policy)
148175

149176
# Attach the CP group to the model.
150177
for layer in model.model.layers:
@@ -161,6 +188,10 @@ def main(args: DictConfig) -> float | None:
161188
# TE layers require special handling to initialize the weights from the meta device.
162189
model.init_empty_weights()
163190

191+
# Assign names to layers so debug API can identify them
192+
if args.quant_stats_config.enabled:
193+
debug_api.infer_and_assign_layer_names(model)
194+
164195
# --- Optimizer & Scheduler ---
165196
# Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
166197
optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore

0 commit comments

Comments
 (0)