Skip to content

Commit 8280a47

Browse files
committed
lint fixes after merge
1 parent 64c1b15 commit 8280a47

File tree

4 files changed

+43
-40
lines changed

4 files changed

+43
-40
lines changed

src/stagehand/_pydantic_extract.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import inspect
66
import logging
7-
from typing import Any, Dict, Type
7+
from typing import Any, Dict, Type, cast
8+
from typing_extensions import Literal
89

910
from pydantic import BaseModel, ConfigDict
1011

@@ -58,14 +59,17 @@ def validate_extract_response(
5859

5960
@lru_cache(maxsize=None)
6061
def _validation_schema(schema: Type[BaseModel], strict_response_validation: bool) -> Type[BaseModel]:
61-
extra_behavior = "forbid" if strict_response_validation else "allow"
62-
validation_schema = type(
63-
f"{schema.__name__}ExtractValidation",
64-
(schema,),
65-
{
66-
"__module__": schema.__module__,
67-
"model_config": ConfigDict(extra=extra_behavior),
68-
},
62+
extra_behavior: Literal["allow", "forbid"] = "forbid" if strict_response_validation else "allow"
63+
validation_schema = cast(
64+
Type[BaseModel],
65+
type(
66+
f"{schema.__name__}ExtractValidation",
67+
(schema,),
68+
{
69+
"__module__": schema.__module__,
70+
"model_config": ConfigDict(extra=extra_behavior),
71+
},
72+
),
6973
)
7074
validation_schema.model_rebuild(force=True)
7175
return validation_schema

src/stagehand/_session_extract.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, cast
5+
from typing import Any, Mapping, cast
6+
from typing_extensions import Unpack
67

78
import httpx
8-
from typing_extensions import Unpack
99

10-
from ._pydantic_extract import is_pydantic_model, pydantic_model_to_json_schema, validate_extract_response
11-
from .session import AsyncSession, Session
12-
from ._types import Body, Headers, NotGiven, Query, not_given
1310
from .types import session_extract_params
11+
from ._types import Body, Query, Headers, NotGiven, not_given
12+
from .session import Session, AsyncSession
13+
from ._pydantic_extract import is_pydantic_model, validate_extract_response, pydantic_model_to_json_schema
1414
from .types.session_extract_response import SessionExtractResponse
1515

1616
_ORIGINAL_SESSION_EXTRACT = Session.extract
@@ -104,23 +104,22 @@ async def _async_extract( # type: ignore[override, misc]
104104

105105

106106
def _with_schema(
107-
params: session_extract_params.SessionExtractParamsNonStreaming,
107+
params: Mapping[str, object],
108108
schema: dict[str, object] | type | None,
109109
) -> session_extract_params.SessionExtractParamsNonStreaming:
110110
api_params = dict(params)
111111
if schema is not None:
112112
api_params["schema"] = cast(Any, schema)
113113
return cast(session_extract_params.SessionExtractParamsNonStreaming, api_params)
114114

115-
116115
_sync_extract.__module__ = _ORIGINAL_SESSION_EXTRACT.__module__
117116
_sync_extract.__name__ = _ORIGINAL_SESSION_EXTRACT.__name__
118117
_sync_extract.__qualname__ = _ORIGINAL_SESSION_EXTRACT.__qualname__
119118
_sync_extract.__doc__ = _ORIGINAL_SESSION_EXTRACT.__doc__
120-
_sync_extract.__stagehand_pydantic_extract_patch__ = True
119+
setattr(_sync_extract, "__stagehand_pydantic_extract_patch__", True) # noqa: B010
121120

122121
_async_extract.__module__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__module__
123122
_async_extract.__name__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__name__
124123
_async_extract.__qualname__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__qualname__
125124
_async_extract.__doc__ = _ORIGINAL_ASYNC_SESSION_EXTRACT.__doc__
126-
_async_extract.__stagehand_pydantic_extract_patch__ = True
125+
setattr(_async_extract, "__stagehand_pydantic_extract_patch__", True) # noqa: B010

src/stagehand/resources/sessions_helpers.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,26 @@
66

77
import httpx
88

9-
from .._session_extract import install_pydantic_extract_patch
10-
from .._compat import cached_property
11-
from .._response import (
12-
async_to_raw_response_wrapper,
13-
async_to_streamed_response_wrapper,
14-
to_raw_response_wrapper,
15-
to_streamed_response_wrapper,
16-
)
9+
from ..types import session_start_params
1710
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
11+
from .._compat import cached_property
1812
from ..session import Session, AsyncSession
19-
from ..types import session_start_params
20-
from ..types.session_start_response import SessionStartResponse
2113
from .sessions import (
22-
AsyncSessionsResource,
23-
AsyncSessionsResourceWithRawResponse,
24-
AsyncSessionsResourceWithStreamingResponse,
2514
SessionsResource,
15+
AsyncSessionsResource,
2616
SessionsResourceWithRawResponse,
17+
AsyncSessionsResourceWithRawResponse,
2718
SessionsResourceWithStreamingResponse,
19+
AsyncSessionsResourceWithStreamingResponse,
2820
)
21+
from .._response import (
22+
to_raw_response_wrapper,
23+
to_streamed_response_wrapper,
24+
async_to_raw_response_wrapper,
25+
async_to_streamed_response_wrapper,
26+
)
27+
from .._session_extract import install_pydantic_extract_patch
28+
from ..types.session_start_response import SessionStartResponse
2929

3030
install_pydantic_extract_patch()
3131

tests/test_session_extract_pydantic.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44

55
import os
66
import json
7-
from typing import cast
7+
from typing import Any, cast
88

99
import httpx
1010
import pytest
11-
from pydantic import BaseModel
1211
from respx import MockRouter
12+
from pydantic import BaseModel
1313
from respx.models import Call
1414

15-
from stagehand import AsyncStagehand, Stagehand
15+
from stagehand import Stagehand, AsyncStagehand
1616

1717
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
1818

@@ -45,7 +45,7 @@ def test_session_extract_accepts_pydantic_schema(respx_mock: MockRouter, client:
4545
)
4646

4747
session = client.sessions.start(model_name="openai/gpt-5-nano")
48-
response = session.extract(instruction="extract the user", schema=ExtractedUser)
48+
response = session.extract(instruction="extract the user", schema=cast(Any, ExtractedUser))
4949

5050
assert isinstance(response.data.result, ExtractedUser)
5151
assert response.data.result.user_name == "Ada"
@@ -79,7 +79,7 @@ def test_session_extract_allows_extra_fields_when_client_is_non_strict(
7979
)
8080

8181
session = client.sessions.start(model_name="openai/gpt-5-nano")
82-
response = session.extract(instruction="extract the user", schema=ExtractedName)
82+
response = session.extract(instruction="extract the user", schema=cast(Any, ExtractedName))
8383

8484
assert isinstance(response.data.result, ExtractedName)
8585
assert response.data.result.user_name == "Ada"
@@ -107,7 +107,7 @@ def test_session_extract_rejects_extra_fields_when_client_is_strict(
107107
)
108108

109109
session = client.sessions.start(model_name="openai/gpt-5-nano")
110-
response = session.extract(instruction="extract the user", schema=ExtractedName)
110+
response = session.extract(instruction="extract the user", schema=cast(Any, ExtractedName))
111111

112112
assert response.data.result == {"userName": "Ada", "favoriteColor": "blue"}
113113

@@ -133,7 +133,7 @@ async def test_async_session_extract_accepts_pydantic_schema(
133133
)
134134

135135
session = await async_client.sessions.start(model_name="openai/gpt-5-nano")
136-
response = await session.extract(instruction="extract the user", schema=ExtractedUser)
136+
response = await session.extract(instruction="extract the user", schema=cast(Any, ExtractedUser))
137137

138138
assert isinstance(response.data.result, ExtractedUser)
139139
assert response.data.result.user_name == "Grace"
@@ -167,7 +167,7 @@ async def test_async_session_extract_allows_extra_fields_when_client_is_non_stri
167167
)
168168

169169
session = await async_client.sessions.start(model_name="openai/gpt-5-nano")
170-
response = await session.extract(instruction="extract the user", schema=ExtractedName)
170+
response = await session.extract(instruction="extract the user", schema=cast(Any, ExtractedName))
171171

172172
assert isinstance(response.data.result, ExtractedName)
173173
assert response.data.result.user_name == "Grace"
@@ -195,6 +195,6 @@ async def test_async_session_extract_rejects_extra_fields_when_client_is_strict(
195195
)
196196

197197
session = await async_client.sessions.start(model_name="openai/gpt-5-nano")
198-
response = await session.extract(instruction="extract the user", schema=ExtractedName)
198+
response = await session.extract(instruction="extract the user", schema=cast(Any, ExtractedName))
199199

200200
assert response.data.result == {"userName": "Grace", "favoriteColor": "green"}

0 commit comments

Comments
 (0)