diff --git a/CHANGELOG.md b/CHANGELOG.md index 1942f9a5c..ee3ae7fbe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,10 +14,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `PowerLR`, a power-law learning rate scheduler with linear warmup, power-decay phase (`lr = initial_lr * (current / warmup) ** b` for negative `b`, making the LR independent of the training horizon), and an optional linear decay tail. Registered as `"power_lr"`. - Added `ComposableScheduler`, a piecewise LR scheduler built from `ComposableSchedulerStage` segments (linear/cosine interpolation between endpoint LRs) on an absolute time axis. Registered as `"composable"`. Note: `ComposableScheduler` ignores the `t_max` passed to `get_lr` and emits a once-per-instance `UserWarning` to that effect. - Added `OverrideDecay`, a late-stage decay override usable on both `ComposableScheduler` and `SequentialScheduler` via an `override_decay` field. When `current >= override_decay.start`, the main schedule is interrupted mid-flight and the LR decays from the value the main schedule would have produced at `start` to a target LR over `duration` (linear or cosine). `SequentialScheduler` additionally warns that `t_max` is ignored once the override becomes active. +- Added `NvidiaProfilerCallback` (wraps a window of training steps in `cudaProfilerStart/Stop` + NVTX ranges for Nsight Systems) and `TorchMemoryHistoryCallback` (records CUDA memory history and dumps a snapshot pickle for https://pytorch.org/memory_viz). +- `SpeedMonitorCallback` now logs a `throughput/device/TFLOPs_per_GPU` metric and recognizes the RTX PRO 6000 device for peak-FLOPs / MFU estimation. ### Fixed +- `WandBCallback` and `CometCallback` now initialize before the checkpointer (via a higher callback `priority`) so that pre-train checkpoint saves no longer drop already-recorded metrics. - Fixed LM in-loop evaluator data-order drift across repeated runs by resetting loader bookkeeping before each pass and making deterministic reshuffling the default. - Fixed Qwen3 implementation to match HuggingFace by applying RoPE in the input dtype (bf16) rather than upcasting to fp32. - Fixed Beaker secret existence check to use the case-insensitive HTTP endpoint, avoiding spurious "secret not found" errors when secret names differ only in case. diff --git a/src/olmo_core/train/callbacks/__init__.py b/src/olmo_core/train/callbacks/__init__.py index 6286558c6..609b9e446 100644 --- a/src/olmo_core/train/callbacks/__init__.py +++ b/src/olmo_core/train/callbacks/__init__.py @@ -22,7 +22,11 @@ from .metric_saver import MetricSaverCallback from .model_merger import ModelMergeCallback from .monkey_patcher import MonkeyPatcherCallback -from .profiler import ProfilerCallback +from .profiler import ( + NvidiaProfilerCallback, + ProfilerCallback, + TorchMemoryHistoryCallback, +) from .sequence_length_scheduler import SequenceLengthSchedulerCallback from .slack_notifier import SlackNotificationSetting, SlackNotifierCallback from .speed_monitor import SpeedMonitorCallback @@ -46,6 +50,8 @@ "GPUMemoryMonitorCallback", "HFConverterCallback", "ProfilerCallback", + "NvidiaProfilerCallback", + "TorchMemoryHistoryCallback", "SlackNotifierCallback", "SlackNotificationSetting", "SequenceLengthSchedulerCallback", diff --git a/src/olmo_core/train/callbacks/comet.py b/src/olmo_core/train/callbacks/comet.py index ba2e8fc6b..45810f82d 100644 --- a/src/olmo_core/train/callbacks/comet.py +++ b/src/olmo_core/train/callbacks/comet.py @@ -1,7 +1,7 @@ import logging import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, cast from olmo_core.config import StrEnum from olmo_core.distributed.utils import get_rank @@ -57,6 +57,11 @@ class CometCallback(Callback): of :data:`Trainer.metrics_collect_interval `. """ + priority: ClassVar[int] = 3 + """ + Initialize before checkpointing, since pre-train checkpoint saves may flush metrics. + """ + enabled: bool = True """ Set to false to disable this callback. diff --git a/src/olmo_core/train/callbacks/profiler.py b/src/olmo_core/train/callbacks/profiler.py index 61c4421b9..63b597b3e 100644 --- a/src/olmo_core/train/callbacks/profiler.py +++ b/src/olmo_core/train/callbacks/profiler.py @@ -1,6 +1,9 @@ import logging +import os from contextlib import ExitStack -from dataclasses import dataclass +from dataclasses import dataclass, field + +import torch from olmo_core.distributed.parallel import ( get_cp_mesh, @@ -186,3 +189,104 @@ def _on_trace_ready(self, prof): prof.export_chrome_trace(str(trace_path)) final_path = self.trainer.persist_working_file(trace_path) log.info(f"Chrome trace saved to '{final_path}'") + + +@dataclass +class NvidiaProfilerCallback(Callback): + """ + Wraps a window of training steps in the NVIDIA profiler (``cudaProfilerStart/Stop`` plus + NVTX ranges), for use with Nsight Systems. Profiling runs from step :data:`start` to + :data:`end` on the configured ranks. + + .. note:: + This only produces output when the job is launched under an external Nsight Systems + session (e.g. ``nsys profile --capture-range=cudaProfilerApi ...``); on its own it just + toggles a capture range with nothing recording. + """ + + start: int = 10 + """ + The step at which to start profiling. + """ + end: int = 12 + """ + The step at which to stop profiling. + """ + enabled: bool = True + """ + Set to ``False`` to disable profiling. + """ + profile_ranks: list[int] = field(default_factory=lambda: [0]) + """ + The ranks to profile. + """ + + _nvtx_ctx = None + + def pre_load_batch(self): + if self.enabled and get_rank() in self.profile_ranks: + if self.step == self.start: + log.info(f"Starting NVIDIA profiler at rank={get_rank()} step={self.step}...") + torch.cuda.cudart().cudaProfilerStart() + self._nvtx_ctx = torch.autograd.profiler.emit_nvtx(record_shapes=True) + self._nvtx_ctx.__enter__() + + def post_train_batch(self): + if self.enabled and get_rank() in self.profile_ranks: + if self.step == self.end and self._nvtx_ctx is not None: + log.info(f"Stopping NVIDIA profiler at rank={get_rank()} step={self.step}...") + self._nvtx_ctx.__exit__(None, None, None) + self._nvtx_ctx = None + torch.cuda.cudart().cudaProfilerStop() + + +@dataclass +class TorchMemoryHistoryCallback(Callback): + """ + Records CUDA memory allocation history between steps :data:`start` and :data:`end` and + dumps a snapshot pickle (viewable at https://pytorch.org/memory_viz) on the configured ranks. + """ + + start: int = 10 + """ + The step at which to start recording memory history. + """ + end: int = 12 + """ + The step at which to stop recording and dump the snapshot. + """ + enabled: bool = True + """ + Set to ``False`` to disable profiling. + """ + profile_ranks: list[int] = field(default_factory=lambda: [0]) + """ + The ranks to profile. + """ + + max_entries: int = 500000 + """ + The maximum number of memory-history entries to record. + """ + + output_dir: str = "." + """ + Directory to write the snapshot pickle(s) to. + """ + + def pre_load_batch(self): + if self.enabled and get_rank() in self.profile_ranks: + if self.step == self.start: + log.info(f"Starting memory profiler at rank={get_rank()} step={self.step}...") + torch.cuda.memory._record_memory_history(max_entries=self.max_entries) + + def post_train_batch(self): + if self.enabled and get_rank() in self.profile_ranks: + if self.step == self.end: + log.info(f"Dumping memory profiler at rank={get_rank()} step={self.step}...") + os.makedirs(self.output_dir, exist_ok=True) + torch.cuda.memory._dump_snapshot( + os.path.join(self.output_dir, f"memsnapshot.{get_rank()}.pickle") + ) + torch.cuda.memory._record_memory_history(enabled=None) + log.info(f"Memory profiler stopped at rank={get_rank()} step={self.step}.") diff --git a/src/olmo_core/train/callbacks/speed_monitor.py b/src/olmo_core/train/callbacks/speed_monitor.py index 184e2677b..fa9aee592 100644 --- a/src/olmo_core/train/callbacks/speed_monitor.py +++ b/src/olmo_core/train/callbacks/speed_monitor.py @@ -108,6 +108,9 @@ def pre_train(self): elif "B200" in device_name: # data from https://www.nvidia.com/en-us/data-center/hgx/ self.device_peak_flops_per_second = int(4.5e15 * dense_correction) + elif "RTX PRO 6000" in device_name: + # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/NVIDIA-RTX-Blackwell-PRO-GPU-Architecture-v1.0.pdf + self.device_peak_flops_per_second = int(1008e12 * dense_correction) else: # for other GPU types, assume A100 # data from https://www.nvidia.com/en-us/data-center/a100/ self.device_peak_flops_per_second = int(624e12 * dense_correction) @@ -213,3 +216,4 @@ def post_step(self): self._mfu_avg = mfu_avg self.trainer.record_metric("throughput/device/MFU", mfu) self.trainer.record_metric("throughput/device/MFU (actual avg)", mfu_avg) + self.trainer.record_metric("throughput/device/TFLOPs_per_GPU", flops_ps / 1e12) diff --git a/src/olmo_core/train/callbacks/wandb.py b/src/olmo_core/train/callbacks/wandb.py index 86e3dc441..2a50be638 100644 --- a/src/olmo_core/train/callbacks/wandb.py +++ b/src/olmo_core/train/callbacks/wandb.py @@ -1,7 +1,7 @@ import logging import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional from olmo_core.distributed.utils import get_rank from olmo_core.exceptions import OLMoEnvironmentError @@ -30,6 +30,11 @@ class WandBCallback(Callback): of :data:`Trainer.metrics_collect_interval `. """ + priority: ClassVar[int] = 3 + """ + Initialize before checkpointing, since pre-train checkpoint saves may flush metrics. + """ + enabled: bool = True """ Set to false to disable this callback. @@ -153,6 +158,7 @@ def post_step(self): def on_error(self, exc: BaseException): del exc if self.enabled and get_rank() == 0 and self.run is not None: + log.warning("Finalizing failed W&B run...") self.finalize(exit_code=1) def close(self): diff --git a/src/test/train/callbacks/callback_order_test.py b/src/test/train/callbacks/callback_order_test.py new file mode 100644 index 000000000..581c075c6 --- /dev/null +++ b/src/test/train/callbacks/callback_order_test.py @@ -0,0 +1,37 @@ +from olmo_core.train.callbacks.checkpointer import CheckpointerCallback +from olmo_core.train.callbacks.comet import CometCallback +from olmo_core.train.callbacks.wandb import WandBCallback + + +def test_external_loggers_initialize_before_checkpointer(): + # Pre-train checkpoint saves can flush already-recorded metrics, so external + # metric sinks must be initialized before the checkpointer runs pre_train(). + assert WandBCallback.priority > CheckpointerCallback.priority + assert CometCallback.priority > CheckpointerCallback.priority + + +class MockWandB: + run = object() + + def __init__(self): + self.finish_calls = [] + + def finish(self, **kwargs): + self.finish_calls.append(kwargs) + + +def test_wandb_finalizes_on_close_not_post_train(): + # Final checkpoint saves happen from CheckpointerCallback.post_train(), after + # higher-priority callbacks have run their own post_train(). Keep W&B open + # until close(), because close() runs after final checkpoint metrics drain. + wandb = MockWandB() + callback = WandBCallback() + callback._wandb = wandb + + callback.post_train() + assert not callback.finalized + assert wandb.finish_calls == [] + + callback.close() + assert callback.finalized + assert wandb.finish_calls == [{"exit_code": 0, "quiet": True}]