Skip to content

Commit 64c1b15

Browse files
authored
Move custom code to dedicated files and cleanup v1 vs v2 checks (#331)
* move custom code to dedicated files and cleanup v1 vs v2 checks * hard-require pydantic v2 or greater * lint
1 parent 14b032a commit 64c1b15

File tree

9 files changed

+572
-232
lines changed

9 files changed

+572
-232
lines changed

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ dev = [
6464
"pytest-xdist>=3.6.1",
6565
"dotenv>=0.9.9",
6666
]
67-
pydantic-v2 = [
68-
"pydantic~=2.0 ; python_full_version < '3.14'",
69-
"pydantic~=2.12 ; python_full_version >= '3.14'",
70-
]
7167

7268
[build-system]
7369
requires = ["hatchling==1.26.3", "hatch-fancy-pypi-readme", "packaging"]

src/stagehand/_compat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919

2020
PYDANTIC_V1 = pydantic.VERSION.startswith("1.")
2121

22+
if PYDANTIC_V1:
23+
raise ImportError(
24+
f"stagehand requires Pydantic v2 or newer; found Pydantic {pydantic.VERSION}. "
25+
"Install `pydantic>=2,<3`."
26+
)
27+
2228
if TYPE_CHECKING:
2329

2430
def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001

src/stagehand/_pydantic_extract.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import logging
77
from typing import Any, Dict, Type
88

9-
from pydantic import BaseModel
9+
from pydantic import BaseModel, ConfigDict
10+
11+
from ._utils import lru_cache
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -27,7 +29,9 @@ def pydantic_model_to_json_schema(schema: Type[BaseModel]) -> Dict[str, object]:
2729
return schema.model_json_schema()
2830

2931

30-
def validate_extract_response(result: object, schema: Type[BaseModel]) -> Any:
32+
def validate_extract_response(
33+
result: object, schema: Type[BaseModel], *, strict_response_validation: bool
34+
) -> Any:
3135
"""Validate raw extract result data against a Pydantic model.
3236
3337
Tries direct validation first. On failure, falls back to normalizing
@@ -36,12 +40,13 @@ def validate_extract_response(result: object, schema: Type[BaseModel]) -> Any:
3640
Returns the validated Pydantic model instance, or the raw result if
3741
both attempts fail.
3842
"""
43+
validation_schema = _validation_schema(schema, strict_response_validation)
3944
try:
40-
return schema.model_validate(result)
45+
return validation_schema.model_validate(result)
4146
except Exception:
4247
try:
4348
normalized = _convert_dict_keys_to_snake_case(result)
44-
return schema.model_validate(normalized)
49+
return validation_schema.model_validate(normalized)
4550
except Exception:
4651
logger.warning(
4752
"Failed to validate extracted data against schema %s. "
@@ -51,6 +56,21 @@ def validate_extract_response(result: object, schema: Type[BaseModel]) -> Any:
5156
return result
5257

5358

59+
@lru_cache(maxsize=None)
60+
def _validation_schema(schema: Type[BaseModel], strict_response_validation: bool) -> Type[BaseModel]:
61+
extra_behavior = "forbid" if strict_response_validation else "allow"
62+
validation_schema = type(
63+
f"{schema.__name__}ExtractValidation",
64+
(schema,),
65+
{
66+
"__module__": schema.__module__,
67+
"model_config": ConfigDict(extra=extra_behavior),
68+
},
69+
)
70+
validation_schema.model_rebuild(force=True)
71+
return validation_schema
72+
73+
5474
def _camel_to_snake(name: str) -> str:
5575
"""Convert a camelCase or PascalCase string to snake_case."""
5676
chars: list[str] = []

src/stagehand/_session_extract.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Custom extract patch installed on top of the session helpers."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, cast
6+
7+
import httpx
8+
from typing_extensions import Unpack
9+
10+
from ._pydantic_extract import is_pydantic_model, pydantic_model_to_json_schema, validate_extract_response
11+
from .session import AsyncSession, Session
12+
from ._types import Body, Headers, NotGiven, Query, not_given
13+
from .types import session_extract_params
14+
from .types.session_extract_response import SessionExtractResponse
15+
16+
_ORIGINAL_SESSION_EXTRACT = Session.extract
17+
_ORIGINAL_ASYNC_SESSION_EXTRACT = AsyncSession.extract
18+
19+
20+
def install_pydantic_extract_patch() -> None:
21+
if getattr(Session.extract, "__stagehand_pydantic_extract_patch__", False):
22+
return
23+
24+
Session.extract = _sync_extract # type: ignore[assignment]
25+
AsyncSession.extract = _async_extract # type: ignore[assignment]
26+
27+
28+
def _sync_extract( # type: ignore[override, misc]
29+
self: Session,
30+
*,
31+
schema: dict[str, object] | type | None = None,
32+
page: Any | None = None,
33+
extra_headers: Headers | None = None,
34+
extra_query: Query | None = None,
35+
extra_body: Body | None = None,
36+
timeout: float | httpx.Timeout | None | NotGiven = not_given,
37+
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
38+
) -> SessionExtractResponse:
39+
params_schema = params.pop("schema", None) # type: ignore[misc]
40+
resolved_schema = schema if schema is not None else params_schema
41+
42+
pydantic_cls: type[Any] | None = None
43+
if is_pydantic_model(resolved_schema):
44+
pydantic_cls = resolved_schema # type: ignore[assignment]
45+
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]
46+
47+
response = _ORIGINAL_SESSION_EXTRACT(
48+
self,
49+
page=page,
50+
extra_headers=extra_headers,
51+
extra_query=extra_query,
52+
extra_body=extra_body,
53+
timeout=timeout,
54+
**_with_schema(params, resolved_schema),
55+
)
56+
57+
if pydantic_cls is not None and response.data and response.data.result is not None:
58+
response.data.result = validate_extract_response(
59+
response.data.result,
60+
pydantic_cls,
61+
strict_response_validation=self._client._strict_response_validation,
62+
)
63+
64+
return response
65+
66+
67+
async def _async_extract( # type: ignore[override, misc]
68+
self: AsyncSession,
69+
*,
70+
schema: dict[str, object] | type | None = None,
71+
page: Any | None = None,
72+
extra_headers: Headers | None = None,
73+
extra_query: Query | None = None,
74+
extra_body: Body | None = None,
75+
timeout: float | httpx.Timeout | None | NotGiven = not_given,
76+
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
77+
) -> SessionExtractResponse:
78+
params_schema = params.pop("schema", None) # type: ignore[misc]
79+
resolved_schema = schema if schema is not None else params_schema
80+
81+
pydantic_cls: type[Any] | None = None
82+
if is_pydantic_model(resolved_schema):
83+
pydantic_cls = resolved_schema # type: ignore[assignment]
84+
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]
85+
86+
response = await _ORIGINAL_ASYNC_SESSION_EXTRACT(
87+
self,
88+
page=page,
89+
extra_headers=extra_headers,
90+
extra_query=extra_query,
91+
extra_body=extra_body,
92+
timeout=timeout,
93+
**_with_schema(params, resolved_schema),
94+
)
95+
96+
if pydantic_cls is not None and response.data and response.data.result is not None:
97+
response.data.result = validate_extract_response(
98+
response.data.result,
99+
pydantic_cls,
100+
strict_response_validation=self._client._strict_response_validation,
101+
)
102+
103+
return response
104+
105+
106+
def _with_schema(
107+
params: session_extract_params.SessionExtractParamsNonStreaming,
108+
schema: dict[str, object] | type | None,
109+
) -> session_extract_params.SessionExtractParamsNonStreaming:
110+
api_params = dict(params)
111+
if schema is not None:
112+
api_params["schema"] = cast(Any, schema)
113+
return cast(session_extract_params.SessionExtractParamsNonStreaming, api_params)
114+
115+
116+
_sync_extract.__module__ = _ORIGINAL_SESSION_EXTRACT.__module__
117+
_sync_extract.__name__ = _ORIGINAL_SESSION_EXTRACT.__name__
118+
_sync_extract.__qualname__ = _ORIGINAL_SESSION_EXTRACT.__qualname__
119+
_sync_extract.__doc__ = _ORIGINAL_SESSION_EXTRACT.__doc__
120+
_sync_extract.__stagehand_pydantic_extract_patch__ = True
121+
122+
_async_extract.__module__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__module__
123+
_async_extract.__name__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__name__
124+
_async_extract.__qualname__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__qualname__
125+
_async_extract.__doc__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__doc__
126+
_async_extract.__stagehand_pydantic_extract_patch__ = True

