Skip to content

Commit 807782a

Browse files
authored
Add wandb support for quant stats logging (#1526)
### Description <!-- Provide a detailed description of the changes in this PR --> #### Usage <!--- How does a user interact with the changed code --> ```python TODO: Add code snippet ``` ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests for bionemo2 - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2 - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. #### Triggering Code Rabbit AI Review To trigger a code review from code rabbit, comment on a pull request with one of these commands: - @coderabbitai review - Triggers a standard review - @coderabbitai full review - Triggers a comprehensive review See https://docs.coderabbit.ai/reference/review-commands for a full list of commands. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully --------- Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent 736d3a3 commit 807782a

4 files changed

Lines changed: 27 additions & 1 deletion

File tree

bionemo-recipes/recipes/esm2_native_te/checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ class AppState(Stateful):
358358
default_factory=lambda: StateDictOptions(
359359
full_state_dict=False,
360360
cpu_offload=True,
361+
strict=False,
361362
)
362363
)
363364

bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ quant_stats_config:
8787
enabled: false
8888
quant_stats_file: ./fp8_debugging_stats.yaml
8989
quant_log_dir: ./log_quant_stats
90+
log_to_wandb: false
9091

9192
# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
9293
fp8_layers: null

bionemo-recipes/recipes/esm2_native_te/quantization.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,26 @@
2020
from pathlib import Path
2121

2222
import yaml
23+
from nvdlfw_inspect.logging import BaseLogger
2324

2425

2526
logger = logging.getLogger(__name__)
2627

2728

29+
class WandBQuantLogger(BaseLogger):
30+
"""Forward nvdlfw_inspect quant stats to WandB as scalars.
31+
32+
Each stat is logged under the ``quant/`` prefix so it appears alongside
33+
training metrics (loss, perplexity, etc.) in a single WandB dashboard.
34+
"""
35+
36+
def log_scalar(self, name: str, value: float | int, iteration: int, **kwargs):
37+
"""Log a single quant stat to WandB."""
38+
import wandb
39+
40+
wandb.log({f"quant/{name}": value}, step=iteration)
41+
42+
2843
def generate_layer_regex(layer_numbers: list[int] | None) -> str:
2944
"""Generate a regex pattern to match specific layer numbers (1-indexed).
3045
@@ -99,6 +114,7 @@ def initialize_quant_stats_logging(
99114
quant_log_dir: str,
100115
rank: int,
101116
layer_precision: list[str | None],
117+
statistics_logger: BaseLogger | None = None,
102118
) -> None:
103119
"""Set up quantization stats logging via nvdlfw_inspect.
104120
@@ -111,6 +127,9 @@ def initialize_quant_stats_logging(
111127
rank: The global rank of this process.
112128
layer_precision: Per-layer precision list (0-indexed by position). Each element is
113129
``"fp8"``, ``"fp4"``, or ``None``.
130+
statistics_logger: Optional custom logger (e.g. :class:`WandBQuantLogger`) that receives
131+
every ``log_scalar`` call from the debug API. When provided together with
132+
``default_logging_enabled=True`` the file logger is kept as well.
114133
"""
115134
import nvdlfw_inspect.api as debug_api
116135
import transformer_engine
@@ -133,6 +152,7 @@ def initialize_quant_stats_logging(
133152
config_file=updated_config,
134153
feature_dirs=[te_features_dir],
135154
log_dir=rank_log_dir,
155+
statistics_logger=statistics_logger,
136156
default_logging_enabled=True,
137157
)
138158

bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from distributed_config import DistributedConfig
3535
from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
3636
from perf_logger import PerfLogger
37-
from quantization import initialize_quant_stats_logging, resolve_layer_precision
37+
from quantization import WandBQuantLogger, initialize_quant_stats_logging, resolve_layer_precision
3838
from scheduler import get_linear_schedule_with_warmup
3939

4040

@@ -82,11 +82,15 @@ def main(args: DictConfig) -> float | None:
8282
)
8383
config.layer_precision = layer_precision
8484
if args.quant_stats_config.enabled:
85+
wandb_logger = None
86+
if args.quant_stats_config.log_to_wandb and dist_config.is_main_process():
87+
wandb_logger = WandBQuantLogger()
8588
initialize_quant_stats_logging(
8689
quant_stats_file=args.quant_stats_config.quant_stats_file,
8790
quant_log_dir=args.quant_stats_config.quant_log_dir,
8891
rank=dist_config.rank,
8992
layer_precision=layer_precision,
93+
statistics_logger=wandb_logger,
9094
)
9195

9296
# Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.

0 commit comments

Comments
 (0)