Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions openadapt_evals/training/trl_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def make_waa_rollout_func(
_outlines_state: dict[str, Any] = {"generator": None, "attempted": False}
_prompt_logged: list[bool] = [False] # log the prompt once for diagnostics
_output_logged: list[bool] = [False] # log first generation output
_template_patched: list[bool] = [False] # patch chat template once

def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
"""TRL GRPOTrainer rollout function.
Expand All @@ -430,6 +431,27 @@ def rollout_func(prompts: list[str], trainer: Any) -> dict[str, list]:
"""
processor = trainer.processing_class
model = trainer.model

# --- Disable Qwen3.5 thinking mode at the template level ---
# Qwen3.5's chat template inserts <think> which produces opaque
# reasoning tokens instead of DSL actions. Stripping from the
# rendered text is insufficient because TRL or the processor may
# re-apply the template. The fix: patch the template itself so
# <think> is never inserted, regardless of who calls it.
if not _template_patched[0]:
_template_patched[0] = True
for obj in [processor, getattr(processor, "tokenizer", None)]:
if obj is None:
continue
tpl = getattr(obj, "chat_template", None)
if tpl and "<think>" in tpl:
patched = tpl.replace("<think>", "").replace("</think>", "")
obj.chat_template = patched
logger.info(
"Patched chat_template on %s to remove <think>/<think> "
"tags (disables Qwen3.5 thinking mode)",
type(obj).__name__,
)
device = next(model.parameters()).device

num_generations = getattr(trainer.args, "num_generations", 8)
Expand Down Expand Up @@ -532,10 +554,16 @@ def generate_fn(screenshot_bytes: bytes, instruction: str):
messages, **chat_kwargs,
)

# Belt-and-suspenders: strip <think> tag if it slipped through
if "<think>" in text_input:
logger.info("Stripping <think> tag from prompt to disable thinking mode")
text_input = text_input.replace("<think>\n", "").replace("<think>", "")
# Belt-and-suspenders: strip thinking tags if they slipped through
if "<think>" in text_input or "</think>" in text_input:
logger.info("Stripping <think>/<think> tags from rendered prompt")
text_input = (
text_input
.replace("<think>\n", "")
.replace("<think>", "")
.replace("</think>\n", "")
.replace("</think>", "")
)

# Comprehensive prompt diagnostics on first call.
# This logs everything needed to debug prompt construction:
Expand Down
Loading