-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathinference.py
More file actions
290 lines (248 loc) · 9.54 KB
/
inference.py
File metadata and controls
290 lines (248 loc) · 9.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
"""Inference module for calling LLMs to perform various tasks."""
import json
import time
from typing import Any, Callable, Optional, Union
from pydantic import BaseModel
from stagehand.llm.prompts import (
build_extract_system_prompt,
build_extract_user_prompt,
build_metadata_prompt,
build_metadata_system_prompt,
build_observe_system_prompt,
build_observe_user_message,
)
from stagehand.types import (
MetadataSchema,
ObserveInferenceSchema,
)
# TODO: kwargs
async def observe(
instruction: str,
tree_elements: str,
llm_client: Any,
user_provided_instructions: Optional[str] = None,
logger: Optional[Callable] = None,
log_inference_to_file: bool = False,
from_act: bool = False,
) -> dict[str, Any]:
"""
Call LLM to find elements in the DOM/accessibility tree based on an instruction.
Args:
instruction: The instruction to follow when finding elements
tree_elements: String representation of DOM/accessibility tree elements
llm_client: Client for calling LLM
user_provided_instructions: Optional custom system instructions
logger: Optional logger function
log_inference_to_file: Whether to log inference to file
from_act: Whether this observe call is part of an act operation
Returns:
dict containing elements found and token usage information
"""
# Build the prompts
system_prompt = build_observe_system_prompt(
user_provided_instructions=user_provided_instructions,
)
user_prompt = build_observe_user_message(
instruction=instruction,
tree_elements=tree_elements,
)
messages = [
system_prompt,
user_prompt,
]
start_time = time.time()
try:
# Call the LLM
logger.info("Calling LLM")
response = await llm_client.create_response(
model=llm_client.default_model,
messages=messages,
response_format=ObserveInferenceSchema,
temperature=0.1,
function_name="ACT" if from_act else "OBSERVE",
)
inference_time_ms = int((time.time() - start_time) * 1000)
# Extract token counts
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
# Parse the response
content = response.choices[0].message.content
logger.info("Got LLM response")
logger.debug(
"LLM Response",
auxiliary={
"content": content,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"inference_time_ms": inference_time_ms,
},
)
if isinstance(content, str):
try:
parsed_response = json.loads(content)
except json.JSONDecodeError:
if logger:
logger.error(f"Failed to parse JSON response: {content}")
parsed_response = {"elements": []}
else:
parsed_response = content
elements = parsed_response.get("elements", [])
return {
"elements": elements,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"inference_time_ms": inference_time_ms,
}
except Exception as e:
if logger:
logger.error(f"Error in observe inference: {str(e)}")
# Return empty response on error
return {
"elements": [],
"prompt_tokens": 0,
"completion_tokens": 0,
"inference_time_ms": int((time.time() - start_time) * 1000),
}
async def extract(
instruction: str,
tree_elements: str,
schema: Optional[Union[type[BaseModel], dict]] = None,
llm_client: Any = None,
user_provided_instructions: Optional[str] = None,
logger: Optional[Callable] = None,
log_inference_to_file: bool = False,
is_using_text_extract: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
"""
Call LLM to extract structured data from the page based on the provided instruction and schema.
Args:
instruction: The instruction for what data to extract
tree_elements: The DOM or accessibility tree representation
schema: Pydantic model defining the structure of the data to extract
llm_client: The LLM client to use for the request
user_provided_instructions: Optional custom system instructions
logger: Logger instance for logging
log_inference_to_file: Whether to log inference to file
is_using_text_extract: Whether using text extraction (vs. DOM/a11y tree)
**kwargs: Additional parameters to pass to the LLM client
Returns:
dict containing the extraction results and metadata
"""
logger.info("Calling LLM")
# Create system and user messages for extraction
system_message = build_extract_system_prompt(
is_using_text_extract=is_using_text_extract,
user_provided_instructions=user_provided_instructions,
)
user_message = build_extract_user_prompt(instruction, tree_elements)
extract_messages = [
system_message,
user_message,
]
# Call LLM for extraction
start_time = time.time()
# Determine if we need to use schema-based response format
# TODO: if schema is json, return json
response_format = {"type": "json_object"}
if schema:
# If schema is a Pydantic model, use it directly
response_format = schema
# Call the LLM with appropriate parameters
try:
extract_response = await llm_client.create_response(
model=llm_client.default_model,
messages=extract_messages,
response_format=response_format,
temperature=0.1,
function_name="EXTRACT", # Always set to EXTRACT
**kwargs,
)
extract_time_ms = int((time.time() - start_time) * 1000)
# Extract token counts
prompt_tokens = extract_response.usage.prompt_tokens
completion_tokens = extract_response.usage.completion_tokens
# Parse the response
extract_content = extract_response.choices[0].message.content
if isinstance(extract_content, str):
try:
extracted_data = json.loads(extract_content)
except json.JSONDecodeError:
logger.error(
f"Failed to parse JSON extraction response: {extract_content}"
)
extracted_data = {}
else:
extracted_data = extract_content
except Exception as e:
logger.error(f"Error in extract inference: {str(e)}")
# In case of failure, return empty data
extracted_data = {}
prompt_tokens = 0
completion_tokens = 0
extract_time_ms = int((time.time() - start_time) * 1000)
# Generate metadata about the extraction
metadata_system_message = build_metadata_system_prompt()
metadata_user_message = build_metadata_prompt(instruction, extracted_data, 1, 1)
metadata_messages = [
metadata_system_message,
metadata_user_message,
]
# Define the metadata schema
metadata_schema = MetadataSchema
# Call LLM for metadata
try:
metadata_start_time = time.time()
metadata_response = await llm_client.create_response(
model=llm_client.default_model,
messages=metadata_messages,
response_format=metadata_schema,
temperature=0.1,
function_name="EXTRACT", # Metadata for extraction should also be tracked as EXTRACT
)
metadata_end_time = time.time()
metadata_time_ms = int((metadata_end_time - metadata_start_time) * 1000)
logger.info("Got LLM response")
# Extract metadata content
metadata_content = metadata_response.choices[0].message.content
# Parse metadata content
if isinstance(metadata_content, str):
try:
metadata = json.loads(metadata_content)
except json.JSONDecodeError:
logger.error(f"Failed to parse metadata response: {metadata_content}")
metadata = {"completed": False, "progress": "Failed to parse metadata"}
else:
metadata = metadata_content
# Get token usage for metadata
metadata_prompt_tokens = metadata_response.usage.prompt_tokens
metadata_completion_tokens = metadata_response.usage.completion_tokens
except Exception as e:
logger.error(f"Error in metadata inference: {str(e)}")
# In case of failure, use default metadata
metadata = {"completed": False, "progress": "Metadata generation failed"}
metadata_prompt_tokens = 0
metadata_completion_tokens = 0
metadata_time_ms = 0
# Calculate total tokens and time
total_prompt_tokens = prompt_tokens + metadata_prompt_tokens
total_completion_tokens = completion_tokens + metadata_completion_tokens
total_inference_time_ms = extract_time_ms + metadata_time_ms
# Create the final result
result = {
"data": extracted_data,
"metadata": metadata,
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"inference_time_ms": total_inference_time_ms,
}
logger.debug(
"LLM response",
auxiliary={
"metadata": metadata,
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"inference_time_ms": total_inference_time_ms,
},
)
return result