Skip to content

Commit 6a4abbe

Browse files
committed
added support for computer-use models and newer gemini models
1 parent 6ae1d7f commit 6a4abbe

File tree

3 files changed

+322
-12
lines changed

3 files changed

+322
-12
lines changed

app/models/factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from models.gpt4o import GPT4o
22
from models.gpt4v import GPT4v
33
from models.gpt5 import GPT5
4+
from models.openai_computer_use import OpenAIComputerUse
45
from models.gemini import Gemini
56

67

@@ -10,6 +11,8 @@ def create_model(model_name, *args):
1011
try:
1112
if model_name == 'gpt-4o' or model_name == 'gpt-4o-mini':
1213
return GPT4o(model_name, *args)
14+
elif model_name == 'computer-use-preview':
15+
return OpenAIComputerUse(model_name, *args)
1316
elif model_name.startswith('gpt-5'):
1417
return GPT5(model_name, *args)
1518
elif model_name == 'gpt-4-vision-preview' or model_name == 'gpt-4-turbo':

app/models/openai_computer_use.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
from typing import Any
2+
3+
from models.model import Model
4+
from utils.screen import Screen
5+
6+
7+
class OpenAIComputerUse(Model):
8+
def __init__(self, model_name, base_url, api_key, context):
9+
super().__init__(model_name, base_url, api_key, context)
10+
self.previous_response_id = None
11+
self.last_call_id = None
12+
self.pending_safety_checks = []
13+
14+
def get_instructions_for_objective(self, original_user_request: str, step_num: int = 0) -> dict[str, Any]:
15+
if step_num == 0:
16+
self.previous_response_id = None
17+
self.last_call_id = None
18+
self.pending_safety_checks = []
19+
20+
llm_response = self.send_message_to_llm(original_user_request)
21+
return self.convert_llm_response_to_json_instructions(llm_response)
22+
23+
def send_message_to_llm(self, original_user_request: str) -> Any:
24+
base64_img = Screen().get_screenshot_in_base64()
25+
screenshot_url = f'data:image/png;base64,{base64_img}'
26+
screen_width, screen_height = Screen().get_size()
27+
28+
tools = [{
29+
'type': 'computer_use_preview',
30+
'display_width': screen_width,
31+
'display_height': screen_height,
32+
'environment': 'browser'
33+
}]
34+
35+
if self.previous_response_id and self.last_call_id:
36+
computer_call_output: dict[str, Any] = {
37+
'type': 'computer_call_output',
38+
'call_id': self.last_call_id,
39+
'output': {
40+
'type': 'input_image',
41+
'image_url': screenshot_url
42+
}
43+
}
44+
45+
if self.pending_safety_checks:
46+
computer_call_output['acknowledged_safety_checks'] = self.pending_safety_checks
47+
self.pending_safety_checks = []
48+
49+
return self.client.responses.create(
50+
model=self.model_name,
51+
previous_response_id=self.previous_response_id,
52+
tools=tools,
53+
input=[computer_call_output],
54+
truncation='auto'
55+
)
56+
57+
return self.client.responses.create(
58+
model=self.model_name,
59+
tools=tools,
60+
input=[{
61+
'role': 'user',
62+
'content': [
63+
{'type': 'input_text', 'text': original_user_request},
64+
{'type': 'input_image', 'image_url': screenshot_url}
65+
]
66+
}],
67+
reasoning={'summary': 'concise'},
68+
truncation='auto'
69+
)
70+
71+
def convert_llm_response_to_json_instructions(self, llm_response: Any) -> dict[str, Any]:
72+
self.previous_response_id = self.read_obj(llm_response, 'id')
73+
74+
output_items = self.read_obj(llm_response, 'output') or []
75+
for item in output_items:
76+
if self.read_obj(item, 'type') != 'computer_call':
77+
continue
78+
79+
self.last_call_id = self.read_obj(item, 'call_id')
80+
self.pending_safety_checks = self.serialize_safety_checks(self.read_obj(item, 'pending_safety_checks') or [])
81+
82+
action = self.read_obj(item, 'action') or {}
83+
steps = self.convert_action_to_steps(action)
84+
return {
85+
'steps': steps,
86+
'done': None
87+
}
88+
89+
done_message = (self.read_obj(llm_response, 'output_text') or '').strip()
90+
if done_message == '':
91+
done_message = 'Done.'
92+
93+
self.last_call_id = None
94+
self.pending_safety_checks = []
95+
return {
96+
'steps': [],
97+
'done': done_message
98+
}
99+
100+
def serialize_safety_checks(self, checks: list[Any]) -> list[dict[str, Any]]:
101+
serialized = []
102+
for check in checks:
103+
check_id = self.read_obj(check, 'id')
104+
code = self.read_obj(check, 'code')
105+
message = self.read_obj(check, 'message')
106+
if check_id and code and message:
107+
serialized.append({
108+
'id': check_id,
109+
'code': code,
110+
'message': message
111+
})
112+
return serialized
113+
114+
def convert_action_to_steps(self, action: Any) -> list[dict[str, Any]]:
115+
action_type = self.read_obj(action, 'type')
116+
117+
if action_type == 'click':
118+
return [{
119+
'function': 'click',
120+
'parameters': {
121+
'x': self.read_obj(action, 'x'),
122+
'y': self.read_obj(action, 'y'),
123+
'button': self.read_obj(action, 'button') or 'left',
124+
'clicks': 1
125+
}
126+
}]
127+
128+
if action_type == 'double_click':
129+
return [{
130+
'function': 'click',
131+
'parameters': {
132+
'x': self.read_obj(action, 'x'),
133+
'y': self.read_obj(action, 'y'),
134+
'button': 'left',
135+
'clicks': 2
136+
}
137+
}]
138+
139+
if action_type == 'move':
140+
return [{
141+
'function': 'moveTo',
142+
'parameters': {
143+
'x': self.read_obj(action, 'x'),
144+
'y': self.read_obj(action, 'y')
145+
}
146+
}]
147+
148+
if action_type == 'scroll':
149+
scroll_y = self.read_obj(action, 'scroll_y') or 0
150+
return [{
151+
'function': 'scroll',
152+
'parameters': {
153+
# Browser coordinate systems usually use positive Y for scrolling down;
154+
# pyautogui.scroll uses negative values for down.
155+
'clicks': int(-scroll_y)
156+
}
157+
}]
158+
159+
if action_type == 'type':
160+
return [{
161+
'function': 'write',
162+
'parameters': {
163+
'string': self.read_obj(action, 'text') or '',
164+
'interval': 0.03
165+
}
166+
}]
167+
168+
if action_type == 'wait':
169+
return [{
170+
'function': 'sleep',
171+
'parameters': {
172+
'secs': 1
173+
}
174+
}]
175+
176+
if action_type == 'keypress':
177+
keys = self.read_obj(action, 'keys') or []
178+
normalized_keys = [self.normalize_key_name(key) for key in keys if key]
179+
if len(normalized_keys) == 0:
180+
return []
181+
if len(normalized_keys) == 1:
182+
return [{
183+
'function': 'press',
184+
'parameters': {
185+
'key': normalized_keys[0]
186+
}
187+
}]
188+
return [{
189+
'function': 'hotkey',
190+
'parameters': {
191+
'keys': normalized_keys
192+
}
193+
}]
194+
195+
if action_type == 'drag':
196+
path = self.read_obj(action, 'path') or []
197+
if len(path) < 2:
198+
return []
199+
200+
start_x = self.read_obj(path[0], 0)
201+
start_y = self.read_obj(path[0], 1)
202+
end_x = self.read_obj(path[-1], 0)
203+
end_y = self.read_obj(path[-1], 1)
204+
205+
if None in [start_x, start_y, end_x, end_y]:
206+
return []
207+
208+
return [
209+
{
210+
'function': 'moveTo',
211+
'parameters': {'x': start_x, 'y': start_y}
212+
},
213+
{
214+
'function': 'dragTo',
215+
'parameters': {'x': end_x, 'y': end_y, 'duration': 0.2, 'button': 'left'}
216+
}
217+
]
218+
219+
if action_type == 'screenshot':
220+
return []
221+
222+
print(f'Unsupported computer_use action type: {action_type}')
223+
return []
224+
225+
@staticmethod
226+
def read_obj(obj: Any, key: Any, default=None) -> Any:
227+
if obj is None:
228+
return default
229+
if isinstance(obj, dict):
230+
return obj.get(key, default)
231+
if isinstance(obj, (list, tuple)) and isinstance(key, int):
232+
if 0 <= key < len(obj):
233+
return obj[key]
234+
return default
235+
return getattr(obj, key, default)
236+
237+
@staticmethod
238+
def normalize_key_name(key: str) -> str:
239+
key_l = str(key).lower()
240+
key_mappings = {
241+
'ctrl': 'ctrl',
242+
'control': 'ctrl',
243+
'cmd': 'command',
244+
'command': 'command',
245+
'option': 'option',
246+
'alt': 'alt',
247+
'return': 'enter',
248+
'esc': 'esc',
249+
'arrowleft': 'left',
250+
'arrowright': 'right',
251+
'arrowup': 'up',
252+
'arrowdown': 'down',
253+
}
254+
return key_mappings.get(key_l, key_l)
255+
256+
def cleanup(self):
257+
self.previous_response_id = None
258+
self.last_call_id = None
259+
self.pending_safety_checks = []

