Skip to content

Commit fecf461

Browse files
abrichrclaude
andauthored
fix: increase max_new_tokens to 2048 and make configurable via GRPOConfig (#62)
* fix: align GRPO prompt format with SFT training format The GRPO rollout prompt was missing the "Thought:" line and action history that the SFT training uses. Models fine-tuned via SFT output "Thought: ...\nAction: CLICK(...)" but the GRPO prompt didn't prompt for this format, causing verbose free-form output that couldn't be parsed → reward 0.0 → zero gradients. Changes: - Add "Thought:" and "Action:" prompt lines matching SFT format - Add action_history parameter for step context - Parser extracts action from "Action: ..." line before regex matching - Parser handles JSON format {"action_type": "click", "coordinate": [x,y]} - Debug logging of raw VLM output for zero-reward diagnosis Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: increase max_new_tokens to 2048 and make configurable The default of 100 tokens truncated reasoning models mid-thought, producing unparseable output → DONE → reward 0.0 → zero gradients. Caused 4 failed training runs (~20 GPU-hours wasted). - Add max_new_tokens to GRPOConfig (default 2048) - Use config value instead of hardcoded 100 - Add truncation warning when generation hits the limit Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 04e6e9f commit fecf461

2 files changed

Lines changed: 17 additions & 1 deletion

File tree

openadapt_ml/training/grpo/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,12 @@ class GRPOConfig:
8686
save_every_steps: int = 50
8787
output_dir: str = "checkpoints/grpo"
8888

89+
# Generation
90+
max_new_tokens: int = 2048 # Token budget per step. Reasoning models need
91+
# 1000+ tokens (thought + action). 100 truncates mid-reasoning → unparseable.
92+
93+
# Task configs
94+
task_dir: str | None = None # Directory of TaskConfig YAMLs for milestone rewards
95+
8996
# Stuck detection
9097
stuck_window: int = 3

openadapt_ml/training/grpo/trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def agent_fn(obs: Any) -> BenchmarkAction:
500500
with torch.no_grad():
501501
outputs = model.generate(
502502
**inputs,
503-
max_new_tokens=100,
503+
max_new_tokens=self._config.max_new_tokens,
504504
temperature=temperature,
505505
do_sample=True,
506506
)
@@ -517,6 +517,15 @@ def agent_fn(obs: Any) -> BenchmarkAction:
517517
except Exception:
518518
pass
519519

520+
# Warn if output was likely truncated (hit max_new_tokens)
521+
gen_len = outputs[0].shape[0] - inputs["input_ids"].shape[1]
522+
if gen_len >= self._config.max_new_tokens - 1:
523+
logger.warning(
524+
"Generation hit max_new_tokens=%d — output may be truncated. "
525+
"Increase config.max_new_tokens if actions aren't parsed.",
526+
self._config.max_new_tokens,
527+
)
528+
520529
action = _parse_vlm_output_to_action(decoded, screen_size=screen_size)
521530

522531
# Store raw VLM output for accurate loss computation

0 commit comments

Comments
 (0)