Skip to content

Commit 57a50b3

Browse files
committed
Adds quant stats logging support
Adds unpadded_tps to wandb charts
1 parent 213ef6e commit 57a50b3

12 files changed

Lines changed: 1096 additions & 1 deletion

File tree

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
example_fp8_tensor_stat_collection:
2+
enabled: True
3+
layers:
4+
# Match the te.Linear sublayers within MiniFormer blocks
5+
layer_types: [pi, gi, po, go, fc1, fc2]
6+
transformer_engine:
7+
LogFp8TensorStats:
8+
enabled: True
9+
tensors_struct:
10+
- tensor: activation
11+
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
12+
freq: 10
13+
- tensor: gradient
14+
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
15+
freq: 10
16+
- tensor: weight
17+
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
18+
freq: 10
19+
LogTensorStats:
20+
enabled: True
21+
stats: [max, min, mean, std, l1_norm]
22+
tensors: [dgrad, wgrad]
23+
freq: 1

bionemo-recipes/recipes/esm2_minifold_te/hydra_config/L0_sanity.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ component_precision:
7373
seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers)
7474
dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper
7575

76+
quant_stats_config:
77+
enabled: false
78+
quant_stats_file: ./fp8_debugging_stats.yaml
79+
quant_log_dir: ./log_quant_stats
80+
log_to_wandb: false
81+
7682
# Log every step for sanity check
7783
logger:
7884
frequency: 1

bionemo-recipes/recipes/esm2_minifold_te/hydra_config/defaults.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ component_precision:
8181
seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers)
8282
dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper
8383

84+
# Quantization stats logging (requires nvdlfw_inspect)
85+
quant_stats_config:
86+
enabled: false
87+
quant_stats_file: ./fp8_debugging_stats.yaml
88+
quant_log_dir: ./log_quant_stats
89+
log_to_wandb: false
90+
8491
# Logging
8592
logger:
8693
frequency: 100

bionemo-recipes/recipes/esm2_minifold_te/hydra_config/eval.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ component_precision:
4949
seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers)
5050
dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper
5151

52+
quant_stats_config:
53+
enabled: false
54+
quant_stats_file: ./fp8_debugging_stats.yaml
55+
quant_log_dir: ./log_quant_stats
56+
log_to_wandb: false
57+
5258
wandb_init_args:
5359
project: esm2_minifold_te
5460
name: eval_${now:%Y%m%d_%H%M%S}

bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ component_precision:
6464
seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers)
6565
dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper
6666

67+
quant_stats_config:
68+
enabled: false
69+
quant_stats_file: ./fp8_debugging_stats.yaml
70+
quant_log_dir: ./log_quant_stats
71+
log_to_wandb: false
72+
6773
logger:
6874
frequency: 5
6975

bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100_real.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ component_precision:
7171
seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers)
7272
dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper
7373

74+
quant_stats_config:
75+
enabled: false
76+
quant_stats_file: ./fp8_debugging_stats.yaml
77+
quant_log_dir: ./log_quant_stats
78+
log_to_wandb: false
79+
7480
logger:
7581
frequency: 5
7682

bionemo-recipes/recipes/esm2_minifold_te/hydra_config/run_100_real_3B.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ component_precision:
7070
seq_proj: false # Sequence/pair feature projections (fc_s_1, fc_s_2, fc_z_1, fc_z_2, seq_to_pair te.Linear layers)
7171
dist_head: false # Distogram output head (fc_out_1, fc_out_2 te.Linear layers) — kept in model's base precision per MXFP8 paper
7272

73+
quant_stats_config:
74+
enabled: false
75+
quant_stats_file: ./fp8_debugging_stats.yaml
76+
quant_log_dir: ./log_quant_stats
77+
log_to_wandb: false
78+
7379
logger:
7480
frequency: 5
7581

bionemo-recipes/recipes/esm2_minifold_te/perf_logger.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import logging
1919
import time
2020

21+
import nvdlfw_inspect.api as debug_api
2122
import torch
2223
import torchmetrics
2324
from omegaconf import DictConfig, OmegaConf
@@ -44,6 +45,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
4445

4546
self.min_loss = torch.tensor(float("inf"), device=torch.device(f"cuda:{dist_config.local_rank}"))
4647
self.logging_frequency = args.logger.frequency
48+
self.quant_stats_enabled = args.quant_stats_config.enabled
4749

4850
metrics_dict = {
4951
"train/loss": torchmetrics.MeanMetric(),
@@ -57,6 +59,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
5759
"train/contact_recall_8A": torchmetrics.MeanMetric(),
5860
"train/lddt_from_distogram": torchmetrics.MeanMetric(),
5961
"train/mean_distance_error": torchmetrics.MeanMetric(),
62+
"train/unpadded_tokens_per_sec": torchmetrics.MeanMetric(),
6063
}
6164

6265
self.metrics = torchmetrics.MetricCollection(metrics_dict)
@@ -75,6 +78,7 @@ def log_step(
7578
grad_norm: torch.Tensor | DTensor | float = 0.0,
7679
lr: float = 0.0,
7780
structure_metrics: dict[str, torch.Tensor] | None = None,
81+
unpadded_tokens: float = 0.0,
7882
):
7983
"""Log a training step."""
8084
with torch.no_grad():
@@ -95,6 +99,8 @@ def log_step(
9599
self.metrics["train/learning_rate"].update(lr)
96100
self.metrics["train/grad_norm"].update(grad_norm)
97101
self.metrics["train/step_time"].update(step_time)
102+
if unpadded_tokens > 0 and step_time > 0:
103+
self.metrics["train/unpadded_tokens_per_sec"].update(unpadded_tokens / step_time)
98104

99105
if structure_metrics is not None:
100106
for key, value in structure_metrics.items():
@@ -121,8 +127,13 @@ def log_step(
121127
if self._dist_config.local_rank == 0:
122128
logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()]))
123129

130+
if self.quant_stats_enabled:
131+
debug_api.step()
132+
124133
def finish(self):
125134
"""Finish the logger."""
135+
if self.quant_stats_enabled:
136+
debug_api.end_debug()
126137
if not self._dist_config.is_main_process():
127138
return
128139
wandb.finish()

0 commit comments

Comments
 (0)