Skip to content

Commit 86415b6

Browse files
authored
fix(typing): typed overloads for expect_event and wait_for_event (#3061)
1 parent f3d8fd1 commit 86415b6

9 files changed

Lines changed: 1489 additions & 38 deletions

playwright/async_api/_generated.py

Lines changed: 715 additions & 14 deletions
Large diffs are not rendered by default.

playwright/sync_api/_generated.py

Lines changed: 715 additions & 14 deletions
Large diffs are not rendered by default.

scripts/documentation_provider.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,55 @@ def print_events(self, class_name: str) -> None:
243243
doc.append(f" return super().{event_type}(event=event,f=f)")
244244
print("\n".join(doc))
245245

246+
def print_event_overloads(self, class_name: str, method_name: str) -> None:
247+
"""Emit ``@typing.overload`` stubs for ``expect_event`` / ``wait_for_event``
248+
keyed on ``Literal`` event names with their payload types from api.json,
249+
so pyright/mypy can narrow the return type at call sites.
250+
Must be called right before the implementation signature is emitted.
251+
"""
252+
if class_name not in self.classes:
253+
return
254+
events = self.classes[class_name].get("events") or []
255+
if not events:
256+
return
257+
is_expect = method_name == "expect_event"
258+
async_prefix = "async " if not is_expect and self.is_async else ""
259+
if is_expect:
260+
ctx_mgr = (
261+
"AsyncEventContextManager" if self.is_async else "EventContextManager"
262+
)
263+
for event in events:
264+
payload = self.serialize_doc_type(event["type"], "")
265+
if payload.startswith("{"):
266+
payload = "typing.Dict"
267+
if "Union[" in payload:
268+
payload = payload.replace("Union[", "typing.Union[")
269+
return_type = f'{ctx_mgr}["{payload}"]' if is_expect else f'"{payload}"'
270+
event_literal = event["name"].lower()
271+
print(" @typing.overload")
272+
print(f" {async_prefix}def {method_name}(")
273+
print(" self,")
274+
print(f' event: typing.Literal["{event_literal}"],')
275+
print(
276+
f' predicate: typing.Optional[typing.Callable[["{payload}"], bool]] = None,'
277+
)
278+
print(" *,")
279+
print(" timeout: typing.Optional[float] = None,")
280+
print(f" ) -> {return_type}: ...")
281+
print("")
282+
# Catch-all overload for non-literal event names — keeps pyright happy
283+
# with `event: str` callers without falling through to `Unknown`.
284+
catchall_return = f"{ctx_mgr}[typing.Any]" if is_expect else "typing.Any"
285+
print(" @typing.overload")
286+
print(f" {async_prefix}def {method_name}(")
287+
print(" self,")
288+
print(" event: str,")
289+
print(" predicate: typing.Optional[typing.Callable[..., bool]] = None,")
290+
print(" *,")
291+
print(" timeout: typing.Optional[float] = None,")
292+
print(f" ) -> {catchall_return}: ...")
293+
print("")
294+
246295
def indent_paragraph(self, p: str, indent: str) -> str:
247296
lines = p.split("\n")
248297
result = [lines[0]]

scripts/generate_async_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def generate(t: Any) -> None:
9292
'"Disposable"', '"AsyncContextManager"'
9393
).replace('"DisposableStub"', '"AsyncContextManager"')
9494
print("")
95+
if name in ("expect_event", "wait_for_event"):
96+
documentation_provider.print_event_overloads(class_name, name)
9597
async_prefix = "async " if is_async else ""
9698
print(
9799
f" {async_prefix}def {name}({signature(value, len(name) + 9)}) -> {return_type_value}:"

scripts/generate_sync_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def generate(t: Any) -> None:
9393
'"Disposable"', '"SyncContextManager"'
9494
).replace('"DisposableStub"', '"SyncContextManager"')
9595
print("")
96+
if name in ("expect_event", "wait_for_event"):
97+
documentation_provider.print_event_overloads(class_name, name)
9698
print(
9799
f" def {name}({signature(value, len(name) + 9)}) -> {return_type_value}:"
98100
)

tests/async/test_browsercontext_events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ async def test_page_error_event_should_work(page: Page) -> None:
197197
await page.set_content('<script>throw new Error("boom")</script>')
198198
page_error = await page_error_info.value
199199
assert page_error.page == page
200-
assert "boom" in page_error.error.stack
200+
assert page_error.error.stack and "boom" in page_error.error.stack
201201

202202

203203
async def test_weberror_event_should_work(context: BrowserContext, page: Page) -> None:

tests/async/test_page_request_intercept.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import asyncio
16-
from typing import cast
1716

1817
import pytest
1918

@@ -98,4 +97,4 @@ async def route_handler(route: Route) -> None:
9897
[popup, _] = await asyncio.gather(
9998
page.wait_for_event("popup"), page.get_by_text("click me").click()
10099
)
101-
await expect(cast(Page, popup).locator("body")).to_have_text("hello")
100+
await expect(popup.locator("body")).to_have_text("hello")

tests/async/test_resource_timing.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict
1615

1716
import pytest
1817

19-
from playwright.async_api import Browser, Page
18+
from playwright.async_api import Browser, Page, ResourceTiming
2019
from tests.server import Server
2120

2221

@@ -95,7 +94,7 @@ def verify_timing_value(value: float, previous: float) -> None:
9594
assert value == -1 or value > 0 and value >= previous
9695

9796

98-
def verify_connections_timing_consistency(timing: Dict) -> None:
97+
def verify_connections_timing_consistency(timing: ResourceTiming) -> None:
9998
verify_timing_value(timing["domainLookupStart"], -1)
10099
verify_timing_value(timing["domainLookupEnd"], timing["domainLookupStart"])
101100
verify_timing_value(timing["connectStart"], timing["domainLookupEnd"])

tests/sync/test_resource_timing.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict
16-
1715
import pytest
1816

19-
from playwright.sync_api import Browser, Page
17+
from playwright.sync_api import Browser, Page, ResourceTiming
2018
from tests.server import Server
2119

2220

@@ -99,7 +97,7 @@ def verify_timing_value(value: float, previous: float) -> None:
9997
assert value == -1 or value > 0 and value >= previous
10098

10199

102-
def verify_connections_timing_consistency(timing: Dict) -> None:
100+
def verify_connections_timing_consistency(timing: ResourceTiming) -> None:
103101
verify_timing_value(timing["domainLookupStart"], -1)
104102
verify_timing_value(timing["domainLookupEnd"], timing["domainLookupStart"])
105103
verify_timing_value(timing["connectStart"], timing["domainLookupEnd"])

0 commit comments

Comments
 (0)