-
Notifications
You must be signed in to change notification settings - Fork 251
Add profiling callbacks and fix logger init ordering #689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,9 @@ | ||
| import logging | ||
| import os | ||
| from contextlib import ExitStack | ||
| from dataclasses import dataclass | ||
| from dataclasses import dataclass, field | ||
|
|
||
| import torch | ||
|
|
||
| from olmo_core.distributed.parallel import ( | ||
| get_cp_mesh, | ||
|
|
@@ -186,3 +189,104 @@ def _on_trace_ready(self, prof): | |
| prof.export_chrome_trace(str(trace_path)) | ||
| final_path = self.trainer.persist_working_file(trace_path) | ||
| log.info(f"Chrome trace saved to '{final_path}'") | ||
|
|
||
|
|
||
| @dataclass | ||
| class NvidiaProfilerCallback(Callback): | ||
| """ | ||
| Wraps a window of training steps in the NVIDIA profiler (``cudaProfilerStart/Stop`` plus | ||
| NVTX ranges), for use with Nsight Systems. Profiling runs from step :data:`start` to | ||
| :data:`end` on the configured ranks. | ||
|
|
||
| .. note:: | ||
| This only produces output when the job is launched under an external Nsight Systems | ||
| session (e.g. ``nsys profile --capture-range=cudaProfilerApi ...``); on its own it just | ||
| toggles a capture range with nothing recording. | ||
| """ | ||
|
|
||
| start: int = 10 | ||
| """ | ||
| The step at which to start profiling. | ||
| """ | ||
| end: int = 12 | ||
| """ | ||
| The step at which to stop profiling. | ||
| """ | ||
| enabled: bool = True | ||
| """ | ||
| Set to ``False`` to disable profiling. | ||
| """ | ||
| profile_ranks: list[int] = field(default_factory=lambda: [0]) | ||
| """ | ||
| The ranks to profile. | ||
| """ | ||
|
|
||
| _nvtx_ctx = None | ||
|
|
||
| def pre_load_batch(self): | ||
| if self.enabled and get_rank() in self.profile_ranks: | ||
| if self.step == self.start: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In this trainer, Useful? React with 👍 / 👎. |
||
| log.info(f"Starting NVIDIA profiler at rank={get_rank()} step={self.step}...") | ||
| torch.cuda.cudart().cudaProfilerStart() | ||
| self._nvtx_ctx = torch.autograd.profiler.emit_nvtx(record_shapes=True) | ||
| self._nvtx_ctx.__enter__() | ||
|
|
||
| def post_train_batch(self): | ||
| if self.enabled and get_rank() in self.profile_ranks: | ||
| if self.step == self.end and self._nvtx_ctx is not None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Cleanup is only performed when a batch reaches exactly Useful? React with 👍 / 👎. |
||
| log.info(f"Stopping NVIDIA profiler at rank={get_rank()} step={self.step}...") | ||
| self._nvtx_ctx.__exit__(None, None, None) | ||
| self._nvtx_ctx = None | ||
| torch.cuda.cudart().cudaProfilerStop() | ||
|
|
||
|
|
||
| @dataclass | ||
| class TorchMemoryHistoryCallback(Callback): | ||
| """ | ||
| Records CUDA memory allocation history between steps :data:`start` and :data:`end` and | ||
| dumps a snapshot pickle (viewable at https://pytorch.org/memory_viz) on the configured ranks. | ||
| """ | ||
|
|
||
| start: int = 10 | ||
| """ | ||
| The step at which to start recording memory history. | ||
| """ | ||
| end: int = 12 | ||
| """ | ||
| The step at which to stop recording and dump the snapshot. | ||
| """ | ||
| enabled: bool = True | ||
| """ | ||
| Set to ``False`` to disable profiling. | ||
| """ | ||
| profile_ranks: list[int] = field(default_factory=lambda: [0]) | ||
| """ | ||
| The ranks to profile. | ||
| """ | ||
|
|
||
| max_entries: int = 500000 | ||
| """ | ||
| The maximum number of memory-history entries to record. | ||
| """ | ||
|
|
||
| output_dir: str = "." | ||
| """ | ||
| Directory to write the snapshot pickle(s) to. | ||
| """ | ||
|
|
||
| def pre_load_batch(self): | ||
| if self.enabled and get_rank() in self.profile_ranks: | ||
| if self.step == self.start: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This has the same off-by-one behavior as the NVIDIA profiler: Useful? React with 👍 / 👎. |
||
| log.info(f"Starting memory profiler at rank={get_rank()} step={self.step}...") | ||
| torch.cuda.memory._record_memory_history(max_entries=self.max_entries) | ||
|
|
||
| def post_train_batch(self): | ||
| if self.enabled and get_rank() in self.profile_ranks: | ||
| if self.step == self.end: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The memory recorder is disabled and the snapshot is written only at the exact Useful? React with 👍 / 👎. |
||
| log.info(f"Dumping memory profiler at rank={get_rank()} step={self.step}...") | ||
| os.makedirs(self.output_dir, exist_ok=True) | ||
| torch.cuda.memory._dump_snapshot( | ||
| os.path.join(self.output_dir, f"memsnapshot.{get_rank()}.pickle") | ||
| ) | ||
| torch.cuda.memory._record_memory_history(enabled=None) | ||
| log.info(f"Memory profiler stopped at rank={get_rank()} step={self.step}.") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| from olmo_core.train.callbacks.checkpointer import CheckpointerCallback | ||
| from olmo_core.train.callbacks.comet import CometCallback | ||
| from olmo_core.train.callbacks.wandb import WandBCallback | ||
|
|
||
|
|
||
| def test_external_loggers_initialize_before_checkpointer(): | ||
| # Pre-train checkpoint saves can flush already-recorded metrics, so external | ||
| # metric sinks must be initialized before the checkpointer runs pre_train(). | ||
| assert WandBCallback.priority > CheckpointerCallback.priority | ||
| assert CometCallback.priority > CheckpointerCallback.priority | ||
|
|
||
|
|
||
| class MockWandB: | ||
| run = object() | ||
|
|
||
| def __init__(self): | ||
| self.finish_calls = [] | ||
|
|
||
| def finish(self, **kwargs): | ||
| self.finish_calls.append(kwargs) | ||
|
|
||
|
|
||
| def test_wandb_finalizes_on_close_not_post_train(): | ||
| # Final checkpoint saves happen from CheckpointerCallback.post_train(), after | ||
| # higher-priority callbacks have run their own post_train(). Keep W&B open | ||
| # until close(), because close() runs after final checkpoint metrics drain. | ||
| wandb = MockWandB() | ||
| callback = WandBCallback() | ||
| callback._wandb = wandb | ||
|
|
||
| callback.post_train() | ||
| assert not callback.finalized | ||
| assert wandb.finish_calls == [] | ||
|
|
||
| callback.close() | ||
| assert callback.finalized | ||
| assert wandb.finish_calls == [{"exit_code": 0, "quiet": True}] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When Comet is enabled and the run ends on a step that still needs a final checkpoint, raising this priority makes
CometCallback.post_train()run beforeCheckpointerCallback.post_train()(the trainer iterates callbacks in priority order attrainer.py:759, and the checkpointer writes the final checkpoint inpost_train). That means Comet can send a “completed successfully” notification before the final checkpoint save has actually succeeded; if that save then fails, users get a false success signal and no failure notification from this post-train path. Consider separating Comet initialization ordering from its post-train ordering, or deferring the success notification until after checkpointing.Useful? React with 👍 / 👎.