Skip to content

Commit 1c23f0b

Browse files
abrichrclaude
andauthored
fix: clean config separation — our config + TRL's config, no duplication (#230)
TrainingConfig owns OpenAdapt concerns: model, task_dir, server_url, constrained_decoding, max_new_tokens, use_unsloth, weave_project. TRL's GRPOConfig owns training concerns: loss_type, learning_rate, batch_size, gradient_accumulation, vLLM, bf16, W&B reporting. The wrapper accepts both via trl_config kwarg: trainer = GRPOTrainer( TrainingConfig(task_dir="tasks/", constrained_decoding=True), trl_config=GRPOConfig(loss_type="dapo", num_generations=4), on_step_complete=my_logger, ) If trl_config is omitted, sensible defaults are built from TrainingConfig. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9759cdc commit 1c23f0b

2 files changed

Lines changed: 84 additions & 55 deletions

File tree

openadapt_evals/training/standalone/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class TrainingConfig:
4242
constrained_decoding: bool = False
4343

4444
server_url: str = "http://localhost:5001"
45+
evaluate_url: str | None = None
4546
task_ids: list[str] = field(default_factory=list)
4647
task_dir: str | None = None
4748
screen_size: tuple[int, int] = (1920, 1080)
@@ -51,3 +52,8 @@ class TrainingConfig:
5152
save_every_steps: int = 50
5253
output_dir: str = "checkpoints/grpo"
5354
eval_model: str = "gpt-4.1-mini"
55+
56+
# Use Unsloth for 90% VRAM reduction (requires pip install unsloth)
57+
use_unsloth: bool = False
58+
# Weave project for LLM tracing (empty = disabled)
59+
weave_project: str = ""
Lines changed: 78 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,31 @@
1-
"""Drop-in TRL-backed GRPOTrainer with the same API as the standalone trainer.
1+
"""TRL-backed GRPO trainer with clean config separation.
22
3-
Usage (identical to standalone trainer):
3+
Our TrainingConfig handles OpenAdapt-specific concerns (WAA server,
4+
task loading, constrained decoding, callbacks). TRL's GRPOConfig
5+
handles training concerns (learning rate, loss type, batch size).
6+
No duplication — each config owns its domain.
47
8+
Usage:
9+
from trl import GRPOConfig
510
from openadapt_evals.training.trl_wrapper import GRPOTrainer
611
from openadapt_evals.training.standalone.config import TrainingConfig
712
813
trainer = GRPOTrainer(
9-
TrainingConfig(model_name="Qwen/Qwen3.5-9B", task_dir="tasks/"),
14+
TrainingConfig(
15+
task_dir="tasks/",
16+
server_url="http://localhost:5001",
17+
constrained_decoding=True,
18+
),
19+
trl_config=GRPOConfig(
20+
output_dir="./checkpoints",
21+
loss_type="dapo",
22+
num_generations=4,
23+
learning_rate=5e-6,
24+
bf16=True,
25+
),
1026
on_step_complete=my_logger,
1127
)
1228
trainer.train()
13-
14-
Internally uses TRL's GRPOTrainer + rollout_func. Falls back to the
15-
standalone trainer if TRL is not installed.
1629
"""
1730

1831
from __future__ import annotations
@@ -24,44 +37,50 @@
2437

2538

2639
class GRPOTrainer:
27-
"""TRL-backed GRPO trainer with the standalone trainer's API.
28-
29-
Same constructor signature: TrainingConfig + 4 callback hooks.
30-
Same train() → str return (checkpoint path).
40+
"""TRL-backed GRPO trainer.
41+
42+
Args:
43+
config: Our TrainingConfig — WAA server, task_dir, model loading,
44+
constrained decoding. Handles everything OpenAdapt-specific.
45+
trl_config: TRL's GRPOConfig — learning rate, loss type, batch
46+
size, gradient accumulation, vLLM, W&B reporting. Passed
47+
directly to TRL with zero translation. Optional — sensible
48+
defaults are used if omitted.
49+
on_model_loaded: ``(model, processor) -> None``
50+
on_before_collect: ``(task_id, env) -> None``
51+
on_rollout_complete: ``(rollout, index) -> None``
52+
on_step_complete: ``(step, rollouts, metrics) -> None``
3153
"""
3254

