Skip to content

Commit aca70a2

Browse files
abrichrclaude
andauthored
fix: use Outlines Generator API instead of logits_processor kwarg (#204)
Outlines v1.2 does NOT work by passing a processor to model.generate(logits_processor=[...]). It uses its own Generator: model = outlines.from_transformers(hf_model, hf_processor) gen = outlines.Generator(model, outlines.regex(pattern)) output = gen(prompt, max_new_tokens=512) The Generator wraps the model and handles tokenization, constrained generation, and decoding internally. Prior approach compiled the processor successfully but it was never actually applied to generation. Also fixes max_tokens → max_new_tokens (transformers kwarg name). Tests (35, all pass in 0.09s): - test_outlines_api_imports: verifies from_transformers, regex, Generator - test_outlines_regex_compiles: verifies action regex compiles - test_outlines_generator_api_contract: verifies Generator and SteerableGenerator signatures match what the trainer calls - No slow model download — API contract checks only Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d3d8c8a commit aca70a2

2 files changed

Lines changed: 131 additions & 110 deletions

File tree

openadapt_evals/training/standalone/trainer.py

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -110,68 +110,55 @@ def __init__(
110110
_ACTION_REGEX = (
111111
r"Thought: [^\n]+\nAction: (" + _ACTION_RE + r")"
112112
)
113-
# Sentinel: None = not yet attempted, list = success, False = failed
114-
_constrained_processor_cache: Any = None
113+
# Cached outlines Generator (created once, reused for all generate calls)
114+
# None = not yet attempted, False = failed, Generator = success
115+
_outlines_generator: Any = None
115116

116-
def _get_constrained_logits_processor(self) -> list | None:
117-
"""Build an Outlines RegexLogitsProcessor for the action format.
117+
def _get_outlines_generator(self) -> Any | None:
118+
"""Build an Outlines Generator for constrained generation.
118119
119-
Returns a ``[LogitsProcessor]`` list suitable for passing to
120-
``model.generate(logits_processor=...)``, or ``None`` if Outlines
121-
is not installed or compilation fails.
120+
Outlines v1.2 uses its own Generator API — NOT model.generate()
121+
with a logits_processor kwarg. The Generator wraps the model and
122+
handles tokenization, generation, and decoding internally.
122123
123-
The processor is cached after first creation (the DFA compilation
124-
is expensive — ~2 seconds — but only happens once).
124+
Returns the Generator, or None if creation fails.
125125
"""
126-
# Already attempted and failed
127-
if self._constrained_processor_cache is False:
126+
if self._outlines_generator is False:
128127
return None
129-
# Already compiled successfully
130-
if isinstance(self._constrained_processor_cache, list):
131-
return self._constrained_processor_cache
128+
if self._outlines_generator is not None:
129+
return self._outlines_generator
132130

133131
try:
134-
# Outlines v1.2+ API:
135-
# 1. Wrap HF model+tokenizer in outlines.Transformers
136-
# 2. Call get_regex_logits_processor(None, wrapped, regex)
137-
# The processor is then passed to model.generate(logits_processor=[p])
138-
from outlines import Transformers
139-
from outlines.generator import get_regex_logits_processor
140-
141-
raw_tokenizer = (
142-
self._processor.tokenizer
143-
if hasattr(self._processor, "tokenizer")
144-
else self._processor
145-
)
146-
wrapped_model = Transformers(self._model, raw_tokenizer)
147-
processor = get_regex_logits_processor(
148-
None, # use default backend
149-
wrapped_model,
150-
self._ACTION_REGEX,
132+
import outlines
133+
134+
wrapped_model = outlines.from_transformers(
135+
self._model, self._processor,
151136
)
152-
self._constrained_processor_cache = [processor]
137+
constraint = outlines.regex(self._ACTION_REGEX)
138+
generator = outlines.Generator(wrapped_model, constraint)
139+
140+
self._outlines_generator = generator
153141
logger.info(
154142
"Outlines constrained decoding enabled "
155-
"(regex compiled via %s, processor=%s)",
143+
"(model=%s, regex compiled successfully)",
156144
type(wrapped_model).__name__,
157-
type(processor).__name__,
158145
)
159-
return self._constrained_processor_cache
146+
return generator
160147
except ImportError:
161148
logger.error(
162149
"constrained_decoding=True but 'outlines' is not installed. "
163150
"Install with: uv sync --extra training"
164151
)
165-
self._constrained_processor_cache = False
152+
self._outlines_generator = False
166153
return None
167154
except Exception as exc:
168155
logger.error(
169-
"Outlines logits processor creation failed: %s. "
156+
"Outlines Generator creation failed: %s. "
170157
"Falling back to unconstrained generation. "
171158
"Try: uv pip install -U outlines",
172159
exc,
173160
)
174-
self._constrained_processor_cache = False
161+
self._outlines_generator = False
175162
return None
176163

177164
# --- Task loading -----------------------------------------------------
@@ -221,23 +208,43 @@ def _collect_rollout(self, task_id: str, instruction: str) -> Rollout:
221208
else:
222209
text_input = messages[-1]["content"]
223210

224-
inputs = self._processor(text=[text_input], images=[image], return_tensors="pt")
225-
inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
226-
with torch.no_grad():
227-
generate_kwargs: dict[str, Any] = dict(
211+
# --- Generation: constrained (Outlines) or unconstrained (HF) ---
212+
outlines_gen = (
213+
self._get_outlines_generator()
214+
if self._config.constrained_decoding
215+
else None
216+
)
217+
if outlines_gen is not None:
218+
# Outlines v1.2 Generator API: handles tokenization,
219+
# generation, and decoding internally. For multimodal
220+
# models, pass a dict with "text" + image keys.
221+
model_input = {"text": text_input, "images": [image]}
222+
decoded = outlines_gen(
223+
model_input,
228224
max_new_tokens=self._config.max_new_tokens,
229225
temperature=self._config.temperature,
230-
do_sample=True,
231226
)
232-
# Constrained decoding: force output to match the
233-
# action format regex, eliminating unparseable output.
234-
if self._config.constrained_decoding:
235-
logits_proc = self._get_constrained_logits_processor()
236-
if logits_proc is not None:
237-
generate_kwargs["logits_processor"] = logits_proc
238-
outputs = self._model.generate(**inputs, **generate_kwargs)
239-
decoded = self._processor.decode(
240-
outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
227+
gen_len = len(self._processor.tokenizer.encode(
228+
decoded, add_special_tokens=False,
229+
)) if decoded else 0
230+
else:
231+
# Standard HF generate (no constrained decoding)
232+
inputs = self._processor(
233+
text=[text_input], images=[image], return_tensors="pt",
234+
)
235+
inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
236+
with torch.no_grad():
237+
outputs = self._model.generate(
238+
**inputs,
239+
max_new_tokens=self._config.max_new_tokens,
240+
temperature=self._config.temperature,
241+
do_sample=True,
242+
)
243+
decoded = self._processor.decode(
244+
outputs[0][inputs["input_ids"].shape[1]:],
245+
skip_special_tokens=True,
246+
)
247+
gen_len = outputs[0].shape[0] - inputs["input_ids"].shape[1]
241248
gen_len = outputs[0].shape[0] - inputs["input_ids"].shape[1]
242249
action = parse_vlm_output_to_action(decoded, screen_size=self._config.screen_size)
243250

tests/test_standalone_trainer.py

Lines changed: 72 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -106,98 +106,112 @@ def test_no_bounded_quantifiers_in_regex(self) -> None:
106106
class TestConstrainedDecodingCache:
107107
"""Test the caching logic for the Outlines logits processor."""
108108

109-
def test_cache_starts_as_none(self) -> None:
109+
def test_generator_cache_starts_as_none(self) -> None:
110110
config = TrainingConfig()
111111
trainer = GRPOTrainer(config)
112-
assert trainer._constrained_processor_cache is None
112+
assert trainer._outlines_generator is None
113113

114-
def test_failed_cache_returns_none(self) -> None:
115-
"""When compilation fails, subsequent calls return None (not [])."""
114+
def test_failed_generator_returns_none(self) -> None:
115+
"""When creation fails, subsequent calls return None."""
116116
config = TrainingConfig(constrained_decoding=True)
117117
trainer = GRPOTrainer(config)
118-
# Simulate a failed compilation
119-
trainer._constrained_processor_cache = False
120-
result = trainer._get_constrained_logits_processor()
118+
trainer._outlines_generator = False
119+
result = trainer._get_outlines_generator()
121120
assert result is None
122121

123-
def test_successful_cache_returns_list(self) -> None:
124-
"""When compilation succeeds, subsequent calls return the list."""
122+
def test_successful_generator_returns_cached(self) -> None:
123+
"""When creation succeeds, subsequent calls return the cached generator."""
125124
config = TrainingConfig(constrained_decoding=True)
126125
trainer = GRPOTrainer(config)
127-
# Simulate a successful compilation
128-
trainer._constrained_processor_cache = ["mock_processor"]
129-
result = trainer._get_constrained_logits_processor()
130-
assert result == ["mock_processor"]
126+
trainer._outlines_generator = "mock_generator"
127+
result = trainer._get_outlines_generator()
128+
assert result == "mock_generator"
131129

132130
def test_outlines_api_imports(self) -> None:
133131
"""Verify the outlines API the trainer depends on is importable.
134132
135133
The trainer uses:
136-
- outlines.Transformers (model wrapper)
137-
- outlines.generator.get_regex_logits_processor (factory)
134+
- outlines.from_transformers (model wrapper factory)
135+
- outlines.regex (constraint factory)
136+
- outlines.Generator (generation with constraints)
138137
"""
139138
try:
140139
import outlines # noqa: F401
141140
except ImportError:
142141
pytest.skip("outlines not installed")
143142

144-
from outlines import Transformers
145-
from outlines.generator import get_regex_logits_processor
146-
assert callable(Transformers)
147-
assert callable(get_regex_logits_processor)
143+
assert callable(outlines.from_transformers)
144+
assert callable(outlines.regex)
145+
assert callable(outlines.Generator)
148146

149-
def test_outlines_processor_creation(self) -> None:
150-
"""Verify a regex logits processor can actually be created.
147+
def test_outlines_regex_compiles(self) -> None:
148+
"""Verify the action regex can be compiled by Outlines.
151149
152-
This is the integration test that would have caught the prior bugs:
153-
- Wrong class name (RegexLogitsProcessor vs OutlinesLogitsProcessor)
154-
- Wrong constructor args (tokenizer= kwarg didn't exist)
155-
156-
Requires a real model, so we use a tiny one or skip.
150+
This catches DFA state explosion (bounded quantifiers) and
151+
syntax errors in the regex.
157152
"""
158153
try:
159154
import outlines
160-
import torch
161-
from transformers import AutoTokenizer
162155
except ImportError:
163-
pytest.skip("outlines/torch/transformers not installed")
156+
pytest.skip("outlines not installed")
164157

165-
try:
166-
# Use the smallest possible tokenizer for fast test
167-
tokenizer = AutoTokenizer.from_pretrained(
168-
"hf-internal-testing/tiny-random-LlamaForCausalLM",
169-
trust_remote_code=True,
170-
)
171-
except Exception:
172-
pytest.skip("Could not load test tokenizer")
158+
# This should NOT raise — if it does, the regex is too complex
159+
constraint = outlines.regex(GRPOTrainer._ACTION_REGEX)
160+
assert constraint is not None
173161

174-
from outlines.generator import get_regex_logits_processor
162+
def test_outlines_generator_api_contract(self) -> None:
163+
"""Verify the Outlines Generator API contract the trainer depends on.
175164
176-
# Verify the factory function signature matches what the trainer expects:
177-
# get_regex_logits_processor(backend_name, model, regex)
178-
import inspect
179-
sig = inspect.signature(get_regex_logits_processor)
165+
Checks that:
166+
1. outlines.from_transformers accepts (model, processor) args
167+
2. outlines.regex returns an object Generator accepts
168+
3. outlines.Generator returns a callable
169+
4. The callable accepts (prompt, max_new_tokens=N) kwargs
170+
171+
Does NOT load a real model (too slow for CI). Instead verifies
172+
the API signatures match what the trainer calls.
173+
"""
174+
try:
175+
import outlines
176+
import inspect
177+
except ImportError:
178+
pytest.skip("outlines not installed")
179+
180+
# 1. from_transformers signature
181+
sig = inspect.signature(outlines.from_transformers)
180182
params = list(sig.parameters.keys())
181-
assert len(params) >= 3, (
182-
f"get_regex_logits_processor signature changed: {sig}. "
183-
f"Expected (backend_name, model, regex), got {params}"
183+
assert "model" in params, f"from_transformers missing 'model' param: {params}"
184+
assert "tokenizer_or_processor" in params or len(params) >= 2, (
185+
f"from_transformers signature changed: {sig}"
184186
)
185187

186-
def test_empty_list_no_longer_caches_as_success(self) -> None:
187-
"""Regression test: empty list [] should NOT be treated as success.
188-
189-
Prior bug: failure cached [] which is truthy for `is not None`,
190-
causing subsequent calls to return [] (no processors applied).
191-
"""
188+
# 2. regex returns something
189+
constraint = outlines.regex(r"DONE\(\)")
190+
assert constraint is not None
191+
192+
# 3. Generator signature
193+
sig_gen = inspect.signature(outlines.Generator)
194+
params_gen = list(sig_gen.parameters.keys())
195+
assert "model" in params_gen, f"Generator missing 'model' param: {params_gen}"
196+
197+
# 4. SteerableGenerator.__call__ accepts **inference_kwargs
198+
from outlines.generator import SteerableGenerator
199+
sig_call = inspect.signature(SteerableGenerator.__call__)
200+
params_call = list(sig_call.parameters.keys())
201+
assert "inference_kwargs" in params_call or any(
202+
p.startswith("**") or sig_call.parameters[p].kind == inspect.Parameter.VAR_KEYWORD
203+
for p in params_call
204+
), f"SteerableGenerator.__call__ doesn't accept **kwargs: {sig_call}"
205+
206+
def test_false_sentinel_not_confused_with_none(self) -> None:
207+
"""Regression: False sentinel must return None, not be treated as uninitialized."""
192208
config = TrainingConfig(constrained_decoding=True)
193209
trainer = GRPOTrainer(config)
194-
# The old buggy behavior would cache [] on failure
195-
# Verify the sentinel is False (not []) for failures
196-
trainer._constrained_processor_cache = False
197-
assert trainer._get_constrained_logits_processor() is None
198-
# And [] is actually a valid success cache (with a processor in it)
199-
trainer._constrained_processor_cache = ["real_processor"]
200-
assert trainer._get_constrained_logits_processor() == ["real_processor"]
210+
trainer._outlines_generator = False
211+
assert trainer._get_outlines_generator() is None
212+
# A real generator object should be returned as-is
213+
trainer._outlines_generator = "real_generator"
214+
assert trainer._get_outlines_generator() == "real_generator"
201215

202216

203217
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)