src/stagehand/resources/sessions_helpers.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,28 @@
66

77
import httpx
88

9-
from ..types import session_start_params
10-
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
9+
from .._session_extract import install_pydantic_extract_patch
1110
from .._compat import cached_property
11+
from .._response import (
12+
async_to_raw_response_wrapper,
13+
async_to_streamed_response_wrapper,
14+
to_raw_response_wrapper,
15+
to_streamed_response_wrapper,
16+
)
17+
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
1218
from ..session import Session, AsyncSession
19+
from ..types import session_start_params
20+
from ..types.session_start_response import SessionStartResponse
1321
from .sessions import (
14-
SessionsResource,
1522
AsyncSessionsResource,
16-
SessionsResourceWithRawResponse,
1723
AsyncSessionsResourceWithRawResponse,
18-
SessionsResourceWithStreamingResponse,
1924
AsyncSessionsResourceWithStreamingResponse,
25+
SessionsResource,
26+
SessionsResourceWithRawResponse,
27+
SessionsResourceWithStreamingResponse,
2028
)
21-
from .._response import (
22-
to_raw_response_wrapper,
23-
to_streamed_response_wrapper,
24-
async_to_raw_response_wrapper,
25-
async_to_streamed_response_wrapper,
26-
)
27-
from ..types.session_start_response import SessionStartResponse
29+
30+
install_pydantic_extract_patch()
2831

