1+ import typing
12import asyncio
23import gc
34import weakref
45import pytest
5- from catalyst .domain .engine import WorkflowEngine , TaskError
6+ from catalyst .domain .engine import WorkflowEngine
67
78
89@pytest .mark .asyncio
@@ -43,7 +44,7 @@ async def test_reference_cycle_is_broken() -> None:
4344 """Ensure that the execution dictionary is not held in a reference cycle."""
4445 engine = WorkflowEngine ()
4546
46- async def my_task ():
47+ async def my_task () -> str :
4748 return "success"
4849
4950 engine .add_task ("task_a" , my_task )
@@ -52,22 +53,25 @@ async def my_task():
5253 results = await engine .execute ()
5354 assert results ["task_a" ] == "success"
5455
55- weak_dep = None
56+ weak_dep : typing . Any = None
5657
5758 orig_run_node = engine ._run_node
5859
5960 # We create a dummy class to hold a reference to tasks so we can weakref it
6061 class TaskHolder :
61- def __init__ (self , tasks ) :
62+ def __init__ (self , tasks : tuple [ asyncio . Task [ typing . Any ], ...]) -> None :
6263 self .tasks = tasks
6364
64- async def wrapped_run_node (node : str , dep_tasks : tuple [asyncio .Task , ...]):
65+
66+
67+
68+ async def wrapped_run_node (node : str , dep_tasks : tuple [asyncio .Task [typing .Any ], ...]) -> typing .Any :
6569 nonlocal weak_dep
6670 holder = TaskHolder (dep_tasks )
6771 weak_dep = weakref .ref (holder )
6872 return await orig_run_node (node , dep_tasks )
6973
70- engine ._run_node = wrapped_run_node
74+ engine ._run_node = wrapped_run_node # type: ignore
7175
7276 await engine .execute ()
7377 gc .collect ()
0 commit comments