Skip to content

Commit d7896d5

Browse files
abrichrclaude
andauthored
feat: add DiagnosticsCallback and TRL robustness tests (#238)
- Add DiagnosticsCallback to trl_callbacks.py: logs loss, |loss|, grad_norm, reward in scientific notation (matches standalone trainer diagnostic output) - Register DiagnosticsCallback in trl_wrapper.py alongside TelemetryCallback - Add test_trl_robustness.py: 19 tests covering health check, corrupt screenshot retry, stuck detection, truncation warning, diagnostics callback, and empty rollout result shape Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent fb7e87f commit d7896d5

3 files changed

Lines changed: 690 additions & 7 deletions

File tree

openadapt_evals/integrations/trl_callbacks.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1-
"""TRL TrainerCallback implementations for telemetry and Weave tracing.
1+
"""TRL TrainerCallback implementations for telemetry, diagnostics, and Weave tracing.
22
33
Provides callbacks that integrate with TRL's GRPOTrainer to automatically
4-
track training events via our telemetry system and optionally log to Weave.
4+
track training events via our telemetry system, emit rich diagnostic logs
5+
matching the standalone trainer output, and optionally log to Weave.
56
67
Usage::
78
89
from trl import GRPOConfig, GRPOTrainer
9-
from openadapt_evals.integrations.trl_callbacks import TelemetryCallback
10+
from openadapt_evals.integrations.trl_callbacks import (
11+
DiagnosticsCallback,
12+
TelemetryCallback,
13+
)
1014
1115
trainer = GRPOTrainer(
1216
model=model,
1317
args=config,
14-
callbacks=[TelemetryCallback(model_name="Qwen/Qwen2.5-VL-7B-Instruct")],
18+
callbacks=[
19+
TelemetryCallback(model_name="Qwen/Qwen2.5-VL-7B-Instruct"),
20+
DiagnosticsCallback(),
21+
],
1522
...
1623
)
1724
trainer.train()
@@ -161,6 +168,42 @@ def on_train_end(
161168
logger.debug("Telemetry on_train_end failed: %s", exc)
162169

163170

171+
class DiagnosticsCallback:
172+
"""Rich training diagnostics matching standalone trainer output.
173+
174+
Emits per-step log lines with loss, |loss|, grad_norm, and reward in the
175+
same format as the standalone GRPOTrainer. This makes it easy for operators
176+
to monitor TRL-based training runs with the same tooling (grep, dashboards)
177+
used for the standalone path.
178+
179+
All values are read from TRL's ``state.log_history``. If a metric is
180+
missing, it defaults to 0.0.
181+
"""
182+
183+
def on_step_end(
184+
self,
185+
args: Any,
186+
state: Any,
187+
control: Any,
188+
**kwargs: Any,
189+
) -> None:
190+
"""Log diagnostic metrics at the end of each training step."""
191+
if not state.log_history:
192+
return
193+
latest = state.log_history[-1]
194+
loss = latest.get("loss", 0.0)
195+
grad_norm = latest.get("grad_norm", 0.0)
196+
reward = latest.get("reward", latest.get("reward_mean", 0.0))
197+
logger.info(
198+
"Step %d: loss=%+.2e |loss|=%.2e grad_norm=%.4f reward=%.4f",
199+
state.global_step,
200+
loss,
201+
abs(loss),
202+
grad_norm,
203+
reward,
204+
)
205+
206+
164207
# Register as a TrainerCallback subclass at import time so TRL recognizes it.
165208
# If transformers is installed, wrap with proper inheritance.
166209
# We can't patch __bases__ after the fact (Python doesn't allow it when
@@ -172,7 +215,12 @@ class _TelemetryCallbackWithBase(_TrainerCallback, TelemetryCallback):
172215
"""TelemetryCallback with proper TrainerCallback inheritance."""
173216
pass
174217

175-
# Replace the module-level name so imports get the subclass
218+
class _DiagnosticsCallbackWithBase(_TrainerCallback, DiagnosticsCallback):
219+
"""DiagnosticsCallback with proper TrainerCallback inheritance."""
220+
pass
221+
222+
# Replace the module-level names so imports get the subclasses
176223
TelemetryCallback = _TelemetryCallbackWithBase # type: ignore[misc]
224+
DiagnosticsCallback = _DiagnosticsCallbackWithBase # type: ignore[misc]
177225
except ImportError:
178-
pass # TelemetryCallback works as duck-typed callback without inheritance
226+
pass # Callbacks work as duck-typed without inheritance

openadapt_evals/training/trl_wrapper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,12 @@ def env_reward_fn(completions, **kwargs):
170170
callbacks = []
171171

172172
try:
173-
from openadapt_evals.integrations.trl_callbacks import TelemetryCallback
173+
from openadapt_evals.integrations.trl_callbacks import (
174+
DiagnosticsCallback,
175+
TelemetryCallback,
176+
)
174177
callbacks.append(TelemetryCallback())
178+
callbacks.append(DiagnosticsCallback())
175179
except ImportError:
176180
pass
177181

0 commit comments

Comments
 (0)