3355
def __init__(
3456
self,
3557
config,
3658
*,
59+
trl_config=None,
3760
on_model_loaded=None,
3861
on_before_collect=None,
3962
on_rollout_complete=None,
4063
on_step_complete=None,
4164
):
4265
self._config = config
66+
self._trl_config = trl_config
4367
self._on_model_loaded = on_model_loaded
4468
self._on_before_collect = on_before_collect
4569
self._on_rollout_complete = on_rollout_complete
4670
self._on_step_complete = on_step_complete
4771

4872
def train(self) -> str:
4973
"""Run GRPO training via TRL. Returns path to final checkpoint."""
50-
from pathlib import Path
51-
5274
from datasets import Dataset
5375
from trl import GRPOConfig, GRPOTrainer as _TRLTrainer
5476

5577
from openadapt_evals.task_config import TaskConfig
5678
from openadapt_evals.training.trl_rollout import make_waa_rollout_func
5779

58-
# Load tasks
80+
# --- Tasks (from our config) ---
5981
task_configs = []
6082
if self._config.task_dir:
6183
task_configs = TaskConfig.from_dir(self._config.task_dir)
62-
if self._config.task_ids and not task_configs:
63-
logger.warning("task_ids set but no task_dir — using task_ids as prompts")
64-
6584
if not task_configs:
6685
raise ValueError("No tasks. Set task_dir in TrainingConfig.")
6786

@@ -70,20 +89,22 @@ def train(self) -> str:
7089
"task_id": [tc.id for tc in task_configs],
7190
})
7291

73-
# Load model
74-
try:
92+
# --- Model (from our config) ---
93+
if getattr(self._config, "use_unsloth", False):
7594
from unsloth import FastVisionModel
7695
logger.info("Loading with Unsloth: %s", self._config.model_name)
7796
model, processor = FastVisionModel.from_pretrained(
7897
self._config.model_name,
7998
load_in_4bit=self._config.load_in_4bit,
99+
fast_inference=True,
100+
gpu_memory_utilization=0.6,
80101
)
81102
model = FastVisionModel.get_peft_model(
82103
model, r=self._config.lora_r, lora_alpha=self._config.lora_alpha,
83104
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
84105
"gate_proj", "up_proj", "down_proj"],
85106
)
86-
except ImportError:
107+
else:
87108
from openadapt_evals.training.standalone.model_loader import (
88109
load_model_and_processor,
89110
)
@@ -92,16 +113,17 @@ def train(self) -> str:
92113
load_in_4bit=self._config.load_in_4bit,
93114
lora_r=self._config.lora_r,
94115
lora_alpha=self._config.lora_alpha,
95-
lora_checkpoint=self._config.lora_checkpoint,
116+
lora_checkpoint=getattr(self._config, "lora_checkpoint", None),
96117
)
97118

98119
if self._on_model_loaded:
99120
self._on_model_loaded(model, processor)
100121

101-
# Create rollout function
122+
# --- Rollout function (from our config) ---
102123
from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig
103124
adapter = WAALiveAdapter(WAALiveConfig(
104125
server_url=self._config.server_url,
126+
evaluate_url=getattr(self._config, "evaluate_url", None),
105127
))
106128
rollout_func = make_waa_rollout_func(
107129
adapter=adapter,
@@ -112,21 +134,19 @@ def train(self) -> str:
112134
temperature=self._config.temperature,
113135
)
114136

115-
# Reward function
137+
# --- Reward ---
116138
def env_reward_fn(completions, **kwargs):
117139
return kwargs.get("env_reward", [0.0] * len(completions))
118140

119-
# Build callbacks
141+
# --- Callbacks ---
120142
callbacks = []
121143

