1- """TRL TrainerCallback implementations for telemetry and Weave tracing.
1+ """TRL TrainerCallback implementations for telemetry, diagnostics, and Weave tracing.
22
33Provides callbacks that integrate with TRL's GRPOTrainer to automatically
4- track training events via our telemetry system and optionally log to Weave.
4+ track training events via our telemetry system, emit rich diagnostic logs
5+ matching the standalone trainer output, and optionally log to Weave.
56
67Usage::
78
89 from trl import GRPOConfig, GRPOTrainer
9- from openadapt_evals.integrations.trl_callbacks import TelemetryCallback
10+ from openadapt_evals.integrations.trl_callbacks import (
11+ DiagnosticsCallback,
12+ TelemetryCallback,
13+ )
1014
1115 trainer = GRPOTrainer(
1216 model=model,
1317 args=config,
14- callbacks=[TelemetryCallback(model_name="Qwen/Qwen2.5-VL-7B-Instruct")],
18+ callbacks=[
19+ TelemetryCallback(model_name="Qwen/Qwen2.5-VL-7B-Instruct"),
20+ DiagnosticsCallback(),
21+ ],
1522 ...
1623 )
1724 trainer.train()
@@ -161,6 +168,42 @@ def on_train_end(
161168 logger .debug ("Telemetry on_train_end failed: %s" , exc )
162169
163170
171+ class DiagnosticsCallback :
172+ """Rich training diagnostics matching standalone trainer output.
173+
174+ Emits per-step log lines with loss, |loss|, grad_norm, and reward in the
175+ same format as the standalone GRPOTrainer. This makes it easy for operators
176+ to monitor TRL-based training runs with the same tooling (grep, dashboards)
177+ used for the standalone path.
178+
179+ All values are read from TRL's ``state.log_history``. If a metric is
180+ missing, it defaults to 0.0.
181+ """
182+
183+ def on_step_end (
184+ self ,
185+ args : Any ,
186+ state : Any ,
187+ control : Any ,
188+ ** kwargs : Any ,
189+ ) -> None :
190+ """Log diagnostic metrics at the end of each training step."""
191+ if not state .log_history :
192+ return
193+ latest = state .log_history [- 1 ]
194+ loss = latest .get ("loss" , 0.0 )
195+ grad_norm = latest .get ("grad_norm" , 0.0 )
196+ reward = latest .get ("reward" , latest .get ("reward_mean" , 0.0 ))
197+ logger .info (
198+ "Step %d: loss=%+.2e |loss|=%.2e grad_norm=%.4f reward=%.4f" ,
199+ state .global_step ,
200+ loss ,
201+ abs (loss ),
202+ grad_norm ,
203+ reward ,
204+ )
205+
206+
164207# Register as a TrainerCallback subclass at import time so TRL recognizes it.
165208# If transformers is installed, wrap with proper inheritance.
166209# We can't patch __bases__ after the fact (Python doesn't allow it when
@@ -172,7 +215,12 @@ class _TelemetryCallbackWithBase(_TrainerCallback, TelemetryCallback):
172215 """TelemetryCallback with proper TrainerCallback inheritance."""
173216 pass
174217
175- # Replace the module-level name so imports get the subclass
218+ class _DiagnosticsCallbackWithBase (_TrainerCallback , DiagnosticsCallback ):
219+ """DiagnosticsCallback with proper TrainerCallback inheritance."""
220+ pass
221+
222+ # Replace the module-level names so imports get the subclasses
176223 TelemetryCallback = _TelemetryCallbackWithBase # type: ignore[misc]
224+ DiagnosticsCallback = _DiagnosticsCallbackWithBase # type: ignore[misc]
177225except ImportError :
178- pass # TelemetryCallback works as duck-typed callback without inheritance
226+ pass # Callbacks work as duck-typed without inheritance
0 commit comments