1- """Drop-in TRL-backed GRPOTrainer with the same API as the standalone trainer .
1+ """TRL-backed GRPO trainer with clean config separation .
22
3- Usage (identical to standalone trainer):
3+ Our TrainingConfig handles OpenAdapt-specific concerns (WAA server,
4+ task loading, constrained decoding, callbacks). TRL's GRPOConfig
5+ handles training concerns (learning rate, loss type, batch size).
6+ No duplication — each config owns its domain.
47
8+ Usage:
9+ from trl import GRPOConfig
510 from openadapt_evals.training.trl_wrapper import GRPOTrainer
611 from openadapt_evals.training.standalone.config import TrainingConfig
712
813 trainer = GRPOTrainer(
9- TrainingConfig(model_name="Qwen/Qwen3.5-9B", task_dir="tasks/"),
14+ TrainingConfig(
15+ task_dir="tasks/",
16+ server_url="http://localhost:5001",
17+ constrained_decoding=True,
18+ ),
19+ trl_config=GRPOConfig(
20+ output_dir="./checkpoints",
21+ loss_type="dapo",
22+ num_generations=4,
23+ learning_rate=5e-6,
24+ bf16=True,
25+ ),
1026 on_step_complete=my_logger,
1127 )
1228 trainer.train()
13-
14- Internally uses TRL's GRPOTrainer + rollout_func. Falls back to the
15- standalone trainer if TRL is not installed.
1629"""
1730
1831from __future__ import annotations
2437
2538
2639class GRPOTrainer :
27- """TRL-backed GRPO trainer with the standalone trainer's API.
28-
29- Same constructor signature: TrainingConfig + 4 callback hooks.
30- Same train() → str return (checkpoint path).
40+ """TRL-backed GRPO trainer.
41+
42+ Args:
43+ config: Our TrainingConfig — WAA server, task_dir, model loading,
44+ constrained decoding. Handles everything OpenAdapt-specific.
45+ trl_config: TRL's GRPOConfig — learning rate, loss type, batch
46+ size, gradient accumulation, vLLM, W&B reporting. Passed
47+ directly to TRL with zero translation. Optional — sensible
48+ defaults are used if omitted.
49+ on_model_loaded: ``(model, processor) -> None``
50+ on_before_collect: ``(task_id, env) -> None``
51+ on_rollout_complete: ``(rollout, index) -> None``
52+ on_step_complete: ``(step, rollouts, metrics) -> None``
3153 """
3254
3355 def __init__ (
3456 self ,
3557 config ,
3658 * ,
59+ trl_config = None ,
3760 on_model_loaded = None ,
3861 on_before_collect = None ,
3962 on_rollout_complete = None ,
4063 on_step_complete = None ,
4164 ):
4265 self ._config = config
66+ self ._trl_config = trl_config
4367 self ._on_model_loaded = on_model_loaded
4468 self ._on_before_collect = on_before_collect
4569 self ._on_rollout_complete = on_rollout_complete
4670 self ._on_step_complete = on_step_complete
4771
4872 def train (self ) -> str :
4973 """Run GRPO training via TRL. Returns path to final checkpoint."""
50- from pathlib import Path
51-
5274 from datasets import Dataset
5375 from trl import GRPOConfig , GRPOTrainer as _TRLTrainer
5476
5577 from openadapt_evals .task_config import TaskConfig
5678 from openadapt_evals .training .trl_rollout import make_waa_rollout_func
5779
58- # Load tasks
80+ # --- Tasks (from our config) ---
5981 task_configs = []
6082 if self ._config .task_dir :
6183 task_configs = TaskConfig .from_dir (self ._config .task_dir )
62- if self ._config .task_ids and not task_configs :
63- logger .warning ("task_ids set but no task_dir — using task_ids as prompts" )
64-
6584 if not task_configs :
6685 raise ValueError ("No tasks. Set task_dir in TrainingConfig." )
6786
@@ -70,20 +89,22 @@ def train(self) -> str:
7089 "task_id" : [tc .id for tc in task_configs ],
7190 })
7291
73- # Load model
74- try :
92+ # --- Model (from our config) ---
93+ if getattr ( self . _config , "use_unsloth" , False ) :
7594 from unsloth import FastVisionModel
7695 logger .info ("Loading with Unsloth: %s" , self ._config .model_name )
7796 model , processor = FastVisionModel .from_pretrained (
7897 self ._config .model_name ,
7998 load_in_4bit = self ._config .load_in_4bit ,
99+ fast_inference = True ,
100+ gpu_memory_utilization = 0.6 ,
80101 )
81102 model = FastVisionModel .get_peft_model (
82103 model , r = self ._config .lora_r , lora_alpha = self ._config .lora_alpha ,
83104 target_modules = ["q_proj" , "k_proj" , "v_proj" , "o_proj" ,
84105 "gate_proj" , "up_proj" , "down_proj" ],
85106 )
86- except ImportError :
107+ else :
87108 from openadapt_evals .training .standalone .model_loader import (
88109 load_model_and_processor ,
89110 )
@@ -92,16 +113,17 @@ def train(self) -> str:
92113 load_in_4bit = self ._config .load_in_4bit ,
93114 lora_r = self ._config .lora_r ,
94115 lora_alpha = self ._config .lora_alpha ,
95- lora_checkpoint = self ._config . lora_checkpoint ,
116+ lora_checkpoint = getattr ( self ._config , " lora_checkpoint" , None ) ,
96117 )
97118
98119 if self ._on_model_loaded :
99120 self ._on_model_loaded (model , processor )
100121
101- # Create rollout function
122+ # --- Rollout function (from our config) ---
102123 from openadapt_evals .adapters .waa .live import WAALiveAdapter , WAALiveConfig
103124 adapter = WAALiveAdapter (WAALiveConfig (
104125 server_url = self ._config .server_url ,
126+ evaluate_url = getattr (self ._config , "evaluate_url" , None ),
105127 ))
106128 rollout_func = make_waa_rollout_func (
107129 adapter = adapter ,
@@ -112,21 +134,19 @@ def train(self) -> str:
112134 temperature = self ._config .temperature ,
113135 )
114136
115- # Reward function
137+ # --- Reward ---
116138 def env_reward_fn (completions , ** kwargs ):
117139 return kwargs .get ("env_reward" , [0.0 ] * len (completions ))
118140
119- # Build callbacks
141+ # --- Callbacks ---
120142 callbacks = []
121143
122- # Telemetry callback
123144 try :
124145 from openadapt_evals .integrations .trl_callbacks import TelemetryCallback
125146 callbacks .append (TelemetryCallback ())
126147 except ImportError :
127148 pass
128149
129- # Map our callback hooks to TRL TrainerCallback
130150 if any ([self ._on_before_collect , self ._on_rollout_complete ,
131151 self ._on_step_complete ]):
132152 try :
@@ -137,11 +157,9 @@ def __init__(self, hooks):
137157 self ._hooks = hooks
138158
139159 def on_step_end (self , args , state , control , ** kwargs ):
140- if self ._hooks .get ("on_step_complete" ):
141- metrics = kwargs .get ("metrics" , {})
142- self ._hooks ["on_step_complete" ](
143- state .global_step , [], metrics ,
144- )
160+ fn = self ._hooks .get ("on_step_complete" )
161+ if fn :
162+ fn (state .global_step , [], kwargs .get ("metrics" , {}))
145163
146164 callbacks .append (HookBridge ({
147165 "on_before_collect" : self ._on_before_collect ,
@@ -151,28 +169,33 @@ def on_step_end(self, args, state, control, **kwargs):
151169 except ImportError :
152170 pass
153171
154- # Weave tracing
155- try :
156- from openadapt_evals .integrations .weave_integration import weave_init
157- weave_init ("openadapt-evals" )
158- except Exception :
159- pass
172+ # --- Weave tracing ---
173+ weave_project = getattr (self ._config , "weave_project" , "" )
174+ if weave_project :
175+ try :
176+ from openadapt_evals .integrations .weave_integration import weave_init
177+ weave_init (weave_project )
178+ except Exception :
179+ pass
160180
161- # TRL config
162- output_dir = self ._config .output_dir
163- trl_config = GRPOConfig (
164- output_dir = output_dir ,
165- num_generations = self ._config .num_rollouts_per_step ,
166- max_completion_length = self ._config .max_new_tokens ,
167- num_train_epochs = 1 ,
168- max_steps = self ._config .num_training_steps ,
169- learning_rate = self ._config .learning_rate ,
170- save_steps = self ._config .save_every_steps ,
171- logging_steps = 1 ,
172- bf16 = True ,
173- loss_type = "grpo" ,
174- )
181+ # --- TRL config: use provided or build sensible defaults ---
182+ if self ._trl_config is not None :
183+ trl_config = self ._trl_config
184+ else :
185+ trl_config = GRPOConfig (
186+ output_dir = self ._config .output_dir ,
187+ num_generations = self ._config .num_rollouts_per_step ,
188+ max_completion_length = self ._config .max_new_tokens ,
189+ max_steps = self ._config .num_training_steps ,
190+ learning_rate = self ._config .learning_rate ,
191+ save_steps = self ._config .save_every_steps ,
192+ logging_steps = 1 ,
193+ bf16 = True ,
194+ loss_type = "grpo" ,
195+ num_train_epochs = 1 ,
196+ )
175197
198+ # --- Train ---
176199 trainer = _TRLTrainer (
177200 model = model ,
178201 processing_class = processor ,
@@ -184,13 +207,13 @@ def on_step_end(self, args, state, control, **kwargs):
184207 )
185208
186209 logger .info (
187- "Starting TRL GRPO training : model=%s tasks=%d rollouts=%d steps=%d " ,
210+ "Starting TRL GRPO: model=%s tasks=%d output=%s loss=%s " ,
188211 self ._config .model_name , len (task_configs ),
189- self . _config . num_rollouts_per_step , self . _config . num_training_steps ,
212+ trl_config . output_dir , trl_config . loss_type ,
190213 )
191214
192215 trainer .train ()
193- trainer .save_model (output_dir )
216+ trainer .save_model (trl_config . output_dir )
194217
195- logger .info ("Training complete. Checkpoint: %s" , output_dir )
196- return output_dir
218+ logger .info ("Training complete. Checkpoint: %s" , trl_config . output_dir )
219+ return trl_config . output_dir
0 commit comments