2929from pathlib import Path
3030
3131import hydra
32+ import nvdlfw_inspect .api as debug_api
3233import nvtx
3334import torch
3435import transformer_engine .pytorch
3536from omegaconf import DictConfig , OmegaConf
3637from torch .distributed .device_mesh import init_device_mesh
37- from torch .distributed .fsdp import fully_shard
38+ from torch .distributed .fsdp import MixedPrecisionPolicy , fully_shard
3839from torch .optim import AdamW
3940from transformer_engine .common .recipe import Format
4041
5051from distributed_config import DistributedConfig
5152from modeling_llama_te import NVLlamaConfig , NVLlamaForCausalLM
5253from perf_logger import PerfLogger
53- from quantization import resolve_layer_precision
54+ from quantization import initialize_quant_stats_logging , resolve_layer_precision
5455from scheduler import get_cosine_annealing_schedule_with_warmup
5556
5657
@@ -80,6 +81,7 @@ def main(args: DictConfig) -> float | None:
8081 logger .info ("Created device mesh: %s" , device_mesh )
8182
8283 # --- Model Configuration ---
84+ < << << << HEAD
8385< << << << HEAD
8486 # Create quantization recipes -- only used if FP8/FP4 is enabled in the config.
8587 fp8_recipe = None
@@ -99,6 +101,13 @@ def main(args: DictConfig) -> float | None:
99101 model = NVLlamaForCausalLM (config , fp8_recipe = fp8_recipe , fp4_recipe = fp4_recipe )
100102== == == =
101103 config = NVLlamaConfig .from_pretrained (args .config_name_or_path , dtype = torch .bfloat16 , ** args .config_kwargs )
104+ == == == =
105+ config = NVLlamaConfig .from_pretrained (
106+ args .config_name_or_path ,
107+ dtype = torch .float32 if args .use_fp32_master_weights else torch .bfloat16 ,
108+ ** args .config_kwargs ,
109+ )
110+ >> >> >> > 80e4897 e (fixed quant stats init and adds fp32 master weights )
102111
103112 # Resolve layer-wise quantization assignments and store on config.
104113 layer_precision = resolve_layer_precision (
@@ -110,6 +119,14 @@ def main(args: DictConfig) -> float | None:
110119 )
111120 config .layer_precision = layer_precision
112121
122+ if args .quant_stats_config .enabled :
123+ initialize_quant_stats_logging (
124+ quant_stats_file = args .quant_stats_config .quant_stats_file ,
125+ quant_log_dir = args .quant_stats_config .quant_log_dir ,
126+ rank = dist_config .rank ,
127+ layer_precision = layer_precision ,
128+ )
129+
113130 # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.
114131 fp8_recipe = None
115132 fp4_recipe = None
@@ -140,11 +157,21 @@ def main(args: DictConfig) -> float | None:
140157 # --- Distributed Wrapping (FSDP2 + CP) ---
141158 cp_dp_mesh = device_mesh ["dp" , "cp" ]._flatten (mesh_dim_name = "dp_shard_cp" )
142159
160+ if args .use_fp32_master_weights :
161+ mp_policy = MixedPrecisionPolicy (
162+ param_dtype = torch .bfloat16 ,
163+ reduce_dtype = torch .float32 ,
164+ output_dtype = torch .bfloat16 ,
165+ cast_forward_inputs = False ,
166+ )
167+ else :
168+ mp_policy = MixedPrecisionPolicy ()
169+
143170 # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
144171 # Each decoder layer should be individually sharded before sharding the full model.
145172 for layer in model .model .layers :
146- fully_shard (layer , mesh = cp_dp_mesh )
147- fully_shard (model , mesh = cp_dp_mesh )
173+ fully_shard (layer , mesh = cp_dp_mesh , mp_policy = mp_policy )
174+ fully_shard (model , mesh = cp_dp_mesh , mp_policy = mp_policy )
148175
149176 # Attach the CP group to the model.
150177 for layer in model .model .layers :
@@ -161,6 +188,10 @@ def main(args: DictConfig) -> float | None:
161188 # TE layers require special handling to initialize the weights from the meta device.
162189 model .init_empty_weights ()
163190
191+ # Assign names to layers so debug API can identify them
192+ if args .quant_stats_config .enabled :
193+ debug_api .infer_and_assign_layer_names (model )
194+
164195 # --- Optimizer & Scheduler ---
165196 # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
166197 optimizer = AdamW (model .parameters (), ** OmegaConf .to_container (args .adamw_kwargs , resolve = True )) # type: ignore
0 commit comments