Skip to content

Commit c1d3588

Browse files
abrichrclaude
andauthored
fix: patch chat_template to remove <think> tags at the source (#250)
Stripping <think> from rendered text was insufficient — TRL or the processor may re-apply the template, re-inserting the tags. The fix: patch processor.chat_template and processor.tokenizer.chat_template on first rollout call, removing <think>/<think> from the Jinja template itself. This ensures no code path can re-insert thinking mode. Also strips </think> (was missed in #249). Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent bd3acaf commit c1d3588

1 file changed

Lines changed: 32 additions & 4 deletions

File tree

openadapt_evals/training/trl_rollout.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def make_waa_rollout_func(
417417
_outlines_state: dict[str, Any] = {"generator": None, "attempted": False}
418418
_prompt_logged: list[bool] = [False] # log the prompt once for diagnostics
419419
_output_logged: list[bool] = [False] # log first generation output
420+
_template_patched: list[bool] = [False] # patch chat template once
420421

421422
def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
422423
"""TRL GRPOTrainer rollout function.
@@ -430,6 +431,27 @@ def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
430431
"""
431432
processor = trainer.processing_class
432433
model = trainer.model
434+
435+
# --- Disable Qwen3.5 thinking mode at the template level ---
436+
# Qwen3.5's chat template inserts <think> which produces opaque
437+
# reasoning tokens instead of DSL actions. Stripping from the
438+
# rendered text is insufficient because TRL or the processor may
439+
# re-apply the template. The fix: patch the template itself so
440+
# <think> is never inserted, regardless of who calls it.
441+
if not _template_patched[0]:
442+
_template_patched[0] = True
443+
for obj in [processor, getattr(processor, "tokenizer", None)]:
444+
if obj is None:
445+
continue
446+
tpl = getattr(obj, "chat_template", None)
447+
if tpl and "<think>" in tpl:
448+
patched = tpl.replace("<think>", "").replace("</think>", "")
449+
obj.chat_template = patched
450+
logger.info(
451+
"Patched chat_template on %s to remove <think>/<think> "
452+
"tags (disables Qwen3.5 thinking mode)",
453+
type(obj).__name__,
454+
)
433455
device = next(model.parameters()).device
434456

435457
num_generations = getattr(trainer.args, "num_generations", 8)
@@ -532,10 +554,16 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
532554
messages, **chat_kwargs,
533555
)
534556

535-
# Belt-and-suspenders: strip <think> tag if it slipped through
536-
if "<think>" in text_input:
537-
logger.info("Stripping <think> tag from prompt to disable thinking mode")
538-
text_input = text_input.replace("<think>\n", "").replace("<think>", "")
557+
# Belt-and-suspenders: strip thinking tags if they slipped through
558+
if "<think>" in text_input or "</think>" in text_input:
559+
logger.info("Stripping <think>/<think> tags from rendered prompt")
560+
text_input = (
561+
text_input
562+
.replace("<think>\n", "")
563+
.replace("<think>", "")
564+
.replace("</think>\n", "")
565+
.replace("</think>", "")
566+
)
539567

540568
# Comprehensive prompt diagnostics on first call.
541569
# This logs everything needed to debug prompt construction:

0 commit comments

Comments
 (0)