Skip to content

Commit d17f5ae

Browse files
chore(profiling): typing for _asyncio.py
1 parent 791f853 commit d17f5ae

1 file changed

Lines changed: 35 additions & 26 deletions

File tree

ddtrace/profiling/_asyncio.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,22 @@ def _get_running_loop() -> typing.Optional["aio.AbstractEventLoop"]:
9292

9393
# Python 3.14+: BaseDefaultEventLoopPolicy was renamed to _BaseDefaultEventLoopPolicy
9494
# Try both names for compatibility
95-
events_module = sys.modules["asyncio.events"]
95+
events_module: ModuleType = sys.modules["asyncio.events"]
9696
if sys.hexversion >= 0x030E0000:
9797
# Python 3.14+: Use _BaseDefaultEventLoopPolicy
98-
policy_class = getattr(events_module, "_BaseDefaultEventLoopPolicy", None)
98+
policy_class: typing.Optional[type[typing.Any]] = getattr(events_module, "_BaseDefaultEventLoopPolicy", None)
9999
else:
100100
# Python < 3.14: Use BaseDefaultEventLoopPolicy
101101
policy_class = getattr(events_module, "BaseDefaultEventLoopPolicy", None)
102102

103103
if policy_class is not None:
104104

105-
@partial(wrap, policy_class.set_event_loop)
105+
@partial(wrap, policy_class.set_event_loop) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
106106
def _(
107-
f: typing.Callable[..., typing.Any], args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]
108-
) -> typing.Any:
107+
f: typing.Callable[[object, typing.Optional["aio.AbstractEventLoop"]], None],
108+
args: typing.Any,
109+
kwargs: typing.Any,
110+
) -> None:
109111
loop: typing.Optional["aio.AbstractEventLoop"] = get_argument_value(args, kwargs, 1, "loop")
110112
if init_stack:
111113
stack.track_asyncio_loop(typing.cast(int, ddtrace_threading.current_thread().ident), loop)
@@ -118,13 +120,16 @@ def _(f: typing.Callable[..., None], args: tuple[typing.Any, ...], kwargs: dict[
118120
try:
119121
return f(*args, **kwargs)
120122
finally:
121-
children = get_argument_value(args, kwargs, 1, "children")
123+
children: list["aio.Future[typing.Any]"] = typing.cast(
124+
list["aio.Future[typing.Any]"], get_argument_value(args, kwargs, 1, "children")
125+
)
122126
assert children is not None # nosec: assert is used for typing
123127

124128
if globals()["get_running_loop"]() is not None:
125-
parent = globals()["current_task"]()
126-
for child in children:
127-
stack.link_tasks(parent, child)
129+
parent: typing.Optional["aio.Task[typing.Any]"] = globals()["current_task"]()
130+
if parent is not None:
131+
for child in children:
132+
stack.link_tasks(parent, child)
128133

129134
@partial(wrap, sys.modules["asyncio"].tasks._wait)
130135
def _(
@@ -178,9 +183,9 @@ def _(
178183
) -> typing.Any:
179184
loop = typing.cast(typing.Optional["aio.AbstractEventLoop"], kwargs.get("loop"))
180185
awaitable = typing.cast("aio.Future[typing.Any]", get_argument_value(args, kwargs, 0, "arg"))
181-
future = asyncio.ensure_future(awaitable, loop=loop)
186+
future: "aio.Future[typing.Any]" = asyncio.ensure_future(awaitable, loop=loop)
182187

183-
parent = globals()["current_task"]()
188+
parent: typing.Optional["aio.Task[typing.Any]"] = globals()["current_task"]()
184189
if parent is not None:
185190
stack.link_tasks(parent, future)
186191

@@ -196,20 +201,20 @@ def _(
196201

197202
# Wrap asyncio.TaskGroup.create_task to link parent task to created tasks (Python 3.11+)
198203
if sys.hexversion >= 0x030B0000: # Python 3.11+
199-
taskgroups_module = sys.modules.get("asyncio.taskgroups")
204+
taskgroups_module: typing.Optional[ModuleType] = sys.modules.get("asyncio.taskgroups")
200205
if taskgroups_module is not None:
201-
taskgroup_class = getattr(taskgroups_module, "TaskGroup", None)
206+
taskgroup_class: typing.Optional[type[typing.Any]] = getattr(taskgroups_module, "TaskGroup", None)
202207
if taskgroup_class is not None and hasattr(taskgroup_class, "create_task"):
203208

204209
@partial(wrap, taskgroup_class.create_task)
205210
def _(
206211
f: typing.Callable[..., "aio.Task[typing.Any]"],
207212
args: tuple[typing.Any, ...],
208213
kwargs: dict[str, typing.Any],
209-
) -> typing.Any:
210-
result = f(*args, **kwargs)
214+
) -> "aio.Task[typing.Any]":
215+
result: "aio.Task[typing.Any]" = f(*args, **kwargs)
211216

212-
parent = globals()["current_task"]()
217+
parent: typing.Optional["aio.Task[typing.Any]"] = globals()["current_task"]()
213218
if parent is not None and result is not None:
214219
# Link parent task to the task created by TaskGroup
215220
stack.link_tasks(parent, result)
@@ -259,18 +264,20 @@ def _(uvloop: ModuleType) -> None:
259264
init_stack: bool = config.stack.enabled and stack.is_available
260265

261266
# Wrap uvloop.new_event_loop to track loops when they're created
262-
new_event_loop_func = getattr(uvloop, "new_event_loop", None)
267+
new_event_loop_func: typing.Optional[typing.Callable[[], "asyncio.AbstractEventLoop"]] = getattr(
268+
uvloop, "new_event_loop", None
269+
)
263270
if new_event_loop_func is not None:
264271

265-
@partial(wrap, new_event_loop_func)
272+
@partial(wrap, new_event_loop_func) # type: ignore[arg-type]
266273
def _(
267-
f: typing.Callable[..., "asyncio.AbstractEventLoop"],
274+
f: typing.Callable[[], "asyncio.AbstractEventLoop"],
268275
args: tuple[typing.Any, ...],
269276
kwargs: dict[str, typing.Any],
270277
) -> "asyncio.AbstractEventLoop":
271-
loop = f(*args, **kwargs)
278+
loop: "asyncio.AbstractEventLoop" = f(*args, **kwargs)
272279
if init_stack:
273-
thread_id = typing.cast(int, ddtrace_threading.current_thread().ident)
280+
thread_id: int = typing.cast(int, ddtrace_threading.current_thread().ident)
274281
stack.set_uvloop_mode(thread_id, True)
275282

276283
stack.track_asyncio_loop(thread_id, loop)
@@ -280,14 +287,16 @@ def _(
280287
return loop
281288

282289
# Wrap uvloop.EventLoopPolicy.set_event_loop for uvloop.install() + asyncio.run() pattern
283-
policy_class = getattr(uvloop, "EventLoopPolicy", None)
290+
policy_class: typing.Optional[type[typing.Any]] = getattr(uvloop, "EventLoopPolicy", None)
284291
if policy_class is not None and hasattr(policy_class, "set_event_loop"):
285292

286-
@partial(wrap, policy_class.set_event_loop)
293+
@partial(wrap, policy_class.set_event_loop) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
287294
def _(
288-
f: typing.Callable[..., typing.Any], args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]
289-
) -> typing.Any:
290-
thread_id = typing.cast(int, ddtrace_threading.current_thread().ident)
295+
f: typing.Callable[[object, typing.Optional["asyncio.AbstractEventLoop"]], None],
296+
args: typing.Any,
297+
kwargs: typing.Any,
298+
) -> None:
299+
thread_id: int = typing.cast(int, ddtrace_threading.current_thread().ident)
291300
if init_stack:
292301
stack.set_uvloop_mode(thread_id, True)
293302

0 commit comments

Comments
 (0)