app/ui.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,26 +59,64 @@ def create_widgets(self) -> None:
5959
radio_frame = ttk.Frame(self)
6060
radio_frame.pack(padx=20, pady=10) # Add padding around the frame
6161

62-
models = [
62+
openai_models = [
6363
('GPT-5.2 (Default)', 'gpt-5.2'),
64+
('OpenAI computer-use-preview (GUI actions)', 'computer-use-preview'),
65+
]
66+
67+
gemini_models = [
68+
('Gemini gemini-3-pro-preview', 'gemini-3-pro-preview'),
69+
('Gemini gemini-3-flash-preview', 'gemini-3-flash-preview'),
70+
]
71+
72+
deprecated_models = [
6473
('GPT-4o (Medium-Accurate, Medium-Fast)', 'gpt-4o'),
6574
('GPT-4o-mini (Cheapest, Fastest)', 'gpt-4o-mini'),
66-
('GPT-4v (Deprecated. Most-Accurate, Slowest)', 'gpt-4-vision-preview'),
75+
('GPT-4v (Most-Accurate, Slowest)', 'gpt-4-vision-preview'),
6776
('GPT-4-Turbo (Least Accurate, Fast)', 'gpt-4-turbo'),
68-
('', ''),
69-
('Gemini gemini-2.0-flash (Free, Fast)', 'gemini-2.0-flash'),
77+
('Gemini gemini-2.5-pro', 'gemini-2.5-pro'),
78+
('Gemini gemini-2.5-flash', 'gemini-2.5-flash'),
79+
('Gemini gemini-2.5-flash-lite', 'gemini-2.5-flash-lite'),
80+
('Gemini gemini-2.0-flash', 'gemini-2.0-flash'),
7081
('Gemini gemini-2.0-flash-lite', 'gemini-2.0-flash-lite'),
7182
('Gemini gemini-2.0-flash-thinking-exp', 'gemini-2.0-flash-thinking-exp'),
7283
('Gemini gemini-2.0-pro-exp-02-05', 'gemini-2.0-pro-exp-02-05'),
73-
('', ''),
74-
('Custom (Specify Settings Below)', 'custom')
7584
]
76-
for text, value in models:
77-
if text == '' and value == '':
78-
ttk.Separator(radio_frame, orient='horizontal').pack(fill='x', pady=10)
79-
else:
80-
ttk.Radiobutton(radio_frame, text=text, value=value, variable=self.model_var, bootstyle="info").pack(
81-
anchor=ttk.W, pady=5)
85+
86+
for text, value in openai_models:
87+
ttk.Radiobutton(radio_frame, text=text, value=value, variable=self.model_var, bootstyle="info").pack(
88+
anchor=ttk.W, pady=5)
89+
90+
ttk.Separator(radio_frame, orient='horizontal').pack(fill='x', pady=8)
91+
92+
for text, value in gemini_models:
93+
ttk.Radiobutton(radio_frame, text=text, value=value, variable=self.model_var, bootstyle="info").pack(
94+
anchor=ttk.W, pady=5)
95+
96+
ttk.Separator(radio_frame, orient='horizontal').pack(fill='x', pady=10)
97+
98+
self.deprecated_expanded = False
99+
self.deprecated_toggle_button = ttk.Button(
100+
radio_frame,
101+
text='Older Models ▸',
102+
bootstyle='secondary-link',
103+
command=self.toggle_deprecated_section
104+
)
105+
self.deprecated_toggle_button.pack(anchor=ttk.W, pady=(0, 5))
106+
107+
self.deprecated_frame = ttk.Frame(radio_frame)
108+
for text, value in deprecated_models:
109+
ttk.Radiobutton(self.deprecated_frame, text=text, value=value, variable=self.model_var, bootstyle="info").pack(
110+
anchor=ttk.W, pady=5)
111+
112+
ttk.Separator(radio_frame, orient='horizontal').pack(fill='x', pady=10)
113+
ttk.Radiobutton(
114+
radio_frame,
115+
text='Custom (Specify Settings Below)',
116+
value='custom',
117+
variable=self.model_var,
118+
bootstyle="info"
119+
).pack(anchor=ttk.W, pady=5)
82120

83121
label_base_url = ttk.Label(self, text='Custom OpenAI-Like API Model Base URL', bootstyle="secondary")
84122
label_base_url.pack(pady=10)
@@ -104,6 +142,16 @@ def create_widgets(self) -> None:
104142
font=('Helvetica', 10))
105143
restart_app_label.pack(pady=(0, 20))
106144

145+
def toggle_deprecated_section(self) -> None:
146+
if self.deprecated_expanded:
147+
self.deprecated_frame.pack_forget()
148+
self.deprecated_toggle_button.config(text='Older Models ▸')
149+
self.deprecated_expanded = False
150+
else:
151+
self.deprecated_frame.pack(anchor=ttk.W, padx=(12, 0), pady=(0, 6), after=self.deprecated_toggle_button)
152+
self.deprecated_toggle_button.config(text='Older Models ▾')
153+
self.deprecated_expanded = True
154+
107155
def save_button(self) -> None:
108156
base_url = self.base_url_entry.get().strip()
109157
model = self.model_var.get() if self.model_var.get() != 'custom' else self.model_entry.get().strip()

0 commit comments

Comments
 (0)