2932

3033
class SessionsResourceWithHelpersRawResponse(SessionsResourceWithRawResponse):

src/stagehand/session.py

Lines changed: 18 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818
from ._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
1919
from ._exceptions import StagehandError
20-
from ._pydantic_extract import is_pydantic_model, validate_extract_response, pydantic_model_to_json_schema
2120
from .types.session_act_response import SessionActResponse
2221
from .types.session_end_response import SessionEndResponse
2322
from .types.session_start_response import Data as SessionStartResponseData, SessionStartResponse
@@ -201,47 +200,28 @@ def observe(
201200
),
202201
)
203202

204-
def extract( # type: ignore[misc]
203+
def extract(
205204
self,
206205
*,
207-
schema: dict[str, object] | type | None = None,
208206
page: Any | None = None,
209207
extra_headers: Headers | None = None,
210208
extra_query: Query | None = None,
211209
extra_body: Body | None = None,
212210
timeout: float | httpx.Timeout | None | NotGiven = not_given,
213-
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
211+
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming],
214212
) -> SessionExtractResponse:
215-
# If the caller passed schema via **params (TypedDict), prefer the explicit kwarg.
216-
params_schema = params.pop("schema", None) # type: ignore[misc]
217-
resolved_schema = schema if schema is not None else params_schema
218-
219-
pydantic_cls: type[Any] | None = None
220-
if is_pydantic_model(resolved_schema):
221-
pydantic_cls = resolved_schema # type: ignore[assignment]
222-
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]
223-
224-
api_params: dict[str, Any] = _maybe_inject_frame_id(dict(params), page)
225-
if resolved_schema is not None:
226-
api_params["schema"] = resolved_schema
227-
228-
response: SessionExtractResponse = cast(
213+
return cast(
229214
SessionExtractResponse,
230215
self._client.sessions.extract(
231-
id=self.id,
232-
extra_headers=extra_headers,
233-
extra_query=extra_query,
234-
extra_body=extra_body,
235-
timeout=timeout,
236-
**api_params,
216+
id=self.id,
217+
extra_headers=extra_headers,
218+
extra_query=extra_query,
219+
extra_body=extra_body,
220+
timeout=timeout,
221+
**_maybe_inject_frame_id(dict(params), page),
237222
),
238223
)
239224

240-
if pydantic_cls is not None and response.data and response.data.result is not None:
241-
response.data.result = validate_extract_response(response.data.result, pydantic_cls)
242-
243-
return response
244-
245225
def execute(
246226
self,
247227
*,
@@ -355,47 +335,28 @@ async def observe(
355335
),
356336
)
357337

358-
async def extract( # type: ignore[misc]
338+
async def extract(
359339
self,
360340
*,
361-
schema: dict[str, object] | type | None = None,
362341
page: Any | None = None,
363342
extra_headers: Headers | None = None,
364343
extra_query: Query | None = None,
365344
extra_body: Body | None = None,
366345
timeout: float | httpx.Timeout | None | NotGiven = not_given,
367-
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
346+
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming],
368347
) -> SessionExtractResponse:
369-
# If the caller passed schema via **params (TypedDict), prefer the explicit kwarg.
370-
params_schema = params.pop("schema", None) # type: ignore[misc]
371-
resolved_schema = schema if schema is not None else params_schema
372-
373-
pydantic_cls: type[Any] | None = None
374-
if is_pydantic_model(resolved_schema):
375-
pydantic_cls = resolved_schema # type: ignore[assignment]
376-
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]
377-
378-
api_params: dict[str, Any] = await _maybe_inject_frame_id_async(dict(params), page)
379-
if resolved_schema is not None:
380-
api_params["schema"] = resolved_schema
381-
382-
response: SessionExtractResponse = cast(
348+
return cast(
383349
SessionExtractResponse,
384350
await self._client.sessions.extract(
385-
id=self.id,
386-
extra_headers=extra_headers,
387-
extra_query=extra_query,
388-
extra_body=extra_body,
389-
timeout=timeout,
390-
**api_params,
351+
id=self.id,
352+
extra_headers=extra_headers,
353+
extra_query=extra_query,
354+
extra_body=extra_body,
355+
timeout=timeout,
356+
**(await _maybe_inject_frame_id_async(dict(params), page)),
391357
),
392358
)
393359

394-
if pydantic_cls is not None and response.data and response.data.result is not None:
395-
response.data.result = validate_extract_response(response.data.result, pydantic_cls)
396-
397-
return response
398-
399360
async def execute(
400361
self,
401362
*,

0 commit comments

Comments
 (0)