Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions openadapt_evals/training/trl_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
from openadapt_evals.training.standalone.prompt import SYSTEM_PROMPT # noqa: E402

# ---------------------------------------------------------------------------
# Constrained decoding regex ported from standalone trainer
# Constrained decoding regex -- ported from standalone trainer
# ---------------------------------------------------------------------------
# Matches the ``Thought: <reasoning>\nAction: <action>`` format.
# All repetitions use unbounded quantifiers (+, *) instead of bounded ({1,N})
Expand All @@ -106,12 +106,12 @@
# ---------------------------------------------------------------------------
# When the model is SFT'd on JSON format (not DSL), switch constrained
# decoding to: outlines.json(model, _AgentOutput) instead of regex.
# This is NOT the default the default uses DSL regex (ACTION_REGEX).
# This is NOT the default -- the default uses DSL regex (ACTION_REGEX).
class _AgentOutput(BaseModel):
"""Pydantic schema for Outlines JSON-mode constrained decoding.

Use with: ``outlines.json(model, _AgentOutput)`` once the model has
been SFT'd on JSON action format. Currently unused default is DSL
been SFT'd on JSON action format. Currently unused -- default is DSL
regex via ACTION_REGEX.
"""

Expand Down Expand Up @@ -349,7 +349,7 @@ def _run_episode(
if step_result.done:
break

# Evaluate dense rewards if milestones, binary otherwise
# Evaluate -- dense rewards if milestones, binary otherwise
reward = env.evaluate_dense()

return prompt_ids, all_completion_ids, all_logprobs, reward
Expand All @@ -365,6 +365,8 @@ def make_waa_rollout_func(
screenshot_retries: int = 3,
screenshot_retry_delay: float = 1.0,
stuck_threshold: int = 3,
on_before_collect: Optional[Callable] = None,
on_rollout_complete: Optional[Callable] = None,
) -> Callable:
"""Create a TRL-compatible rollout_func for WAA environments.

Expand All @@ -389,6 +391,14 @@ def make_waa_rollout_func(
stuck_threshold: Number of consecutive identical screenshots before
breaking an episode early. Set to 0 to disable stuck detection.
Ported from the standalone trainer's WAADirect.is_stuck().
on_before_collect: ``(task_id, env) -> None`` callback fired before
each episode begins. Useful for health checks, logging, or
pre-rollout setup. A raised exception is caught and logged as
a warning (does not abort the episode).
on_rollout_complete: ``(rollout, index) -> None`` callback fired
after each episode completes. ``rollout`` is a dict with keys
``prompt``, ``task_id``, ``reward``, ``gen_idx``. A raised
exception is caught and logged as a warning.

Returns:
A callable suitable for GRPOTrainer(rollout_func=...).
Expand Down Expand Up @@ -423,10 +433,6 @@ def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
num_generations = getattr(trainer.args, "num_generations", 8)

# --- Pre-rollout health check (P0) ---
# Verify WAA server is responsive before committing GPU time to a
# full batch of rollouts. Ported from standalone trainer's
# _collect_group() which calls probe() before each group.
# Skip for mock adapters (unittest.mock.MagicMock or WAAMockAdapter).
_mod = getattr(type(adapter), "__module__", "") or ""
_name = type(adapter).__name__.lower()
_is_mock = "mock" in _name or "mock" in _mod
Expand Down Expand Up @@ -470,10 +476,6 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
"""Generate action tokens from screenshot + instruction."""
from PIL import Image

# --- Corrupt screenshot retry (P0) ---
# On Azure VMs with QEMU, ~1-5% of screenshots are corrupt.
# Retry with a brief delay rather than crashing the entire
# rollout. Ported from standalone trainer's _collect_rollout().
img = None
for attempt in range(screenshot_retries):
try:
Expand Down Expand Up @@ -507,7 +509,6 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
]},
]

# Tokenize with processor
import torch

text_input = processor.apply_chat_template(
Expand All @@ -524,14 +525,10 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
max_new_tokens=max_new_tokens,
temperature=temperature,
)
# Tokenize the decoded text to get token IDs
inner_tok = getattr(processor, "tokenizer", processor)
completion_ids = inner_tok.encode(
decoded, add_special_tokens=False,
)
# Outlines does not return per-token logprobs, so we
# return empty logprobs. TRL recomputes logprobs from
# the model during the training step anyway.
logprobs: list[float] = []

# Truncation warning — detect when output was cut off
Expand All @@ -551,7 +548,6 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
return_tensors="pt", padding=True,
).to(device)

# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
Expand All @@ -562,19 +558,16 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
output_scores=True,
)

# Extract completion tokens (everything after prompt)
prompt_len = inputs["input_ids"].shape[1]
completion_ids = outputs.sequences[0][prompt_len:].tolist()

# Compute per-token logprobs from scores
logprobs = []
if hasattr(outputs, "scores") and outputs.scores:
for i, score in enumerate(outputs.scores):
probs = torch.nn.functional.log_softmax(score[0], dim=-1)
if i < len(completion_ids):
logprobs.append(probs[completion_ids[i]].item())

# Decode text
text = processor.decode(completion_ids, skip_special_tokens=True)

# Truncation warning — detect when output was cut off
Expand All @@ -589,14 +582,24 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
return text, completion_ids, logprobs

for prompt in prompts:
# Find matching task config
tc = config_map.get(prompt)

for gen_idx in range(num_generations):
env = RLEnvironment(adapter, task_config=tc)

task_id = tc.id if tc else "default"

# --- on_before_collect callback ---
if on_before_collect is not None:
try:
on_before_collect(task_id, env)
except Exception as exc:
logger.warning(
"on_before_collect callback raised for "
"task_id=%s gen=%d: %s",
task_id, gen_idx, exc,
)

try:
p_ids, c_ids, lps, reward = _run_episode(
env, generate_fn, prompt, task_id, max_steps,
Expand All @@ -608,6 +611,25 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
)
p_ids, c_ids, lps, reward = [], [], [], 0.0

# --- on_rollout_complete callback ---
if on_rollout_complete is not None:
try:
on_rollout_complete(
{
"prompt": prompt,
"task_id": task_id,
"reward": reward,
"gen_idx": gen_idx,
},
gen_idx,
)
except Exception as exc:
logger.warning(
"on_rollout_complete callback raised for "
"task_id=%s gen=%d: %s",
task_id, gen_idx, exc,
)

all_prompt_ids.append(p_ids)
all_completion_ids.append(c_ids)
all_logprobs.append(lps)
Expand Down
28 changes: 15 additions & 13 deletions openadapt_evals/training/trl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def train(self) -> str:
constrained_decoding=getattr(self._config, "constrained_decoding", False),
max_new_tokens=self._config.max_new_tokens,
temperature=self._config.temperature,
on_before_collect=self._on_before_collect,
on_rollout_complete=self._on_rollout_complete,
)

# --- Reward ---
Expand All @@ -185,25 +187,25 @@ def env_reward_fn(completions, **kwargs):
except ImportError:
pass

if any([self._on_before_collect, self._on_rollout_complete,
self._on_step_complete]):
# on_before_collect and on_rollout_complete are passed directly to
# make_waa_rollout_func (above) because TRL has no pre-rollout
# callback. Only on_step_complete maps to TRL's on_step_end.
if self._on_step_complete:
try:
from transformers import TrainerCallback

class HookBridge(TrainerCallback):
def __init__(self, hooks):
self._hooks = hooks
def __init__(self, on_step_complete):
self._on_step_complete = on_step_complete

def on_step_end(self, args, state, control, **kwargs):
fn = self._hooks.get("on_step_complete")
if fn:
fn(state.global_step, [], kwargs.get("metrics", {}))

callbacks.append(HookBridge({
"on_before_collect": self._on_before_collect,
"on_rollout_complete": self._on_rollout_complete,
"on_step_complete": self._on_step_complete,
}))
if self._on_step_complete:
self._on_step_complete(
state.global_step, [],
kwargs.get("metrics", {}),
)

callbacks.append(HookBridge(self._on_step_complete))
except ImportError:
pass

Expand Down
55 changes: 55 additions & 0 deletions tests/test_trl_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,58 @@ def test_callback_fires_events(self):
with patch("openadapt_evals.telemetry.capture_event"):
cb.on_train_begin(args, state, control)
cb.on_step_end(args, state, control)


# ---------------------------------------------------------------------------
# Wrapper callback passthrough tests
# ---------------------------------------------------------------------------


class TestWrapperPassesCallbacks:
"""Verify GRPOTrainer passes on_before_collect and on_rollout_complete
through to make_waa_rollout_func, not into HookBridge."""

def test_wrapper_passes_callbacks_to_rollout_func(self):
"""Verify on_before_collect and on_rollout_complete are forwarded
to make_waa_rollout_func as keyword arguments.

Since train() has local imports of heavy dependencies (datasets, trl,
torch), we verify by inspecting the source code of train() to confirm
the kwargs are passed. This avoids needing GPU/torch in CI.
"""
from openadapt_evals.training.trl_wrapper import GRPOTrainer
from openadapt_evals.training.standalone.config import TrainingConfig
import inspect

before_fn = lambda task_id, env: None
complete_fn = lambda rollout, index: None

trainer = GRPOTrainer(
TrainingConfig(task_dir="tasks/"),
on_before_collect=before_fn,
on_rollout_complete=complete_fn,
)

# 1. Verify the stored callbacks match what was passed.
assert trainer._on_before_collect is before_fn
assert trainer._on_rollout_complete is complete_fn

# 2. Verify train() source passes callbacks to make_waa_rollout_func.
source = inspect.getsource(GRPOTrainer.train)
assert "on_before_collect=self._on_before_collect" in source, (
"train() must pass on_before_collect to make_waa_rollout_func"
)
assert "on_rollout_complete=self._on_rollout_complete" in source, (
"train() must pass on_rollout_complete to make_waa_rollout_func"
)

# 3. Verify HookBridge no longer stores these callbacks.
hookbridge_section = source.split("class HookBridge")[1].split(
"callbacks.append"
)[0] if "class HookBridge" in source else ""
assert "on_before_collect" not in hookbridge_section, (
"HookBridge should not store on_before_collect"
)
assert "on_rollout_complete" not in hookbridge_section, (
"HookBridge should not store on_rollout_complete"
)
Loading
Loading