Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `PowerLR`, a power-law learning rate scheduler with linear warmup, power-decay phase (`lr = initial_lr * (current / warmup) ** b` for negative `b`, making the LR independent of the training horizon), and an optional linear decay tail. Registered as `"power_lr"`.
- Added `ComposableScheduler`, a piecewise LR scheduler built from `ComposableSchedulerStage` segments (linear/cosine interpolation between endpoint LRs) on an absolute time axis. Registered as `"composable"`. Note: `ComposableScheduler` ignores the `t_max` passed to `get_lr` and emits a once-per-instance `UserWarning` to that effect.
- Added `OverrideDecay`, a late-stage decay override usable on both `ComposableScheduler` and `SequentialScheduler` via an `override_decay` field. When `current >= override_decay.start`, the main schedule is interrupted mid-flight and the LR decays from the value the main schedule would have produced at `start` to a target LR over `duration` (linear or cosine). `SequentialScheduler` additionally warns that `t_max` is ignored once the override becomes active.
- Added `NvidiaProfilerCallback` (wraps a window of training steps in `cudaProfilerStart/Stop` + NVTX ranges for Nsight Systems) and `TorchMemoryHistoryCallback` (records CUDA memory history and dumps a snapshot pickle for https://pytorch.org/memory_viz).
- `SpeedMonitorCallback` now logs a `throughput/device/TFLOPs_per_GPU` metric and recognizes the RTX PRO 6000 device for peak-FLOPs / MFU estimation.


### Fixed

- `WandBCallback` and `CometCallback` now initialize before the checkpointer (via a higher callback `priority`) so that pre-train checkpoint saves no longer drop already-recorded metrics.
- Fixed LM in-loop evaluator data-order drift across repeated runs by resetting loader bookkeeping before each pass and making deterministic reshuffling the default.
- Fixed Qwen3 implementation to match HuggingFace by applying RoPE in the input dtype (bf16) rather than upcasting to fp32.
- Fixed Beaker secret existence check to use the case-insensitive HTTP endpoint, avoiding spurious "secret not found" errors when secret names differ only in case.
Expand Down
8 changes: 7 additions & 1 deletion src/olmo_core/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from .metric_saver import MetricSaverCallback
from .model_merger import ModelMergeCallback
from .monkey_patcher import MonkeyPatcherCallback
from .profiler import ProfilerCallback
from .profiler import (
NvidiaProfilerCallback,
ProfilerCallback,
TorchMemoryHistoryCallback,
)
from .sequence_length_scheduler import SequenceLengthSchedulerCallback
from .slack_notifier import SlackNotificationSetting, SlackNotifierCallback
from .speed_monitor import SpeedMonitorCallback
Expand All @@ -46,6 +50,8 @@
"GPUMemoryMonitorCallback",
"HFConverterCallback",
"ProfilerCallback",
"NvidiaProfilerCallback",
"TorchMemoryHistoryCallback",
"SlackNotifierCallback",
"SlackNotificationSetting",
"SequenceLengthSchedulerCallback",
Expand Down
7 changes: 6 additions & 1 deletion src/olmo_core/train/callbacks/comet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, cast

from olmo_core.config import StrEnum
from olmo_core.distributed.utils import get_rank
Expand Down Expand Up @@ -57,6 +57,11 @@ class CometCallback(Callback):
of :data:`Trainer.metrics_collect_interval <olmo_core.train.Trainer.metrics_collect_interval>`.
"""

priority: ClassVar[int] = 3
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep Comet success notification after final checkpoint

When Comet is enabled and the run ends on a step that still needs a final checkpoint, raising this priority makes CometCallback.post_train() run before CheckpointerCallback.post_train() (the trainer iterates callbacks in priority order at trainer.py:759, and the checkpointer writes the final checkpoint in post_train). That means Comet can send a “completed successfully” notification before the final checkpoint save has actually succeeded; if that save then fails, users get a false success signal and no failure notification from this post-train path. Consider separating Comet initialization ordering from its post-train ordering, or deferring the success notification until after checkpointing.

Useful? React with 👍 / 👎.

"""
Initialize before checkpointing, since pre-train checkpoint saves may flush metrics.
"""

enabled: bool = True
"""
Set to false to disable this callback.
Expand Down
106 changes: 105 additions & 1 deletion src/olmo_core/train/callbacks/profiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import os
from contextlib import ExitStack
from dataclasses import dataclass
from dataclasses import dataclass, field

import torch

from olmo_core.distributed.parallel import (
get_cp_mesh,
Expand Down Expand Up @@ -186,3 +189,104 @@ def _on_trace_ready(self, prof):
prof.export_chrome_trace(str(trace_path))
final_path = self.trainer.persist_working_file(trace_path)
log.info(f"Chrome trace saved to '{final_path}'")


@dataclass
class NvidiaProfilerCallback(Callback):
"""
Wraps a window of training steps in the NVIDIA profiler (``cudaProfilerStart/Stop`` plus
NVTX ranges), for use with Nsight Systems. Profiling runs from step :data:`start` to
:data:`end` on the configured ranks.

.. note::
This only produces output when the job is launched under an external Nsight Systems
session (e.g. ``nsys profile --capture-range=cudaProfilerApi ...``); on its own it just
toggles a capture range with nothing recording.
"""

start: int = 10
"""
The step at which to start profiling.
"""
end: int = 12
"""
The step at which to stop profiling.
"""
enabled: bool = True
"""
Set to ``False`` to disable profiling.
"""
profile_ranks: list[int] = field(default_factory=lambda: [0])
"""
The ranks to profile.
"""

_nvtx_ctx = None

def pre_load_batch(self):
if self.enabled and get_rank() in self.profile_ranks:
if self.step == self.start:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Start the NVIDIA profiler before the requested step

In this trainer, pre_load_batch() runs before fetching the next batch, while global_step is still the number of completed steps; the loop increments global_step only after the batch is yielded. With start=10, this branch therefore starts profiling after step 10 has already completed, so the captured window is steps 11..end rather than 10..end (and start == end starts after the stop point, so it never stops). Please start on start - 1 here or move the start hook to a point that runs after the step counter advances but before that step’s work.

Useful? React with 👍 / 👎.

log.info(f"Starting NVIDIA profiler at rank={get_rank()} step={self.step}...")
torch.cuda.cudart().cudaProfilerStart()
self._nvtx_ctx = torch.autograd.profiler.emit_nvtx(record_shapes=True)
self._nvtx_ctx.__enter__()

def post_train_batch(self):
if self.enabled and get_rank() in self.profile_ranks:
if self.step == self.end and self._nvtx_ctx is not None:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Stop the NVIDIA profiler during callback cleanup

Cleanup is only performed when a batch reaches exactly end; if training is canceled, errors, or simply completes before end after the profiler has started, the trainer will call on_error()/close() but this callback never exits _nvtx_ctx or calls cudaProfilerStop(). For nsys --capture-range=cudaProfilerApi this can leave the capture range open and lose or corrupt the intended trace, so please stop the profiler from close()/on_error() whenever _nvtx_ctx is active.

Useful? React with 👍 / 👎.

log.info(f"Stopping NVIDIA profiler at rank={get_rank()} step={self.step}...")
self._nvtx_ctx.__exit__(None, None, None)
self._nvtx_ctx = None
torch.cuda.cudart().cudaProfilerStop()


@dataclass
class TorchMemoryHistoryCallback(Callback):
"""
Records CUDA memory allocation history between steps :data:`start` and :data:`end` and
dumps a snapshot pickle (viewable at https://pytorch.org/memory_viz) on the configured ranks.
"""

start: int = 10
"""
The step at which to start recording memory history.
"""
end: int = 12
"""
The step at which to stop recording and dump the snapshot.
"""
enabled: bool = True
"""
Set to ``False`` to disable profiling.
"""
profile_ranks: list[int] = field(default_factory=lambda: [0])
"""
The ranks to profile.
"""

max_entries: int = 500000
"""
The maximum number of memory-history entries to record.
"""

output_dir: str = "."
"""
Directory to write the snapshot pickle(s) to.
"""

def pre_load_batch(self):
if self.enabled and get_rank() in self.profile_ranks:
if self.step == self.start:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Start memory history before the requested step

This has the same off-by-one behavior as the NVIDIA profiler: pre_load_batch() sees the previous completed global_step, so start=10 begins recording only before loading step 11. The snapshot dumped at end will omit the requested start step, and start == end starts recording after the dump opportunity has passed. Please align the start condition with the trainer’s step timing so the configured window actually includes start.

Useful? React with 👍 / 👎.

log.info(f"Starting memory profiler at rank={get_rank()} step={self.step}...")
torch.cuda.memory._record_memory_history(max_entries=self.max_entries)

def post_train_batch(self):
if self.enabled and get_rank() in self.profile_ranks:
if self.step == self.end:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Dump and disable memory history on early exit

The memory recorder is disabled and the snapshot is written only at the exact end step; if the profiled run hits an OOM/error or stops before end, Trainer._shutdown() calls callback close() but this callback does not dump the snapshot or call _record_memory_history(enabled=None). That loses the failure window this callback is meant to debug and leaves recording enabled until process teardown, so please add cleanup that dumps/disables when recording is active.

Useful? React with 👍 / 👎.

log.info(f"Dumping memory profiler at rank={get_rank()} step={self.step}...")
os.makedirs(self.output_dir, exist_ok=True)
torch.cuda.memory._dump_snapshot(
os.path.join(self.output_dir, f"memsnapshot.{get_rank()}.pickle")
)
torch.cuda.memory._record_memory_history(enabled=None)
log.info(f"Memory profiler stopped at rank={get_rank()} step={self.step}.")
4 changes: 4 additions & 0 deletions src/olmo_core/train/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def pre_train(self):
elif "B200" in device_name:
# data from https://www.nvidia.com/en-us/data-center/hgx/
self.device_peak_flops_per_second = int(4.5e15 * dense_correction)
elif "RTX PRO 6000" in device_name:
# https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/NVIDIA-RTX-Blackwell-PRO-GPU-Architecture-v1.0.pdf
self.device_peak_flops_per_second = int(1008e12 * dense_correction)
else: # for other GPU types, assume A100
# data from https://www.nvidia.com/en-us/data-center/a100/
self.device_peak_flops_per_second = int(624e12 * dense_correction)
Expand Down Expand Up @@ -213,3 +216,4 @@ def post_step(self):
self._mfu_avg = mfu_avg
self.trainer.record_metric("throughput/device/MFU", mfu)
self.trainer.record_metric("throughput/device/MFU (actual avg)", mfu_avg)
self.trainer.record_metric("throughput/device/TFLOPs_per_GPU", flops_ps / 1e12)
8 changes: 7 additions & 1 deletion src/olmo_core/train/callbacks/wandb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional

from olmo_core.distributed.utils import get_rank
from olmo_core.exceptions import OLMoEnvironmentError
Expand Down Expand Up @@ -30,6 +30,11 @@ class WandBCallback(Callback):
of :data:`Trainer.metrics_collect_interval <olmo_core.train.Trainer.metrics_collect_interval>`.
"""

priority: ClassVar[int] = 3
"""
Initialize before checkpointing, since pre-train checkpoint saves may flush metrics.
"""

enabled: bool = True
"""
Set to false to disable this callback.
Expand Down Expand Up @@ -153,6 +158,7 @@ def post_step(self):
def on_error(self, exc: BaseException):
del exc
if self.enabled and get_rank() == 0 and self.run is not None:
log.warning("Finalizing failed W&B run...")
self.finalize(exit_code=1)

def close(self):
Expand Down
37 changes: 37 additions & 0 deletions src/test/train/callbacks/callback_order_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from olmo_core.train.callbacks.checkpointer import CheckpointerCallback
from olmo_core.train.callbacks.comet import CometCallback
from olmo_core.train.callbacks.wandb import WandBCallback


def test_external_loggers_initialize_before_checkpointer():
# Pre-train checkpoint saves can flush already-recorded metrics, so external
# metric sinks must be initialized before the checkpointer runs pre_train().
assert WandBCallback.priority > CheckpointerCallback.priority
assert CometCallback.priority > CheckpointerCallback.priority


class MockWandB:
run = object()

def __init__(self):
self.finish_calls = []

def finish(self, **kwargs):
self.finish_calls.append(kwargs)


def test_wandb_finalizes_on_close_not_post_train():
# Final checkpoint saves happen from CheckpointerCallback.post_train(), after
# higher-priority callbacks have run their own post_train(). Keep W&B open
# until close(), because close() runs after final checkpoint metrics drain.
wandb = MockWandB()
callback = WandBCallback()
callback._wandb = wandb

callback.post_train()
assert not callback.finalized
assert wandb.finish_calls == []

callback.close()
assert callback.finalized
assert wandb.finish_calls == [{"exit_code": 0, "quiet": True}]
Loading