122-
# Telemetry callback
123144
try:
124145
from openadapt_evals.integrations.trl_callbacks import TelemetryCallback
125146
callbacks.append(TelemetryCallback())
126147
except ImportError:
127148
pass
128149

129-
# Map our callback hooks to TRL TrainerCallback
130150
if any([self._on_before_collect, self._on_rollout_complete,
131151
self._on_step_complete]):
132152
try:
@@ -137,11 +157,9 @@ def __init__(self, hooks):
137157
self._hooks = hooks
138158

139159
def on_step_end(self, args, state, control, **kwargs):
140-
if self._hooks.get("on_step_complete"):
141-
metrics = kwargs.get("metrics", {})
142-
self._hooks["on_step_complete"](
143-
state.global_step, [], metrics,
144-
)
160+
fn = self._hooks.get("on_step_complete")
161+
if fn:
162+
fn(state.global_step, [], kwargs.get("metrics", {}))
145163

146164
callbacks.append(HookBridge({
147165
"on_before_collect": self._on_before_collect,
@@ -151,28 +169,33 @@ def on_step_end(self, args, state, control, **kwargs):
151169
except ImportError:
152170
pass
153171

154-
# Weave tracing
155-
try:
156-
from openadapt_evals.integrations.weave_integration import weave_init
157-
weave_init("openadapt-evals")
158-
except Exception:
159-
pass
172+
# --- Weave tracing ---
173+
weave_project = getattr(self._config, "weave_project", "")
174+
if weave_project:
175+
try:
176+
from openadapt_evals.integrations.weave_integration import weave_init
177+
weave_init(weave_project)
178+
except Exception:
179+
pass
160180

161-
# TRL config
162-
output_dir = self._config.output_dir
163-
trl_config = GRPOConfig(
164-
output_dir=output_dir,
165-
num_generations=self._config.num_rollouts_per_step,
166-
max_completion_length=self._config.max_new_tokens,
167-
num_train_epochs=1,
168-
max_steps=self._config.num_training_steps,
169-
learning_rate=self._config.learning_rate,
170-
save_steps=self._config.save_every_steps,
171-
logging_steps=1,
172-
bf16=True,
173-
loss_type="grpo",
174-
)
181+
# --- TRL config: use provided or build sensible defaults ---
182+
if self._trl_config is not None:
183+
trl_config = self._trl_config
184+
else:
185+
trl_config = GRPOConfig(
186+
output_dir=self._config.output_dir,
187+
num_generations=self._config.num_rollouts_per_step,
188+
max_completion_length=self._config.max_new_tokens,
189+
max_steps=self._config.num_training_steps,
190+
learning_rate=self._config.learning_rate,
191+
save_steps=self._config.save_every_steps,
192+
logging_steps=1,
193+
bf16=True,
194+
loss_type="grpo",
195+
num_train_epochs=1,
196+
)
175197

198+
# --- Train ---
176199
trainer = _TRLTrainer(
177200
model=model,
178201
processing_class=processor,
@@ -184,13 +207,13 @@ def on_step_end(self, args, state, control, **kwargs):
184207
)
185208

186209
logger.info(
187-
"Starting TRL GRPO training: model=%s tasks=%d rollouts=%d steps=%d",
210+
"Starting TRL GRPO: model=%s tasks=%d output=%s loss=%s",
188211
self._config.model_name, len(task_configs),
189-
self._config.num_rollouts_per_step, self._config.num_training_steps,
212+
trl_config.output_dir, trl_config.loss_type,
190213
)
191214

192215
trainer.train()
193-
trainer.save_model(output_dir)
216+
trainer.save_model(trl_config.output_dir)
194217

195-
logger.info("Training complete. Checkpoint: %s", output_dir)
196-
return output_dir
218+
logger.info("Training complete. Checkpoint: %s", trl_config.output_dir)
219+
return trl_config.output_dir

0 commit comments

Comments
 (0)