-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathextract_handler.py
More file actions
188 lines (158 loc) · 6.77 KB
/
extract_handler.py
File metadata and controls
188 lines (158 loc) · 6.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""Extract handler for performing data extraction from page elements using LLMs."""
from typing import Optional, TypeVar
from pydantic import BaseModel
from stagehand.a11y.utils import get_accessibility_tree
from stagehand.llm.inference import extract as extract_inference
from stagehand.metrics import StagehandFunctionName # Changed import location
from stagehand.types import (
DefaultExtractSchema,
EmptyExtractSchema,
ExtractOptions,
ExtractResult,
)
from stagehand.utils import (
convert_dict_keys_to_snake_case,
inject_urls,
transform_url_strings_to_ids,
)
T = TypeVar("T", bound=BaseModel)
class ExtractHandler:
"""Handler for processing extract operations locally."""
def __init__(
self, stagehand_page, stagehand_client, user_provided_instructions=None
):
"""
Initialize the ExtractHandler.
Args:
stagehand_page: StagehandPage instance
stagehand_client: Stagehand client instance
user_provided_instructions: Optional custom system instructions
"""
self.stagehand_page = stagehand_page
self.stagehand = stagehand_client
self.logger = stagehand_client.logger
self.user_provided_instructions = user_provided_instructions
async def extract(
self,
options: Optional[ExtractOptions] = None,
schema: Optional[type[BaseModel]] = None,
) -> ExtractResult:
"""
Execute an extraction operation locally.
Args:
options: ExtractOptions containing the instruction and other parameters
schema: Optional Pydantic model for structured output
Returns:
ExtractResult instance
"""
if not options:
# If no options provided, extract the entire page text
self.logger.info("Extracting entire page text")
return await self._extract_page_text()
instruction = options.instruction
# TODO add targeted extract
# selector = options.selector
# TODO: add schema to log
self.logger.debug(
"extract",
category="extract",
auxiliary={"instruction": instruction},
)
self.logger.info(
f"Starting extraction with instruction: '{instruction}'", category="extract"
)
# Start inference timer if available
if hasattr(self.stagehand, "start_inference_timer"):
self.stagehand.start_inference_timer()
# Wait for DOM to settle
await self.stagehand_page._wait_for_settled_dom()
# TODO add targeted extract
# target_xpath = (
# selector.replace("xpath=", "")
# if selector and selector.startswith("xpath=")
# else ""
# )
# Get accessibility tree data
tree = await get_accessibility_tree(self.stagehand_page, self.logger)
self.logger.info("Getting accessibility tree data")
output_string = tree["simplified"]
id_to_url_mapping = tree.get("idToUrl", {})
# Transform schema URL fields to numeric IDs if necessary
transformed_schema = schema
url_paths = []
if schema:
# TODO: Remove this once we have a better way to handle URLs
transformed_schema, url_paths = transform_url_strings_to_ids(schema)
else:
schema = transformed_schema = DefaultExtractSchema
# Use inference to call the LLM
extraction_response = await extract_inference(
instruction=instruction,
tree_elements=output_string,
schema=transformed_schema,
llm_client=self.stagehand.llm,
user_provided_instructions=self.user_provided_instructions,
logger=self.logger,
log_inference_to_file=False, # TODO: Implement logging to file if needed
)
# Extract metrics from response and update them directly
prompt_tokens = extraction_response.get("prompt_tokens", 0)
completion_tokens = extraction_response.get("completion_tokens", 0)
inference_time_ms = extraction_response.get("inference_time_ms", 0)
# Update metrics directly using the Stagehand client
self.stagehand.update_metrics(
StagehandFunctionName.EXTRACT,
prompt_tokens,
completion_tokens,
inference_time_ms,
)
# Process extraction response
raw_data_dict = extraction_response.get("data", {})
metadata = extraction_response.get("metadata", {})
# Inject URLs back into result if necessary
if url_paths:
inject_urls(
raw_data_dict, url_paths, id_to_url_mapping
) # Modifies raw_data_dict in place
if metadata.get("completed"):
self.logger.debug(
"Extraction completed successfully",
auxiliary={"result": raw_data_dict},
)
else:
self.logger.debug(
"Extraction incomplete after processing all data",
auxiliary={"result": raw_data_dict},
)
processed_data_payload = raw_data_dict # Default to the raw dictionary
if schema and isinstance(
raw_data_dict, dict
): # schema is the Pydantic model type
# Try direct validation first
try:
validated_model_instance = schema.model_validate(raw_data_dict)
processed_data_payload = validated_model_instance
except Exception as first_error:
# Fallback: attempt camelCase→snake_case key normalization, then re-validate
try:
normalized = convert_dict_keys_to_snake_case(raw_data_dict)
validated_model_instance = schema.model_validate(normalized)
processed_data_payload = validated_model_instance
except Exception as second_error:
self.logger.error(
f"Failed to validate extracted data against schema {schema.__name__}: {first_error}. "
f"Normalization retry also failed: {second_error}. Keeping raw data dict in .data field."
)
# Create ExtractResult object
result = ExtractResult(
data=processed_data_payload,
)
return result
async def _extract_page_text(self) -> ExtractResult:
"""Extract just the text content from the page."""
await self.stagehand_page._wait_for_settled_dom()
tree = await get_accessibility_tree(self.stagehand_page, self.logger)
output_string = tree["simplified"]
output_dict = {"page_text": output_string}
validated_model = EmptyExtractSchema.model_validate(output_dict)
return ExtractResult(data=validated_model).data