Skip to content

Commit 380f5e0

Browse files
miguelg719pirate
andauthored
Support Pydantic models on Extract (#330)
* Support pydantic models as schemas for extract * pop schema * bump pydantic dep and linting fixes * update tests to remove pydantic v1 tests * Move custom code to dedicated files and cleanup v1 vs v2 checks (#331) * move custom code to dedicated files and cleanup v1 vs v2 checks * hard-require pydantic v2 or greater * lint * lint fixes after merge * Centralize extract patching in sessions helpers * Bound extract validation schema cache --------- Co-authored-by: Nick Sweeting <github@sweeting.me> Co-authored-by: Nick Sweeting <git@sweeting.me>
1 parent 99b6033 commit 380f5e0

File tree

6 files changed

+605
-329
lines changed

6 files changed

+605
-329
lines changed

pyproject.toml

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ authors = [
1010

1111
dependencies = [
1212
"httpx>=0.23.0, <1",
13-
"pydantic>=1.9.0, <3",
13+
"pydantic>=2.0.0, <3",
1414
"typing-extensions>=4.14, <5",
1515
"anyio>=3.5.0, <5",
1616
"distro>=1.7.0, <2",
@@ -46,12 +46,7 @@ aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"]
4646
[tool.uv]
4747
managed = true
4848
required-version = ">=0.9"
49-
conflicts = [
50-
[
51-
{ group = "pydantic-v1" },
52-
{ group = "pydantic-v2" },
53-
],
54-
]
49+
conflicts = []
5550

5651
[dependency-groups]
5752
# version pins are in uv.lock
@@ -69,13 +64,6 @@ dev = [
6964
"pytest-xdist>=3.6.1",
7065
"dotenv>=0.9.9",
7166
]
72-
pydantic-v1 = [
73-
"pydantic>=1.9.0,<2",
74-
]
75-
pydantic-v2 = [
76-
"pydantic~=2.0 ; python_full_version < '3.14'",
77-
"pydantic~=2.12 ; python_full_version >= '3.14'",
78-
]
7967

8068
[build-system]
8169
requires = ["hatchling==1.26.3", "hatch-fancy-pypi-readme", "packaging"]

scripts/test

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,8 @@ PY_VERSION_MIN=">=3.9.0"
1414
PY_VERSION_MAX=">=3.14.0"
1515

1616
function run_tests() {
17-
echo "==> Running tests with Pydantic v2"
17+
echo "==> Running tests"
1818
uv run --isolated --all-extras pytest "$@"
19-
20-
# Skip Pydantic v1 tests on latest Python (not supported)
21-
if [[ "$UV_PYTHON" != "$PY_VERSION_MAX" ]]; then
22-
echo "==> Running tests with Pydantic v1"
23-
uv run --isolated --all-extras --group=pydantic-v1 pytest "$@"
24-
fi
2519
}
2620

2721
# If UV_PYTHON is already set in the environment, just run the command once

src/stagehand/_compat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919

2020
PYDANTIC_V1 = pydantic.VERSION.startswith("1.")
2121

22+
if PYDANTIC_V1:
23+
raise ImportError(
24+
f"stagehand requires Pydantic v2 or newer; found Pydantic {pydantic.VERSION}. "
25+
"Install `pydantic>=2,<3`."
26+
)
27+
2228
if TYPE_CHECKING:
2329

2430
def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001

src/stagehand/resources/sessions_helpers.py

Lines changed: 191 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22

33
from __future__ import annotations
44

5-
from typing_extensions import Literal, override
5+
import inspect
6+
import logging
7+
from typing import Any, Type, Mapping, cast
8+
from typing_extensions import Unpack, Literal, override
69

710
import httpx
11+
from pydantic import BaseModel, ConfigDict
812

9-
from ..types import session_start_params
13+
from ..types import session_start_params, session_extract_params
1014
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
15+
from .._utils import lru_cache
1116
from .._compat import cached_property
1217
from ..session import Session, AsyncSession
1318
from .sessions import (
@@ -25,6 +30,190 @@
2530
async_to_streamed_response_wrapper,
2631
)
2732
from ..types.session_start_response import SessionStartResponse
33+
from ..types.session_extract_response import SessionExtractResponse
34+
35+
logger = logging.getLogger(__name__)
36+
37+
_ORIGINAL_SESSION_EXTRACT = Session.extract
38+
_ORIGINAL_ASYNC_SESSION_EXTRACT = AsyncSession.extract
39+
40+
41+
def install_pydantic_extract_patch() -> None:
42+
if getattr(Session.extract, "__stagehand_pydantic_extract_patch__", False):
43+
return
44+
45+
Session.extract = _sync_extract # type: ignore[assignment]
46+
AsyncSession.extract = _async_extract # type: ignore[assignment]
47+
48+
49+
def is_pydantic_model(schema: Any) -> bool:
50+
return inspect.isclass(schema) and issubclass(schema, BaseModel)
51+
52+
53+
def pydantic_model_to_json_schema(schema: Type[BaseModel]) -> dict[str, object]:
54+
schema.model_rebuild()
55+
return cast(dict[str, object], schema.model_json_schema())
56+
57+
58+
def validate_extract_response(
59+
result: object, schema: Type[BaseModel], *, strict_response_validation: bool
60+
) -> object:
61+
validation_schema = _validation_schema(schema, strict_response_validation)
62+
try:
63+
return validation_schema.model_validate(result)
64+
except Exception:
65+
try:
66+
normalized = _convert_dict_keys_to_snake_case(result)
67+
return validation_schema.model_validate(normalized)
68+
except Exception:
69+
logger.warning(
70+
"Failed to validate extracted data against schema %s. Returning raw data.",
71+
schema.__name__,
72+
)
73+
return result
74+
75+
76+
@lru_cache(maxsize=256)
77+
def _validation_schema(schema: Type[BaseModel], strict_response_validation: bool) -> Type[BaseModel]:
78+
extra_behavior: Literal["allow", "forbid"] = "forbid" if strict_response_validation else "allow"
79+
validation_schema = cast(
80+
Type[BaseModel],
81+
type(
82+
f"{schema.__name__}ExtractValidation",
83+
(schema,),
84+
{
85+
"__module__": schema.__module__,
86+
"model_config": ConfigDict(extra=extra_behavior),
87+
},
88+
),
89+
)
90+
validation_schema.model_rebuild(force=True)
91+
return validation_schema
92+
93+
94+
def _camel_to_snake(name: str) -> str:
95+
chars: list[str] = []
96+
for i, ch in enumerate(name):
97+
if ch.isupper() and i != 0 and not name[i - 1].isupper():
98+
chars.append("_")
99+
chars.append(ch.lower())
100+
return "".join(chars)
101+
102+
103+
def _convert_dict_keys_to_snake_case(data: Any) -> Any:
104+
if isinstance(data, dict):
105+
items = cast(dict[object, object], data).items()
106+
return {
107+
_camel_to_snake(k) if isinstance(k, str) else k: _convert_dict_keys_to_snake_case(v)
108+
for k, v in items
109+
}
110+
if isinstance(data, list):
111+
return [_convert_dict_keys_to_snake_case(item) for item in cast(list[object], data)]
112+
return data
113+
114+
115+
def _with_schema(
116+
params: Mapping[str, object],
117+
schema: dict[str, object] | type | None,
118+
) -> session_extract_params.SessionExtractParamsNonStreaming:
119+
api_params = dict(params)
120+
if schema is not None:
121+
api_params["schema"] = cast(Any, schema)
122+
return cast(session_extract_params.SessionExtractParamsNonStreaming, api_params)
123+
124+
125+
def _sync_extract( # type: ignore[override, misc]
126+
self: Session,
127+
*,
128+
schema: dict[str, object] | type | None = None,
129+
page: Any | None = None,
130+
extra_headers: Headers | None = None,
131+
extra_query: Query | None = None,
132+
extra_body: Body | None = None,
133+
timeout: float | httpx.Timeout | None | NotGiven = not_given,
134+
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
135+
) -> SessionExtractResponse:
136+
params_schema = params.pop("schema", None) # type: ignore[misc]
137+
resolved_schema = schema if schema is not None else params_schema
138+
139+
pydantic_cls: Type[BaseModel] | None = None
140+
if is_pydantic_model(resolved_schema):
141+
pydantic_cls = cast(Type[BaseModel], resolved_schema)
142+
resolved_schema = pydantic_model_to_json_schema(pydantic_cls)
143+
144+
response = _ORIGINAL_SESSION_EXTRACT(
145+
self,
146+
page=page,
147+
extra_headers=extra_headers,
148+
extra_query=extra_query,
149+
extra_body=extra_body,
150+
timeout=timeout,
151+
**_with_schema(params, resolved_schema),
152+
)
153+
154+
if pydantic_cls is not None and response.data and response.data.result is not None:
155+
response.data.result = validate_extract_response(
156+
response.data.result,
157+
pydantic_cls,
158+
strict_response_validation=self._client._strict_response_validation,
159+
)
160+
161+
return response
162+
163+
164+
async def _async_extract( # type: ignore[override, misc]
165+
self: AsyncSession,
166+
*,
167+
schema: dict[str, object] | type | None = None,
168+
page: Any | None = None,
169+
extra_headers: Headers | None = None,
170+
extra_query: Query | None = None,
171+
extra_body: Body | None = None,
172+
timeout: float | httpx.Timeout | None | NotGiven = not_given,
173+
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
174+
) -> SessionExtractResponse:
175+
params_schema = params.pop("schema", None) # type: ignore[misc]
176+
resolved_schema = schema if schema is not None else params_schema
177+
178+
pydantic_cls: Type[BaseModel] | None = None
179+
if is_pydantic_model(resolved_schema):
180+
pydantic_cls = cast(Type[BaseModel], resolved_schema)
181+
resolved_schema = pydantic_model_to_json_schema(pydantic_cls)
182+
183+
response = await _ORIGINAL_ASYNC_SESSION_EXTRACT(
184+
self,
185+
page=page,
186+
extra_headers=extra_headers,
187+
extra_query=extra_query,
188+
extra_body=extra_body,
189+
timeout=timeout,
190+
**_with_schema(params, resolved_schema),
191+
)
192+
193+
if pydantic_cls is not None and response.data and response.data.result is not None:
194+
response.data.result = validate_extract_response(
195+
response.data.result,
196+
pydantic_cls,
197+
strict_response_validation=self._client._strict_response_validation,
198+
)
199+
200+
return response
201+
202+
203+
_sync_extract.__module__ = _ORIGINAL_SESSION_EXTRACT.__module__
204+
_sync_extract.__name__ = _ORIGINAL_SESSION_EXTRACT.__name__
205+
_sync_extract.__qualname__ = _ORIGINAL_SESSION_EXTRACT.__qualname__
206+
_sync_extract.__doc__ = _ORIGINAL_SESSION_EXTRACT.__doc__
207+
setattr(_sync_extract, "__stagehand_pydantic_extract_patch__", True) # noqa: B010
208+
209+
_async_extract.__module__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__module__
210+
_async_extract.__name__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__name__
211+
_async_extract.__qualname__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__qualname__
212+
_async_extract.__doc__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__doc__
213+
setattr(_async_extract, "__stagehand_pydantic_extract_patch__", True) # noqa: B010
214+
215+
216+
install_pydantic_extract_patch()
28217

29218

30219
class SessionsResourceWithHelpersRawResponse(SessionsResourceWithRawResponse):

0 commit comments

Comments
 (0)