Skip to content

Commit 4ec7d51

Browse files
abrichrclaude
andauthored
fix: constrained decoding cache bug, task rotation, add trainer tests (#199)
Constrained decoding: - Remove (.|\n)* prefix from action regex — Outlines can't compile it into a DFA efficiently. Model must output action directly. - Fix cache sentinel: use False for failure (not []) so subsequent calls correctly return None instead of empty logits_processor list. Prior bug: [] cached as "success" → model generated unconstrained. - Upgrade warning to error level for visibility. Task rotation: - Fix _load_task_configs: check `not task_ids` once BEFORE the loop (was checking inside loop — only first task ever appended). Tests (21 new): - TestActionRegex: 8 valid actions match, 6 invalid texts rejected - TestConstrainedDecodingCache: sentinel logic, regression for [] bug - TestTaskRotation: all tasks loaded, explicit ids preserved, rotation Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent eb7b158 commit 4ec7d51

2 files changed

Lines changed: 189 additions & 18 deletions

File tree

openadapt_evals/training/standalone/trainer.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,59 +90,70 @@ def __init__(
9090

9191
# --- Constrained decoding -------------------------------------------
9292

93-
# Regex that matches ALL valid action formats. Allows a free-form
94-
# "Thought: ..." prefix (the model's chain-of-thought) followed by
95-
# exactly one action. Outlines converts this to a token-level DFA.
93+
# Regex matching valid action formats. No free-text prefix — the
94+
# model MUST output an action as its very first token. This is
95+
# intentional: constrained decoding forces structured output.
96+
# If the model needs chain-of-thought, disable constrained_decoding
97+
# and rely on prompt instructions instead.
9698
_ACTION_REGEX = (
97-
r"(.|\n)*" # allow any Thought prefix
98-
r"(CLICK\(x=0\.\d{1,3},\s*y=0\.\d{1,3}\)"
99+
r"CLICK\(x=0\.\d{1,3},\s*y=0\.\d{1,3}\)"
99100
r'|TYPE\(text="[^"]{0,200}"\)'
100101
r"|WAIT\(\)"
101-
r"|DONE\(\))"
102+
r"|DONE\(\)"
102103
)
104+
# Sentinel: None = not yet attempted, list = success, False = failed
103105
_constrained_processor_cache: Any = None
104106

105107
def _get_constrained_logits_processor(self) -> list | None:
106108
"""Build an Outlines RegexLogitsProcessor for the action format.
107109
108110
Returns a ``[LogitsProcessor]`` list suitable for passing to
109111
``model.generate(logits_processor=...)``, or ``None`` if Outlines
110-
is not installed.
112+
is not installed or compilation fails.
111113
112114
The processor is cached after first creation (the DFA compilation
113115
is expensive — ~2 seconds — but only happens once).
114116
"""
115-
if self._constrained_processor_cache is not None:
117+
# Already attempted and failed
118+
if self._constrained_processor_cache is False:
119+
return None
120+
# Already compiled successfully
121+
if isinstance(self._constrained_processor_cache, list):
116122
return self._constrained_processor_cache
117123

118124
try:
119125
from outlines.processors import RegexLogitsProcessor
126+
tokenizer = (
127+
self._processor.tokenizer
128+
if hasattr(self._processor, "tokenizer")
129+
else self._processor
130+
)
120131
processor = RegexLogitsProcessor(
121132
self._ACTION_REGEX,
122-
tokenizer=self._processor.tokenizer
123-
if hasattr(self._processor, "tokenizer")
124-
else self._processor,
133+
tokenizer=tokenizer,
125134
)
126135
self._constrained_processor_cache = [processor]
127136
logger.info(
128137
"Outlines constrained decoding enabled "
129-
"(action format regex compiled)"
138+
"(action format regex compiled successfully)"
130139
)
131140
return self._constrained_processor_cache
132141
except ImportError:
133-
logger.warning(
142+
logger.error(
134143
"constrained_decoding=True but 'outlines' is not installed. "
135144
"Install with: pip install outlines>=0.1.0"
136145
)
137-
self._constrained_processor_cache = [] # don't retry
146+
self._constrained_processor_cache = False
138147
return None
139148
except Exception as exc:
140-
logger.warning(
149+
logger.error(
141150
"Outlines RegexLogitsProcessor creation failed: %s. "
142-
"Falling back to unconstrained generation.",
151+
"Falling back to unconstrained generation. "
152+
"This may be a tokenizer compatibility issue — try "
153+
"updating outlines: pip install -U outlines",
143154
exc,
144155
)
145-
self._constrained_processor_cache = []
156+
self._constrained_processor_cache = False
146157
return None
147158

148159
# --- Task loading -----------------------------------------------------
@@ -156,9 +167,10 @@ def _load_task_configs(self) -> None:
156167
if not task_dir.exists():
157168
logger.warning("Task dir not found: %s", task_dir)
158169
return
170+
auto_populate = not self._config.task_ids
159171
for tc in TaskConfig.from_dir(str(task_dir)):
160172
self._task_configs[tc.id] = tc
161-
if not self._config.task_ids:
173+
if auto_populate:
162174
self._config.task_ids.append(tc.id)
163175
logger.info("Loaded %d task configs from %s", len(self._task_configs), task_dir)
164176

tests/test_standalone_trainer.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""Tests for the standalone GRPO trainer.
2+
3+
Covers constrained decoding logic, task rotation, and config handling.
4+
No GPU or WAA server required — tests use mocks.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import re
10+
11+
import pytest
12+
13+
from openadapt_evals.training.standalone.config import TrainingConfig
14+
from openadapt_evals.training.standalone.trainer import GRPOTrainer
15+
16+
17+
# ---------------------------------------------------------------------------
18+
# Action regex tests
19+
# ---------------------------------------------------------------------------
20+
21+
22+
class TestActionRegex:
23+
"""Verify the action format regex matches valid actions and rejects junk."""
24+
25+
regex = GRPOTrainer._ACTION_REGEX
26+
27+
@pytest.mark.parametrize(
28+
"action",
29+
[
30+
"CLICK(x=0.50, y=0.30)",
31+
"CLICK(x=0.0, y=0.0)",
32+
"CLICK(x=0.999, y=0.123)",
33+
'TYPE(text="hello world")',
34+
'TYPE(text="")',
35+
'TYPE(text="notepad")',
36+
"WAIT()",
37+
"DONE()",
38+
],
39+
)
40+
def test_valid_actions_match(self, action: str) -> None:
41+
assert re.match(self.regex, action), f"Expected match: {action!r}"
42+
43+
@pytest.mark.parametrize(
44+
"text",
45+
[
46+
"** Let me think about this...",
47+
"1. Analyze the user's goal",
48+
"The user wants to open Task Manager",
49+
"",
50+
"CLICK",
51+
"click(0.5, 0.3)",
52+
],
53+
)
54+
def test_invalid_text_rejected(self, text: str) -> None:
55+
assert not re.match(self.regex, text), f"Should NOT match: {text!r}"
56+
57+
58+
# ---------------------------------------------------------------------------
59+
# Constrained decoding cache tests
60+
# ---------------------------------------------------------------------------
61+
62+
63+
class TestConstrainedDecodingCache:
64+
"""Test the caching logic for the Outlines logits processor."""
65+
66+
def test_cache_starts_as_none(self) -> None:
67+
config = TrainingConfig()
68+
trainer = GRPOTrainer(config)
69+
assert trainer._constrained_processor_cache is None
70+
71+
def test_failed_cache_returns_none(self) -> None:
72+
"""When compilation fails, subsequent calls return None (not [])."""
73+
config = TrainingConfig(constrained_decoding=True)
74+
trainer = GRPOTrainer(config)
75+
# Simulate a failed compilation
76+
trainer._constrained_processor_cache = False
77+
result = trainer._get_constrained_logits_processor()
78+
assert result is None
79+
80+
def test_successful_cache_returns_list(self) -> None:
81+
"""When compilation succeeds, subsequent calls return the list."""
82+
config = TrainingConfig(constrained_decoding=True)
83+
trainer = GRPOTrainer(config)
84+
# Simulate a successful compilation
85+
trainer._constrained_processor_cache = ["mock_processor"]
86+
result = trainer._get_constrained_logits_processor()
87+
assert result == ["mock_processor"]
88+
89+
def test_empty_list_no_longer_caches_as_success(self) -> None:
90+
"""Regression test: empty list [] should NOT be treated as success.
91+
92+
Prior bug: failure cached [] which is truthy for `is not None`,
93+
causing subsequent calls to return [] (no processors applied).
94+
"""
95+
config = TrainingConfig(constrained_decoding=True)
96+
trainer = GRPOTrainer(config)
97+
# The old buggy behavior would cache [] on failure
98+
# Verify the sentinel is False (not []) for failures
99+
trainer._constrained_processor_cache = False
100+
assert trainer._get_constrained_logits_processor() is None
101+
# And [] is actually a valid success cache (with a processor in it)
102+
trainer._constrained_processor_cache = ["real_processor"]
103+
assert trainer._get_constrained_logits_processor() == ["real_processor"]
104+
105+
106+
# ---------------------------------------------------------------------------
107+
# Task rotation tests
108+
# ---------------------------------------------------------------------------
109+
110+
111+
class TestTaskRotation:
112+
"""Test that all tasks from task_dir are loaded, not just the first."""
113+
114+
def test_all_tasks_loaded_from_dir(self, tmp_path) -> None:
115+
"""Create multiple task YAMLs and verify all are loaded."""
116+
import yaml
117+
118+
for i in range(3):
119+
task = {
120+
"name": f"Task {i}",
121+
"id": f"task-{i}",
122+
"setup": [],
123+
"evaluate": [{"check": "screenshot", "description": "done"}],
124+
}
125+
(tmp_path / f"task_{i}.yaml").write_text(yaml.dump(task))
126+
127+
config = TrainingConfig(task_dir=str(tmp_path))
128+
trainer = GRPOTrainer(config)
129+
trainer._load_task_configs()
130+
131+
assert len(config.task_ids) == 3
132+
assert set(config.task_ids) == {"task-0", "task-1", "task-2"}
133+
134+
def test_explicit_task_ids_not_overwritten(self, tmp_path) -> None:
135+
"""When task_ids is set explicitly, task_dir doesn't override it."""
136+
import yaml
137+
138+
for i in range(3):
139+
task = {"name": f"Task {i}", "id": f"task-{i}", "setup": [], "evaluate": []}
140+
(tmp_path / f"task_{i}.yaml").write_text(yaml.dump(task))
141+
142+
config = TrainingConfig(
143+
task_dir=str(tmp_path),
144+
task_ids=["task-1"], # explicit
145+
)
146+
trainer = GRPOTrainer(config)
147+
trainer._load_task_configs()
148+
149+
# Should keep the explicit list, not auto-populate
150+
assert config.task_ids == ["task-1"]
151+
# But task_configs should still have all 3 loaded (for setup/eval)
152+
assert len(trainer._task_configs) == 3
153+
154+
def test_task_rotation_in_training_loop(self) -> None:
155+
"""Verify step % len(task_ids) produces rotation."""
156+
task_ids = ["a", "b", "c"]
157+
num_steps = 9
158+
selected = [task_ids[step % len(task_ids)] for step in range(num_steps)]
159+
assert selected == ["a", "b", "c", "a", "b", "c", "a", "b", "c"]

0 commit comments

Comments
 (0)