@@ -243,6 +243,55 @@ def print_events(self, class_name: str) -> None:
243243 doc .append (f" return super().{ event_type } (event=event,f=f)" )
244244 print ("\n " .join (doc ))
245245
246+ def print_event_overloads (self , class_name : str , method_name : str ) -> None :
247+ """Emit ``@typing.overload`` stubs for ``expect_event`` / ``wait_for_event``
248+ keyed on ``Literal`` event names with their payload types from api.json,
249+ so pyright/mypy can narrow the return type at call sites.
250+ Must be called right before the implementation signature is emitted.
251+ """
252+ if class_name not in self .classes :
253+ return
254+ events = self .classes [class_name ].get ("events" ) or []
255+ if not events :
256+ return
257+ is_expect = method_name == "expect_event"
258+ async_prefix = "async " if not is_expect and self .is_async else ""
259+ if is_expect :
260+ ctx_mgr = (
261+ "AsyncEventContextManager" if self .is_async else "EventContextManager"
262+ )
263+ for event in events :
264+ payload = self .serialize_doc_type (event ["type" ], "" )
265+ if payload .startswith ("{" ):
266+ payload = "typing.Dict"
267+ if "Union[" in payload :
268+ payload = payload .replace ("Union[" , "typing.Union[" )
269+ return_type = f'{ ctx_mgr } ["{ payload } "]' if is_expect else f'"{ payload } "'
270+ event_literal = event ["name" ].lower ()
271+ print (" @typing.overload" )
272+ print (f" { async_prefix } def { method_name } (" )
273+ print (" self," )
274+ print (f' event: typing.Literal["{ event_literal } "],' )
275+ print (
276+ f' predicate: typing.Optional[typing.Callable[["{ payload } "], bool]] = None,'
277+ )
278+ print (" *," )
279+ print (" timeout: typing.Optional[float] = None," )
280+ print (f" ) -> { return_type } : ..." )
281+ print ("" )
282+ # Catch-all overload for non-literal event names — keeps pyright happy
283+ # with `event: str` callers without falling through to `Unknown`.
284+ catchall_return = f"{ ctx_mgr } [typing.Any]" if is_expect else "typing.Any"
285+ print (" @typing.overload" )
286+ print (f" { async_prefix } def { method_name } (" )
287+ print (" self," )
288+ print (" event: str," )
289+ print (" predicate: typing.Optional[typing.Callable[..., bool]] = None," )
290+ print (" *," )
291+ print (" timeout: typing.Optional[float] = None," )
292+ print (f" ) -> { catchall_return } : ..." )
293+ print ("" )
294+
246295 def indent_paragraph (self , p : str , indent : str ) -> str :
247296 lines = p .split ("\n " )
248297 result = [lines [0 ]]
0 commit comments