Skip to content

Add profiling callbacks and fix logger init ordering#689

Open
AkshitaB wants to merge 2 commits into
mainfrom
akshitab/callback-improvements
Open

Add profiling callbacks and fix logger init ordering#689
AkshitaB wants to merge 2 commits into
mainfrom
akshitab/callback-improvements

Conversation

@AkshitaB
Copy link
Copy Markdown
Contributor

Summary

A few general training-callback improvements.

New profiling callbacks (callbacks/profiler.py)

  • NvidiaProfilerCallback — wraps a window of training steps (startend, on configured ranks) in cudaProfilerStart/Stop plus emit_nvtx NVTX ranges, for capture under an external Nsight Systems session. (Distinct from the existing ProfilerCallback, which uses torch.profiler and writes its own Chrome trace; this one drives an external nsys capture range and annotates the timeline with PyTorch-op NVTX.)
  • TorchMemoryHistoryCallback — records CUDA memory allocation history and dumps a snapshot pickle (viewable at https://pytorch.org/memory_viz) for OOM/fragmentation debugging.

Logger init ordering (callbacks/wandb.py, callbacks/comet.py)

WandBCallback and CometCallback now declare a higher callback priority so they initialize before the checkpointer. Pre-train checkpoint saves can flush already-recorded metrics; if the external logger hasn't started yet those metrics are dropped. New test in callback_order_test.py covers the ordering and the wandb finalize-on-close lifecycle.

Speed monitor (callbacks/speed_monitor.py)

Adds a throughput/device/TFLOPs_per_GPU metric and an RTX PRO 6000 entry for peak-FLOPs / MFU estimation.

Notes

  • NvidiaProfilerCallback only produces output when the job is launched under nsys (e.g. nsys profile --capture-range=cudaProfilerApi ...); standalone it just toggles a capture range. Noted in its docstring.
  • Stray print() debugging in the profiler callbacks was converted to log calls.

Tests

pytest src/test/train/callbacks/ → 40 passed (incl. the 2 new ordering/lifecycle tests). make checks (isort/black/ruff/mypy) clean.

🤖 Generated with Claude Code

Add two general-purpose profiling callbacks:
- `NvidiaProfilerCallback`: wraps a window of steps in `cudaProfilerStart/Stop`
  + NVTX ranges for use with an external Nsight Systems session.
- `TorchMemoryHistoryCallback`: records CUDA memory allocation history and
  dumps a snapshot pickle (viewable at https://pytorch.org/memory_viz).

Give `WandBCallback` and `CometCallback` a higher callback `priority` so they
initialize before the checkpointer — pre-train checkpoint saves can flush
already-recorded metrics, which would otherwise be dropped if the external
logger hasn't started yet.

Also add a `throughput/device/TFLOPs_per_GPU` metric and an RTX PRO 6000
peak-FLOPs entry to `SpeedMonitorCallback`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@AkshitaB AkshitaB requested a review from farhatkevin May 27, 2026 05:11
@farhatkevin
Copy link
Copy Markdown
Contributor

@codex review

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 8b3e3ff6c2

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

of :data:`Trainer.metrics_collect_interval <olmo_core.train.Trainer.metrics_collect_interval>`.
"""

priority: ClassVar[int] = 3
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep Comet success notification after final checkpoint

When Comet is enabled and the run ends on a step that still needs a final checkpoint, raising this priority makes CometCallback.post_train() run before CheckpointerCallback.post_train() (the trainer iterates callbacks in priority order at trainer.py:759, and the checkpointer writes the final checkpoint in post_train). That means Comet can send a “completed successfully” notification before the final checkpoint save has actually succeeded; if that save then fails, users get a false success signal and no failure notification from this post-train path. Consider separating Comet initialization ordering from its post-train ordering, or deferring the success notification until after checkpointing.

Useful? React with 👍 / 👎.

@farhatkevin
Copy link
Copy Markdown
Contributor

@codex review. Check for profiler windows being off by one and profiler cleanup

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 8b3e3ff6c2

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".


def pre_load_batch(self):
if self.enabled and get_rank() in self.profile_ranks:
if self.step == self.start:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Start the NVIDIA profiler before the requested step

In this trainer, pre_load_batch() runs before fetching the next batch, while global_step is still the number of completed steps; the loop increments global_step only after the batch is yielded. With start=10, this branch therefore starts profiling after step 10 has already completed, so the captured window is steps 11..end rather than 10..end (and start == end starts after the stop point, so it never stops). Please start on start - 1 here or move the start hook to a point that runs after the step counter advances but before that step’s work.

Useful? React with 👍 / 👎.


def pre_load_batch(self):
if self.enabled and get_rank() in self.profile_ranks:
if self.step == self.start:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Start memory history before the requested step

This has the same off-by-one behavior as the NVIDIA profiler: pre_load_batch() sees the previous completed global_step, so start=10 begins recording only before loading step 11. The snapshot dumped at end will omit the requested start step, and start == end starts recording after the dump opportunity has passed. Please align the start condition with the trainer’s step timing so the configured window actually includes start.

Useful? React with 👍 / 👎.


def post_train_batch(self):
if self.enabled and get_rank() in self.profile_ranks:
if self.step == self.end and self._nvtx_ctx is not None:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Stop the NVIDIA profiler during callback cleanup

Cleanup is only performed when a batch reaches exactly end; if training is canceled, errors, or simply completes before end after the profiler has started, the trainer will call on_error()/close() but this callback never exits _nvtx_ctx or calls cudaProfilerStop(). For nsys --capture-range=cudaProfilerApi this can leave the capture range open and lose or corrupt the intended trace, so please stop the profiler from close()/on_error() whenever _nvtx_ctx is active.

Useful? React with 👍 / 👎.


def post_train_batch(self):
if self.enabled and get_rank() in self.profile_ranks:
if self.step == self.end:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Dump and disable memory history on early exit

The memory recorder is disabled and the snapshot is written only at the exact end step; if the profiled run hits an OOM/error or stops before end, Trainer._shutdown() calls callback close() but this callback does not dump the snapshot or call _record_memory_history(enabled=None). That loses the failure window this callback is meant to debug and leaves recording enabled until process teardown, so please add cleanup that dumps/disables when recording is active.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants