Skip to content

Commit b8fea20

Browse files
committed
move custom code to dedicated files and cleanup v1 vs v2 checks
1 parent 14b032a commit b8fea20

File tree

10 files changed

+745
-286
lines changed

10 files changed

+745
-286
lines changed

pyproject.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ authors = [
1010

1111
dependencies = [
1212
"httpx>=0.23.0, <1",
13-
"pydantic>=2.0.0, <3",
13+
"pydantic>=1.9.0, <3",
1414
"typing-extensions>=4.14, <5",
1515
"anyio>=3.5.0, <5",
1616
"distro>=1.7.0, <2",
@@ -46,7 +46,12 @@ aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"]
4646
[tool.uv]
4747
managed = true
4848
required-version = ">=0.9"
49-
conflicts = []
49+
conflicts = [
50+
[
51+
{ group = "pydantic-v1" },
52+
{ group = "pydantic-v2" },
53+
],
54+
]
5055

5156
[dependency-groups]
5257
# version pins are in uv.lock
@@ -64,6 +69,9 @@ dev = [
6469
"pytest-xdist>=3.6.1",
6570
"dotenv>=0.9.9",
6671
]
72+
pydantic-v1 = [
73+
"pydantic>=1.9.0,<2",
74+
]
6775
pydantic-v2 = [
6876
"pydantic~=2.0 ; python_full_version < '3.14'",
6977
"pydantic~=2.12 ; python_full_version >= '3.14'",

scripts/test

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,14 @@ PY_VERSION_MIN=">=3.9.0"
1414
PY_VERSION_MAX=">=3.14.0"
1515

1616
function run_tests() {
17-
echo "==> Running tests"
17+
echo "==> Running tests with Pydantic v2"
1818
uv run --isolated --all-extras pytest "$@"
19+
20+
# Skip Pydantic v1 tests on latest Python (not supported)
21+
if [[ "$UV_PYTHON" != "$PY_VERSION_MAX" ]]; then
22+
echo "==> Running tests with Pydantic v1"
23+
uv run --isolated --all-extras --group=pydantic-v1 pytest "$@"
24+
fi
1925
}
2026

2127
# If UV_PYTHON is already set in the environment, just run the command once

src/stagehand/_compat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919

2020
PYDANTIC_V1 = pydantic.VERSION.startswith("1.")
2121

22+
if PYDANTIC_V1:
23+
raise ImportError(
24+
f"stagehand requires Pydantic v2 or newer; found Pydantic {pydantic.VERSION}. "
25+
"Install `pydantic>=2,<3`."
26+
)
27+
2228
if TYPE_CHECKING:
2329

2430
def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001

src/stagehand/_pydantic_extract.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import logging
77
from typing import Any, Dict, Type
88

9-
from pydantic import BaseModel
9+
from pydantic import BaseModel, ConfigDict
10+
11+
from ._utils import lru_cache
1012

1113
logger = logging.getLogger(__name__)
1214

@@ -27,7 +29,9 @@ def pydantic_model_to_json_schema(schema: Type[BaseModel]) -> Dict[str, object]:
2729
return schema.model_json_schema()
2830

2931

30-
def validate_extract_response(result: object, schema: Type[BaseModel]) -> Any:
32+
def validate_extract_response(
33+
result: object, schema: Type[BaseModel], *, strict_response_validation: bool
34+
) -> Any:
3135
"""Validate raw extract result data against a Pydantic model.
3236
3337
Tries direct validation first. On failure, falls back to normalizing
@@ -36,12 +40,13 @@ def validate_extract_response(result: object, schema: Type[BaseModel]) -> Any:
3640
Returns the validated Pydantic model instance, or the raw result if
3741
both attempts fail.
3842
"""
43+
validation_schema = _validation_schema(schema, strict_response_validation)
3944
try:
40-
return schema.model_validate(result)
45+
return validation_schema.model_validate(result)
4146
except Exception:
4247
try:
4348
normalized = _convert_dict_keys_to_snake_case(result)
44-
return schema.model_validate(normalized)
49+
return validation_schema.model_validate(normalized)
4550
except Exception:
4651
logger.warning(
4752
"Failed to validate extracted data against schema %s. "
@@ -51,6 +56,21 @@ def validate_extract_response(result: object, schema: Type[BaseModel]) -> Any:
5156
return result
5257

5358

59+
@lru_cache(maxsize=None)
60+
def _validation_schema(schema: Type[BaseModel], strict_response_validation: bool) -> Type[BaseModel]:
61+
extra_behavior = "forbid" if strict_response_validation else "allow"
62+
validation_schema = type(
63+
f"{schema.__name__}ExtractValidation",
64+
(schema,),
65+
{
66+
"__module__": schema.__module__,
67+
"model_config": ConfigDict(extra=extra_behavior),
68+
},
69+
)
70+
validation_schema.model_rebuild(force=True)
71+
return validation_schema
72+
73+
5474
def _camel_to_snake(name: str) -> str:
5575
"""Convert a camelCase or PascalCase string to snake_case."""
5676
chars: list[str] = []

src/stagehand/_session_extract.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Custom extract patch installed on top of the session helpers."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, cast
6+
7+
import httpx
8+
from typing_extensions import Unpack
9+
10+
from ._pydantic_extract import is_pydantic_model, pydantic_model_to_json_schema, validate_extract_response
11+
from ._types import Body, Headers, NotGiven, Query, not_given
12+
from .session import AsyncSession, Session
13+
from .types import session_extract_params
14+
from .types.session_extract_response import SessionExtractResponse
15+
16+
_ORIGINAL_SESSION_EXTRACT = Session.extract
17+
_ORIGINAL_ASYNC_SESSION_EXTRACT = AsyncSession.extract
18+
19+
20+
def install_pydantic_extract_patch() -> None:
21+
if getattr(Session.extract, "__stagehand_pydantic_extract_patch__", False):
22+
return
23+
24+
Session.extract = _sync_extract # type: ignore[assignment]
25+
AsyncSession.extract = _async_extract # type: ignore[assignment]
26+
27+
28+
def _sync_extract( # type: ignore[override, misc]
29+
self: Session,
30+
*,
31+
schema: dict[str, object] | type | None = None,
32+
page: Any | None = None,
33+
extra_headers: Headers | None = None,
34+
extra_query: Query | None = None,
35+
extra_body: Body | None = None,
36+
timeout: float | httpx.Timeout | None | NotGiven = not_given,
37+
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
38+
) -> SessionExtractResponse:
39+
params_schema = params.pop("schema", None) # type: ignore[misc]
40+
resolved_schema = schema if schema is not None else params_schema
41+
42+
pydantic_cls: type[Any] | None = None
43+
if is_pydantic_model(resolved_schema):
44+
pydantic_cls = resolved_schema # type: ignore[assignment]
45+
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]
46+
47+
response = _ORIGINAL_SESSION_EXTRACT(
48+
self,
49+
page=page,
50+
extra_headers=extra_headers,
51+
extra_query=extra_query,
52+
extra_body=extra_body,
53+
timeout=timeout,
54+
**_with_schema(params, resolved_schema),
55+
)
56+
57+
if pydantic_cls is not None and response.data and response.data.result is not None:
58+
response.data.result = validate_extract_response(
59+
response.data.result,
60+
pydantic_cls,
61+
strict_response_validation=self._client._strict_response_validation,
62+
)
63+
64+
return response
65+
66+
67+
async def _async_extract( # type: ignore[override, misc]
68+
self: AsyncSession,
69+
*,
70+
schema: dict[str, object] | type | None = None,
71+
page: Any | None = None,
72+
extra_headers: Headers | None = None,
73+
extra_query: Query | None = None,
74+
extra_body: Body | None = None,
75+
timeout: float | httpx.Timeout | None | NotGiven = not_given,
76+
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
77+
) -> SessionExtractResponse:
78+
params_schema = params.pop("schema", None) # type: ignore[misc]
79+
resolved_schema = schema if schema is not None else params_schema
80+
81+
pydantic_cls: type[Any] | None = None
82+
if is_pydantic_model(resolved_schema):
83+
pydantic_cls = resolved_schema # type: ignore[assignment]
84+
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]
85+
86+
response = await _ORIGINAL_ASYNC_SESSION_EXTRACT(
87+
self,
88+
page=page,
89+
extra_headers=extra_headers,
90+
extra_query=extra_query,
91+
extra_body=extra_body,
92+
timeout=timeout,
93+
**_with_schema(params, resolved_schema),
94+
)
95+
96+
if pydantic_cls is not None and response.data and response.data.result is not None:
97+
response.data.result = validate_extract_response(
98+
response.data.result,
99+
pydantic_cls,
100+
strict_response_validation=self._client._strict_response_validation,
101+
)
102+
103+
return response
104+
105+
106+
def _with_schema(
107+
params: session_extract_params.SessionExtractParamsNonStreaming,
108+
schema: dict[str, object] | type | None,
109+
) -> session_extract_params.SessionExtractParamsNonStreaming:
110+
api_params = dict(params)
111+
if schema is not None:
112+
api_params["schema"] = cast(Any, schema)
113+
return cast(session_extract_params.SessionExtractParamsNonStreaming, api_params)
114+
115+
116+
_sync_extract.__module__ = _ORIGINAL_SESSION_EXTRACT.__module__
117+
_sync_extract.__name__ = _ORIGINAL_SESSION_EXTRACT.__name__
118+
_sync_extract.__qualname__ = _ORIGINAL_SESSION_EXTRACT.__qualname__
119+
_sync_extract.__doc__ = _ORIGINAL_SESSION_EXTRACT.__doc__
120+
_sync_extract.__stagehand_pydantic_extract_patch__ = True
121+
122+
_async_extract.__module__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__module__
123+
_async_extract.__name__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__name__
124+
_async_extract.__qualname__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__qualname__
125+
_async_extract.__doc__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__doc__
126+
_async_extract.__stagehand_pydantic_extract_patch__ = True

src/stagehand/resources/sessions_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import httpx
88

9+
from .._session_extract import install_pydantic_extract_patch
910
from ..types import session_start_params
1011
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
1112
from .._compat import cached_property
@@ -26,6 +27,8 @@
2627
)
2728
from ..types.session_start_response import SessionStartResponse
2829

30+
install_pydantic_extract_patch()
31+
2932

3033
class SessionsResourceWithHelpersRawResponse(SessionsResourceWithRawResponse):
3134
def __init__(self, sessions: SessionsResourceWithHelpers) -> None: # type: ignore[name-defined]

src/stagehand/session.py

Lines changed: 18 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818
from ._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
1919
from ._exceptions import StagehandError
20-
from ._pydantic_extract import is_pydantic_model, validate_extract_response, pydantic_model_to_json_schema
2120
from .types.session_act_response import SessionActResponse
2221
from .types.session_end_response import SessionEndResponse
2322
from .types.session_start_response import Data as SessionStartResponseData, SessionStartResponse
@@ -201,47 +200,28 @@ def observe(
201200
),
202201
)
203202

204-
def extract( # type: ignore[misc]
203+
def extract(
205204
self,
206205
*,
207-
schema: dict[str, object] | type | None = None,
208206
page: Any | None = None,
209207
extra_headers: Headers | None = None,
210208
extra_query: Query | None = None,
211209
extra_body: Body | None = None,
212210
timeout: float | httpx.Timeout | None | NotGiven = not_given,
213-
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
211+
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming],
214212
) -> SessionExtractResponse:
215-
# If the caller passed schema via **params (TypedDict), prefer the explicit kwarg.
216-
params_schema = params.pop("schema", None) # type: ignore[misc]
217-
resolved_schema = schema if schema is not None else params_schema
218-
219-
pydantic_cls: type[Any] | None = None
220-
if is_pydantic_model(resolved_schema):
221-
pydantic_cls = resolved_schema # type: ignore[assignment]
222-
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]
223-
224-
api_params: dict[str, Any] = _maybe_inject_frame_id(dict(params), page)
225-
if resolved_schema is not None:
226-
api_params["schema"] = resolved_schema
227-
228-
response: SessionExtractResponse = cast(
213+
return cast(
229214
SessionExtractResponse,
230215
self._client.sessions.extract(
231-
id=self.id,
232-
extra_headers=extra_headers,
233-
extra_query=extra_query,
234-
extra_body=extra_body,
235-
timeout=timeout,
236-
**api_params,
216+
id=self.id,
217+
extra_headers=extra_headers,
218+
extra_query=extra_query,
219+
extra_body=extra_body,
220+
timeout=timeout,
221+
**_maybe_inject_frame_id(dict(params), page),
237222
),
238223
)
239224

240-
if pydantic_cls is not None and response.data and response.data.result is not None:
241-
response.data.result = validate_extract_response(response.data.result, pydantic_cls)
242-
243-
return response
244-
245225
def execute(
246226
self,
247227
*,
@@ -355,47 +335,28 @@ async def observe(
355335
),
356336
)
357337

358-
async def extract( # type: ignore[misc]
338+
async def extract(
359339
self,
360340
*,
361-
schema: dict[str, object] | type | None = None,
362341
page: Any | None = None,
363342
extra_headers: Headers | None = None,
364343
extra_query: Query | None = None,
365344
extra_body: Body | None = None,
366345
timeout: float | httpx.Timeout | None | NotGiven = not_given,
367-
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming], # pyright: ignore[reportGeneralTypeIssues]
346+
**params: Unpack[session_extract_params.SessionExtractParamsNonStreaming],
368347
) -> SessionExtractResponse:
369-
# If the caller passed schema via **params (TypedDict), prefer the explicit kwarg.
370-
params_schema = params.pop("schema", None) # type: ignore[misc]
371-
resolved_schema = schema if schema is not None else params_schema
372-
373-
pydantic_cls: type[Any] | None = None
374-
if is_pydantic_model(resolved_schema):
375-
pydantic_cls = resolved_schema # type: ignore[assignment]
376-
resolved_schema = pydantic_model_to_json_schema(pydantic_cls) # type: ignore[arg-type]
377-
378-
api_params: dict[str, Any] = await _maybe_inject_frame_id_async(dict(params), page)
379-
if resolved_schema is not None:
380-
api_params["schema"] = resolved_schema
381-
382-
response: SessionExtractResponse = cast(
348+
return cast(
383349
SessionExtractResponse,
384350
await self._client.sessions.extract(
385-
id=self.id,
386-
extra_headers=extra_headers,
387-
extra_query=extra_query,
388-
extra_body=extra_body,
389-
timeout=timeout,
390-
**api_params,
351+
id=self.id,
352+
extra_headers=extra_headers,
353+
extra_query=extra_query,
354+
extra_body=extra_body,
355+
timeout=timeout,
356+
**(await _maybe_inject_frame_id_async(dict(params), page)),
391357
),
392358
)
393359

394-
if pydantic_cls is not None and response.data and response.data.result is not None:
395-
response.data.result = validate_extract_response(response.data.result, pydantic_cls)
396-
397-
return response
398-
399360
async def execute(
400361
self,
401362
*,

0 commit comments

Comments
 (0)