-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrl_wrapper.py
More file actions
282 lines (248 loc) · 11 KB
/
trl_wrapper.py
File metadata and controls
282 lines (248 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""TRL-backed GRPO trainer with clean config separation.
Our TrainingConfig handles OpenAdapt-specific concerns (WAA server,
task loading, constrained decoding, callbacks). TRL's GRPOConfig
handles training concerns (learning rate, loss type, batch size).
No duplication — each config owns its domain.
Usage:
from trl import GRPOConfig
from openadapt_evals.training.trl_wrapper import GRPOTrainer
from openadapt_evals.training.standalone.config import TrainingConfig
trainer = GRPOTrainer(
TrainingConfig(
task_dir="tasks/",
server_url="http://localhost:5001",
constrained_decoding=True,
),
trl_config=GRPOConfig(
output_dir="./checkpoints",
loss_type="dapo",
num_generations=4,
learning_rate=5e-6,
bf16=True,
),
on_step_complete=my_logger,
)
trainer.train()
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
class GRPOTrainer:
"""TRL-backed GRPO trainer.
Args:
config: Our TrainingConfig — WAA server, task_dir, model loading,
constrained decoding. Handles everything OpenAdapt-specific.
trl_config: TRL's GRPOConfig — learning rate, loss type, batch
size, gradient accumulation, vLLM, W&B reporting. Passed
directly to TRL with zero translation. Optional — sensible
defaults are used if omitted.
on_model_loaded: ``(model, processor) -> None``
on_before_collect: ``(task_id, env) -> None``
on_rollout_complete: ``(rollout, index) -> None``
on_step_complete: ``(step, rollouts, metrics) -> None``
"""
def __init__(
self,
config,
*,
trl_config=None,
on_model_loaded=None,
on_before_collect=None,
on_rollout_complete=None,
on_step_complete=None,
):
self._config = config
self._trl_config = trl_config
self._on_model_loaded = on_model_loaded
self._on_before_collect = on_before_collect
self._on_rollout_complete = on_rollout_complete
self._on_step_complete = on_step_complete
def train(self) -> str:
"""Run GRPO training via TRL. Returns path to final checkpoint."""
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer as _TRLTrainer
from openadapt_evals.task_config import TaskConfig
from openadapt_evals.training.trl_rollout import make_waa_rollout_func
# --- Tasks (from our config) ---
task_configs = []
if self._config.task_dir:
task_configs = TaskConfig.from_dir(self._config.task_dir)
# Filter by task_ids if specified — without this, ALL tasks from
# task_dir end up in the TRL dataset regardless of what the user
# requested. This was a critical bug: config had task_ids=["X"]
# but TRL was running unrelated tasks.
if getattr(self._config, "task_ids", None):
allowed = set(self._config.task_ids)
filtered = [tc for tc in task_configs if tc.id in allowed or tc.name in allowed]
if filtered:
task_configs = filtered
logger.info(
"Filtered tasks by task_ids: %d/%d tasks selected",
len(filtered), len(task_configs) + len(filtered) - len(filtered),
)
else:
logger.warning(
"task_ids=%s matched no tasks from task_dir=%s. "
"Available: %s. Using all tasks.",
self._config.task_ids, self._config.task_dir,
[tc.id for tc in task_configs],
)
if not task_configs:
raise ValueError("No tasks. Set task_dir in TrainingConfig.")
dataset = Dataset.from_dict({
"prompt": [tc.name for tc in task_configs],
"task_id": [tc.id for tc in task_configs],
})
# --- Model (from our config) ---
if getattr(self._config, "use_unsloth", False):
try:
from unsloth import FastVisionModel
except ImportError:
raise ImportError(
"use_unsloth=True but unsloth is not installed. "
"Install with: pip install openadapt-evals[unsloth]"
) from None
logger.info("Loading with Unsloth: %s", self._config.model_name)
model, processor = FastVisionModel.from_pretrained(
self._config.model_name,
load_in_4bit=self._config.load_in_4bit,
fast_inference=True,
gpu_memory_utilization=0.6,
)
model = FastVisionModel.get_peft_model(
model, r=self._config.lora_r, lora_alpha=self._config.lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
)
else:
from openadapt_evals.training.standalone.model_loader import (
load_model_and_processor,
)
model, processor = load_model_and_processor(
self._config.model_name,
load_in_4bit=self._config.load_in_4bit,
lora_r=self._config.lora_r,
lora_alpha=self._config.lora_alpha,
lora_checkpoint=getattr(self._config, "lora_checkpoint", None),
)
if self._on_model_loaded:
self._on_model_loaded(model, processor)
# --- Rollout function (from our config) ---
from openadapt_evals.adapters.waa.live import WAALiveAdapter, WAALiveConfig
adapter = WAALiveAdapter(WAALiveConfig(
server_url=self._config.server_url,
evaluate_url=getattr(self._config, "evaluate_url", None),
# Training-appropriate timeouts: fail fast, don't block the
# training loop. Benchmark defaults (180s, 3 retries) are for
# one-shot evaluation where thoroughness matters. Training does
# thousands of evaluations where speed matters.
evaluate_timeout=15.0,
evaluate_retries=1,
))
rollout_func = make_waa_rollout_func(
adapter=adapter,
task_configs=task_configs,
max_steps=self._config.max_steps_per_episode,
constrained_decoding=getattr(self._config, "constrained_decoding", False),
max_new_tokens=self._config.max_new_tokens,
temperature=self._config.temperature,
on_before_collect=self._on_before_collect,
on_rollout_complete=self._on_rollout_complete,
)
# --- Reward ---
def env_reward_fn(completions, **kwargs):
return kwargs.get("env_reward", [0.0] * len(completions))
# --- Callbacks ---
callbacks = []
try:
from openadapt_evals.integrations.trl_callbacks import (
DiagnosticsCallback,
TelemetryCallback,
)
callbacks.append(TelemetryCallback())
callbacks.append(DiagnosticsCallback())
except ImportError:
pass
# on_before_collect and on_rollout_complete are passed directly to
# make_waa_rollout_func (above) because TRL has no pre-rollout
# callback. Only on_step_complete maps to TRL's on_step_end.
if self._on_step_complete:
try:
from transformers import TrainerCallback
class HookBridge(TrainerCallback):
def __init__(self, on_step_complete):
self._on_step_complete = on_step_complete
def on_step_end(self, args, state, control, **kwargs):
if self._on_step_complete:
self._on_step_complete(
state.global_step, [],
kwargs.get("metrics", {}),
)
callbacks.append(HookBridge(self._on_step_complete))
except ImportError:
pass
# --- Weave tracing ---
weave_project = getattr(self._config, "weave_project", "")
if weave_project:
try:
from openadapt_evals.integrations.weave_integration import weave_init
weave_init(weave_project)
except Exception:
pass
# --- TRL config: use provided or build sensible defaults ---
# TRL constraints:
# - generation_batch_size must be divisible by num_generations
# - per_device_train_batch_size must be <= len(dataset)
#
# For RL with few tasks: set batch_size=1 (one unique prompt per
# step) and generation_batch_size=num_generations (satisfies the
# divisibility requirement). This produces exactly num_generations
# rollouts per step — matching the standalone trainer.
#
# Previous approach (batch_size=num_gen, padded dataset) caused
# 4× over-generation: 4 identical prompts × 4 generations = 16
# rollouts when only 4 were needed.
num_gen = self._config.num_rollouts_per_step
if self._trl_config is not None:
trl_config = self._trl_config
else:
trl_config = GRPOConfig(
output_dir=self._config.output_dir,
num_generations=num_gen,
max_completion_length=self._config.max_new_tokens,
max_steps=self._config.num_training_steps,
learning_rate=self._config.learning_rate,
save_steps=self._config.save_every_steps,
logging_steps=1,
bf16=True,
loss_type="grpo",
num_train_epochs=1,
per_device_train_batch_size=1,
# generation_batch_size must be divisible by num_generations.
# Setting it to num_generations satisfies the constraint
# while keeping batch_size=1 (one unique prompt per step).
generation_batch_size=num_gen,
)
# No dataset padding needed: with batch_size=1, even a single-task
# dataset works. TRL iterates one prompt per step, each getting
# num_generations rollouts via rollout_func.
# --- Train ---
trainer = _TRLTrainer(
model=model,
processing_class=processor,
args=trl_config,
train_dataset=dataset,
reward_funcs=[env_reward_fn],
rollout_func=rollout_func,
callbacks=callbacks,
)
logger.info(
"Starting TRL GRPO: model=%s tasks=%d output=%s loss=%s",
self._config.model_name, len(task_configs),
trl_config.output_dir, trl_config.loss_type,
)
trainer.train()
trainer.save_model(trl_config.output_dir)
logger.info("Training complete. Checkpoint: %s", trl_config.output_dir)
return trl_config.output_dir