Skip to content

Commit 7d095da

Browse files
abrichrclaude
andauthored
feat: add --task-dir support for milestone-based rewards in standalone GRPO trainer (#60)
* fix: include image placeholder in chat template for VLM GRPO Qwen2.5-VL requires <|image_pad|> tokens in the input. These are inserted by apply_chat_template only when messages include {"type": "image"} content blocks. Fixed both agent_fn and _compute_rollout_loss. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * feat: add --task-dir support for milestone-based rewards in standalone GRPO trainer - GRPOConfig: add task_dir field - reward.py: evaluate_milestones_screenshot() for client-side reward - trainer.py: load TaskConfigs, auto-populate task_ids, override rewards - rollout_collector.py: pass task_configs to env - No WAA evaluate endpoint needed — rewards computed via VLM judge Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6617e02 commit 7d095da

File tree

5 files changed

+254
-4
lines changed

5 files changed

+254
-4
lines changed

openadapt_ml/training/grpo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from openadapt_ml.training.grpo.reward import (
4848
binary_task_success,
4949
compute_group_advantages,
50+
evaluate_milestones_screenshot,
5051
)
5152
from openadapt_ml.training.grpo.rollout_collector import (
5253
GRPORolloutCollector,
@@ -86,6 +87,7 @@ def __getattr__(name: str):
8687
"Rollout",
8788
"binary_task_success",
8889
"compute_group_advantages",
90+
"evaluate_milestones_screenshot",
8991
"policy_gradient_loss",
9092
"grpo_loss",
9193
"parse_vlm_output_to_action",

openadapt_ml/training/grpo/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ class GRPOConfig:
3737
server_url: URL of the WAA server for live environment interaction.
3838
evaluate_url: URL of the evaluate server. If None, defaults to server_url.
3939
task_ids: List of WAA task IDs to train on.
40+
task_dir: Path to a directory of YAML task config files. When set,
41+
the trainer loads TaskConfig objects and uses milestone-based
42+
reward evaluation locally (no /evaluate endpoint needed).
43+
If task_ids is empty, task IDs are auto-populated from the
44+
loaded configs.
4045
learning_rate: Optimizer learning rate for LoRA parameter updates.
4146
num_training_steps: Total number of GRPO training steps (outer loop).
4247
save_every_steps: Checkpoint frequency.
@@ -69,6 +74,12 @@ class GRPOConfig:
6974
task_ids: list[str] = field(default_factory=list)
7075
screen_size: tuple[int, int] = (1920, 1080) # (width, height)
7176

77+
# Task configuration directory (YAML files with milestones for dense rewards).
78+
# When set, the trainer loads TaskConfig objects from this directory and
79+
# uses milestone-based reward evaluation locally, without needing the
80+
# WAA /evaluate endpoint. Requires openadapt-evals to be installed.
81+
task_dir: str | None = None
82+
7283
# Training
7384
learning_rate: float = 5e-6
7485
num_training_steps: int = 1000

openadapt_ml/training/grpo/reward.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,19 @@
66
GRPO computes advantages relative to the group mean rather than using
77
a learned value function, which is simpler and works well for sparse
88
binary rewards (task success/failure).
9+
10+
Also provides ``evaluate_milestones_screenshot``, a standalone utility
11+
that evaluates milestone-based rewards from a screenshot without needing
12+
the WAA /evaluate endpoint. This is the local-evaluation path used by
13+
the standalone GRPO trainer when ``--task-dir`` is set.
914
"""
1015

1116
from __future__ import annotations
1217

18+
import logging
19+
20+
logger = logging.getLogger(__name__)
21+
1322

1423
def binary_task_success(score: float, threshold: float = 0.5) -> float:
1524
"""Convert evaluator score to binary reward.
@@ -54,3 +63,90 @@ def compute_group_advantages(rewards: list[float]) -> list[float]:
5463
return [0.0] * n
5564

5665
return [(r - mean) / (std + eps) for r in rewards]
66+
67+
68+
def evaluate_milestones_screenshot(
69+
task_config: object,
70+
screenshot_bytes: bytes,
71+
vlm_model: str = "gpt-4.1-mini",
72+
vlm_provider: str = "openai",
73+
) -> float:
74+
"""Evaluate milestone-based rewards from a screenshot (no server needed).
75+
76+
Iterates over the milestones in a TaskConfig and evaluates each
77+
``screenshot``-type milestone using a VLM judge. Non-screenshot
78+
milestones are skipped (they require a live server).
79+
80+
This is a standalone utility that can be called independently of the
81+
trainer, e.g.::
82+
83+
from openadapt_ml.training.grpo.reward import evaluate_milestones_screenshot
84+
reward = evaluate_milestones_screenshot(task_config, screenshot_bytes)
85+
86+
Args:
87+
task_config: A ``TaskConfig`` instance (from ``openadapt_evals.task_config``).
88+
Must have a ``milestones`` attribute (list of ``Milestone`` objects).
89+
screenshot_bytes: PNG screenshot bytes to evaluate against.
90+
vlm_model: VLM model name for the judge.
91+
vlm_provider: VLM provider (``"openai"`` or ``"anthropic"``).
92+
93+
Returns:
94+
Fraction of screenshot milestones that passed (0.0 to 1.0).
95+
Returns 0.0 if there are no milestones or no screenshot milestones.
96+
"""
97+
milestones = getattr(task_config, "milestones", None)
98+
if not milestones:
99+
return 0.0
100+
101+
# Only evaluate screenshot-type milestones locally
102+
screenshot_milestones = [
103+
ms for ms in milestones
104+
if getattr(ms.check, "check", None) == "screenshot"
105+
]
106+
if not screenshot_milestones:
107+
return 0.0
108+
109+
try:
110+
from openadapt_evals.vlm_evaluator import vlm_judge
111+
except ImportError:
112+
logger.warning(
113+
"openadapt-evals is not installed; cannot evaluate screenshot "
114+
"milestones. Install with: pip install openadapt-evals"
115+
)
116+
return 0.0
117+
118+
passed = 0
119+
for ms in screenshot_milestones:
120+
description = getattr(ms.check, "description", None) or ""
121+
if not description:
122+
continue
123+
try:
124+
success, _confidence = vlm_judge(
125+
screenshot_bytes,
126+
description,
127+
model=vlm_model,
128+
provider=vlm_provider,
129+
)
130+
if success:
131+
passed += 1
132+
logger.debug(
133+
"Milestone '%s': %s",
134+
getattr(ms, "name", "?"),
135+
"PASS" if success else "FAIL",
136+
)
137+
except Exception as exc:
138+
logger.warning(
139+
"Milestone '%s' evaluation failed: %s",
140+
getattr(ms, "name", "?"),
141+
exc,
142+
)
143+
144+
total = len(screenshot_milestones)
145+
score = passed / total if total > 0 else 0.0
146+
logger.info(
147+
"Milestone evaluation: %d/%d screenshot milestones passed (%.2f)",
148+
passed,
149+
total,
150+
score,
151+
)
152+
return score

openadapt_ml/training/grpo/rollout_collector.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,27 @@ class GRPORolloutCollector:
6868
6969
Args:
7070
config: GRPO training configuration.
71+
task_configs: Optional dict mapping task_id -> TaskConfig. When
72+
provided, task configs are loaded into the RLEnvironment for
73+
milestone-based dense reward evaluation.
7174
7275
Raises:
7376
ImportError: If openadapt-evals is not installed.
7477
"""
7578

76-
def __init__(self, config: GRPOConfig) -> None:
79+
def __init__(
80+
self,
81+
config: GRPOConfig,
82+
task_configs: dict[str, Any] | None = None,
83+
) -> None:
7784
if RLEnvironment is None:
7885
raise ImportError(
7986
"openadapt-evals is required for rollout collection. "
8087
"Install it with: uv add openadapt-evals"
8188
)
8289

8390
self._config = config
91+
self._task_configs = task_configs or {}
8492
self._adapter = WAALiveAdapter(
8593
WAALiveConfig(
8694
server_url=config.server_url,
@@ -123,6 +131,11 @@ def collect_group(
123131

124132
rollouts: list[Rollout] = []
125133

134+
# Load task config into the environment for dense milestone rewards
135+
if task_id in self._task_configs:
136+
tc = self._task_configs[task_id]
137+
self._env.load_task_config(tc)
138+
126139
for i in range(self._config.num_rollouts_per_step):
127140
logger.info(
128141
"Collecting rollout %d/%d for task %s",

openadapt_ml/training/grpo/trainer.py

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,24 @@
4545

4646
from openadapt_ml.datasets.next_action import SYSTEM_PROMPT
4747
from openadapt_ml.training.grpo.config import GRPOConfig
48-
from openadapt_ml.training.grpo.reward import compute_group_advantages
48+
from openadapt_ml.training.grpo.reward import (
49+
compute_group_advantages,
50+
evaluate_milestones_screenshot,
51+
)
4952
from openadapt_ml.training.grpo.rollout_collector import (
5053
GRPORolloutCollector,
5154
Rollout,
5255
)
5356

57+
# Optional import for TaskConfig (openadapt-evals may not be installed)
58+
try:
59+
from openadapt_evals.task_config import TaskConfig
60+
61+
_HAS_TASK_CONFIG = True
62+
except ImportError:
63+
TaskConfig = None # type: ignore[assignment, misc]
64+
_HAS_TASK_CONFIG = False
65+
5466
logger = logging.getLogger(__name__)
5567

5668
DEFAULT_SCREEN_SIZE: tuple[int, int] = (1920, 1080)
@@ -301,6 +313,106 @@ def __init__(self, config: GRPOConfig) -> None:
301313
self._optimizer: Any = None
302314
self._collector: GRPORolloutCollector | None = None
303315
self._step: int = 0
316+
self._task_configs: dict[str, Any] = {}
317+
318+
# Load task configs from --task-dir if specified
319+
if config.task_dir:
320+
self._load_task_configs(config.task_dir)
321+
322+
def _load_task_configs(self, task_dir: str) -> None:
323+
"""Load TaskConfig YAMLs from a directory.
324+
325+
Populates ``self._task_configs`` (keyed by task ID) and auto-fills
326+
``config.task_ids`` if it was left empty.
327+
328+
Args:
329+
task_dir: Path to directory containing YAML/JSON task configs.
330+
331+
Raises:
332+
ImportError: If openadapt-evals is not installed.
333+
FileNotFoundError: If the directory does not exist.
334+
"""
335+
if not _HAS_TASK_CONFIG:
336+
raise ImportError(
337+
"openadapt-evals is required for --task-dir support. "
338+
"Install with: pip install openadapt-evals"
339+
)
340+
341+
task_dir_path = Path(task_dir)
342+
if not task_dir_path.is_dir():
343+
raise FileNotFoundError(f"Task directory not found: {task_dir}")
344+
345+
configs = TaskConfig.from_dir(str(task_dir_path))
346+
if not configs:
347+
raise ValueError(f"No task configs found in {task_dir}")
348+
349+
for tc in configs:
350+
self._task_configs[tc.id] = tc
351+
logger.info(
352+
"Loaded task config: %s (%s) — %d milestones",
353+
tc.id,
354+
tc.name[:50],
355+
len(tc.milestones),
356+
)
357+
358+
# Auto-populate task_ids if empty
359+
if not self._config.task_ids:
360+
self._config.task_ids = list(self._task_configs.keys())
361+
logger.info(
362+
"Auto-populated task_ids from task_dir: %s",
363+
self._config.task_ids,
364+
)
365+
366+
def _compute_milestone_reward(
367+
self,
368+
task_id: str,
369+
screenshot_bytes: bytes,
370+
) -> float:
371+
"""Compute milestone-based reward for a task using VLM judge.
372+
373+
Evaluates screenshot-type milestones locally without needing the
374+
WAA /evaluate endpoint. Falls back to 0.0 if the task has no
375+
milestones or the task_id is not found in loaded configs.
376+
377+
Args:
378+
task_id: The task ID to look up in loaded configs.
379+
screenshot_bytes: PNG screenshot bytes to evaluate.
380+
381+
Returns:
382+
Fraction of screenshot milestones passed (0.0 to 1.0).
383+
"""
384+
task_config = self._task_configs.get(task_id)
385+
if task_config is None:
386+
return 0.0
387+
return evaluate_milestones_screenshot(task_config, screenshot_bytes)
388+
389+
def _compute_milestone_reward_from_rollout(
390+
self,
391+
rollout: Rollout,
392+
) -> float | None:
393+
"""Extract the last screenshot from a rollout and compute milestone reward.
394+
395+
Returns None if no task config or no screenshot is available,
396+
signalling the caller to keep the existing reward.
397+
"""
398+
task_config = self._task_configs.get(rollout.task_id)
399+
if task_config is None or not getattr(task_config, "milestones", None):
400+
return None
401+
402+
# Find the last step with a screenshot
403+
screenshot_bytes: bytes | None = None
404+
for step in reversed(rollout.steps):
405+
obs = getattr(step, "observation", None)
406+
if obs is not None:
407+
ss = getattr(obs, "screenshot", None)
408+
if ss:
409+
screenshot_bytes = ss
410+
break
411+
412+
if not screenshot_bytes:
413+
return None
414+
415+
return evaluate_milestones_screenshot(task_config, screenshot_bytes)
304416

305417
def _make_agent_fn(self) -> Callable:
306418
"""Create agent closure: observation -> BenchmarkAction.
@@ -381,20 +493,26 @@ def train(self) -> str:
381493
if not self._config.task_ids:
382494
raise ValueError(
383495
"config.task_ids must be non-empty. Provide at least one "
384-
"WAA task ID to train on."
496+
"WAA task ID to train on, or use --task-dir to load from "
497+
"YAML files."
385498
)
386499

387500
logger.info("Starting GRPO training")
388501
logger.info(" Model: %s", self._config.model_name)
389502
logger.info(" Tasks: %s", self._config.task_ids)
503+
logger.info(" Task dir: %s", self._config.task_dir or "(none)")
504+
logger.info(" Task configs loaded: %d", len(self._task_configs))
390505
logger.info(" Rollouts/step: %d", self._config.num_rollouts_per_step)
391506
logger.info(" Training steps: %d", self._config.num_training_steps)
392507

393508
# Setup
394509
self._model, self._processor = _load_model_and_processor(self._config)
395510
trainable = [p for p in self._model.parameters() if p.requires_grad]
396511
self._optimizer = torch.optim.AdamW(trainable, lr=self._config.learning_rate)
397-
self._collector = GRPORolloutCollector(self._config)
512+
self._collector = GRPORolloutCollector(
513+
self._config,
514+
task_configs=self._task_configs if self._task_configs else None,
515+
)
398516

399517
Path(self._config.output_dir).mkdir(parents=True, exist_ok=True)
400518
agent_fn = self._make_agent_fn()
@@ -409,6 +527,16 @@ def train(self) -> str:
409527
self._model.eval()
410528
rollouts = self._collector.collect_group(agent_fn=agent_fn, task_id=task_id)
411529

530+
# If task configs with milestones are loaded, override the
531+
# binary rewards with milestone-based dense rewards.
532+
if self._task_configs:
533+
for rollout in rollouts:
534+
milestone_reward = self._compute_milestone_reward_from_rollout(
535+
rollout
536+
)
537+
if milestone_reward is not None:
538+
rollout.reward = max(rollout.reward, milestone_reward)
539+
412540
# Train (gradient update)
413541
self._model.train()
414542
metrics = self._training_step(rollouts)

0 commit comments

Comments
 (0)