Skip to content

Commit aeac459

Browse files
abrichrclaude
andauthored
Align PolicyAgent prompt with training format (#31)
* Align PolicyAgent prompt with training format from convert_demos.py - Import SYSTEM_PROMPT from convert_demos (canonical source) - Add system message to SFT sample - Change "Goal:" label to "Instruction:" (training format) - Remove a11y tree, URL, window title injection (not in training data) - Add <think> instruction matching training tail prompt - Format history as " Step {i}: {action}" (0-indexed, indented) - Track previous actions across steps (reset on reset()) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix PolicyAgent to call predict_action_from_sample (not predict) AgentPolicy has predict_action_from_sample() which returns a 4-tuple (Action, thought, state, raw_text). The previous code called predict() which doesn't exist on AgentPolicy. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix _action_to_string to match training format from convert_demos Replace UPPERCASE/normalized format (CLICK(0.500, 0.300)) with training-aligned format (click(x=500, y=300)): lowercase function names, [0,1000] coordinates, named parameters, press() for keys, finished() for done. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(modal): increase inference timeout from 300s to 600s Vision model inference with large screenshots can take 3+ minutes on A10G, especially on cold start. 300s was causing premature timeouts. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: remove dead system prompt from PolicyAgent._build_sample() QwenVLAdapter.generate() only extracts user role messages, dropping the system prompt. Since training also ignores it, removing it at inference keeps behaviour consistent and eliminates misleading code. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * style: ruff format agent.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent afad981 commit aeac459

File tree

2 files changed

+78
-75
lines changed

2 files changed

+78
-75
lines changed

openadapt_ml/benchmarks/agent.py

Lines changed: 77 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,21 @@ class PolicyAgent(BenchmarkAgent):
5252
Converts between BenchmarkObservation/BenchmarkAction and the
5353
SFT sample format expected by AgentPolicy.
5454
55+
Prompt format is aligned with convert_demos.py training data.
56+
5557
Args:
5658
policy: AgentPolicy instance to wrap.
57-
use_accessibility_tree: Whether to include accessibility tree in prompt.
58-
use_history: Whether to include action history in prompt.
59+
use_thinking: Whether to include <think> instruction in prompts.
5960
"""
6061

6162
def __init__(
6263
self,
6364
policy: AgentPolicy,
64-
use_accessibility_tree: bool = True,
65-
use_history: bool = True,
65+
use_thinking: bool = True,
6666
):
6767
self.policy = policy
68-
self.use_accessibility_tree = use_accessibility_tree
69-
self.use_history = use_history
68+
self.use_thinking = use_thinking
69+
self._previous_actions: list[str] = []
7070

7171
def act(
7272
self,
@@ -84,42 +84,63 @@ def act(
8484
Returns:
8585
BenchmarkAction from policy.
8686
"""
87-
# Build SFT-style sample
88-
sample = self._build_sample(observation, task, history)
87+
# Build SFT-style sample (aligned with training format)
88+
sample = self._build_sample(observation, task)
8989

9090
# Get action from policy
91-
action, thought = self.policy.predict(sample)
91+
action, thought, _state, _raw = self.policy.predict_action_from_sample(sample)
9292

9393
# Convert to BenchmarkAction
94-
return self._to_benchmark_action(action, thought)
94+
benchmark_action = self._to_benchmark_action(action, thought)
95+
96+
# Track action for next step's "Previous actions" section
97+
self._previous_actions.append(self._action_to_string(benchmark_action))
98+
99+
return benchmark_action
95100

96101
def _build_sample(
97102
self,
98103
observation: BenchmarkObservation,
99104
task: BenchmarkTask,
100-
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None,
101105
) -> dict:
102-
"""Build SFT-style sample from benchmark observation."""
103-
content_parts = [f"Goal: {task.instruction}"]
104-
105-
if self.use_accessibility_tree and observation.accessibility_tree:
106-
tree_str = self._format_accessibility_tree(observation.accessibility_tree)
107-
content_parts.append(f"UI Elements:\n{tree_str}")
106+
"""Build SFT-style sample aligned with convert_demos.py training format.
108107
109-
if observation.url:
110-
content_parts.append(f"URL: {observation.url}")
111-
if observation.window_title:
112-
content_parts.append(f"Window: {observation.window_title}")
108+
NOTE: No system message is included here because
109+
``QwenVLAdapter.generate()`` only extracts the user role message
110+
and drops any system role. The model was trained under the same
111+
conditions (no system prompt), so omitting it at inference keeps
112+
behaviour consistent.
113113
114-
if self.use_history and history:
115-
history_str = self._format_history(history)
116-
content_parts.append(f"Previous actions:\n{history_str}")
114+
Format::
117115
118-
content_parts.append("What action should be taken next?")
116+
user: <image>
117+
Instruction: {instruction}
118+
...previous actions...
119+
First reason about what you see in <think>...</think> tags,
120+
then output exactly one action.
121+
"""
122+
# Build user content matching training format
123+
parts = ["<image>"]
124+
parts.append(f"Instruction: {task.instruction}")
125+
126+
if self._previous_actions:
127+
parts.append("")
128+
parts.append("Previous actions:")
129+
for i, act in enumerate(self._previous_actions):
130+
parts.append(f" Step {i}: {act}")
131+
132+
parts.append("")
133+
if self.use_thinking:
134+
parts.append(
135+
"First reason about what you see in <think>...</think> "
136+
"tags, then output exactly one action."
137+
)
138+
else:
139+
parts.append("Output exactly one action.")
119140

120141
sample = {
121142
"messages": [
122-
{"role": "user", "content": "\n\n".join(content_parts)},
143+
{"role": "user", "content": "\n".join(parts)},
123144
],
124145
}
125146

@@ -128,57 +149,39 @@ def _build_sample(
128149

129150
return sample
130151

131-
def _format_accessibility_tree(self, tree: dict, indent: int = 0) -> str:
132-
"""Format accessibility tree for prompt."""
133-
lines = []
134-
prefix = " " * indent
135-
136-
role = tree.get("role", "unknown")
137-
name = tree.get("name", "")
138-
node_id = tree.get("id", tree.get("node_id", ""))
152+
@staticmethod
153+
def _action_to_string(action: BenchmarkAction) -> str:
154+
"""Format action matching convert_demos._format_action_qwen training format.
139155
140-
line = f"{prefix}[{node_id}] {role}"
141-
if name:
142-
line += f": {name}"
143-
lines.append(line)
144-
145-
for child in tree.get("children", []):
146-
lines.append(self._format_accessibility_tree(child, indent + 1))
156+
Uses [0, 1000] coordinate range and lowercase function-call style
157+
to match what the model was trained on.
158+
"""
147159

148-
return "\n".join(lines)
160+
def _to_1000(v: float | None) -> int:
161+
return round((v or 0.0) * 1000)
149162

150-
def _format_history(
151-
self, history: list[tuple[BenchmarkObservation, BenchmarkAction]]
152-
) -> str:
153-
"""Format action history for prompt."""
154-
lines = []
155-
for i, (obs, action) in enumerate(history[-5:], 1):
156-
action_str = self._action_to_string(action)
157-
lines.append(f"{i}. {action_str}")
158-
return "\n".join(lines)
159-
160-
def _action_to_string(self, action: BenchmarkAction) -> str:
161-
"""Convert BenchmarkAction to string representation."""
162163
if action.type == "click":
163-
if action.target_name:
164-
return f"CLICK({action.target_name})"
165-
return f"CLICK(x={action.x:.3f}, y={action.y:.3f})"
166-
elif action.type == "type":
167-
return f"TYPE({action.text!r})"
168-
elif action.type == "key":
169-
mods = "+".join(action.modifiers or [])
170-
key = action.key
171-
if mods:
172-
return f"KEY({mods}+{key})"
173-
return f"KEY({key})"
174-
elif action.type == "scroll":
175-
return f"SCROLL({action.scroll_direction})"
176-
elif action.type == "done":
177-
return "DONE()"
178-
elif action.type == "answer":
179-
return f"ANSWER({action.answer!r})"
180-
else:
181-
return f"{action.type.upper()}()"
164+
return f"click(x={_to_1000(action.x)}, y={_to_1000(action.y)})"
165+
if action.type == "double_click":
166+
return f"double_click(x={_to_1000(action.x)}, y={_to_1000(action.y)})"
167+
if action.type == "right_click":
168+
return f"right_click(x={_to_1000(action.x)}, y={_to_1000(action.y)})"
169+
if action.type == "type":
170+
return f'type(text="{action.text or ""}")'
171+
if action.type == "key":
172+
keys = (action.modifiers or []) + ([action.key] if action.key else [])
173+
keys_fmt = ", ".join(f'"{k}"' for k in keys)
174+
return f"press(keys=[{keys_fmt}])"
175+
if action.type == "scroll":
176+
return f'scroll(direction="{action.scroll_direction or "down"}", amount=3)'
177+
if action.type == "drag":
178+
return (
179+
f"drag(from_coord=[{_to_1000(action.x)}, {_to_1000(action.y)}], "
180+
f"to_coord=[{_to_1000(action.end_x)}, {_to_1000(action.end_y)}])"
181+
)
182+
if action.type == "done":
183+
return "finished()"
184+
return f"# unknown: {action.type}"
182185

183186
def _to_benchmark_action(
184187
self, action: Action, thought: str | None
@@ -233,7 +236,7 @@ def _to_benchmark_action(
233236

234237
def reset(self) -> None:
235238
"""Reset agent state."""
236-
pass
239+
self._previous_actions = []
237240

238241

239242
class APIBenchmarkAgent(BenchmarkAgent):

openadapt_ml/cloud/modal_cloud.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def _build_inference_app(
305305
gpu=gpu,
306306
image=inference_image,
307307
volumes={VOLUME_MOUNT: vol},
308-
timeout=300,
308+
timeout=600,
309309
serialized=True,
310310
scaledown_window=600,
311311
)

0 commit comments

Comments
 (0)