55from copy import deepcopy
66from functools import wraps
77from inspect import iscoroutinefunction
8- from typing import Any , TypeAlias
8+ from typing import Any , Generic , TypeAlias , TypeVar , cast
99
1010from amrita_core import SuspendObjectStream , logger
1111from amrita_core .hook .matcher import DependsFactory , MatcherFactory
2424 "WorkflowPC::each_node" # When stop at this checkpoint, change address is allowed.
2525)
2626
27+ io_T = TypeVar ("io_T" , bound = SuspendObjectStream , covariant = True )
2728
28- class WorkflowInterpreter :
29+
30+ class WorkflowInterpreter (Generic [io_T ]):
2931 _graph : NodeComposeRendered
3032 _pointer : PointerVector
3133 _ava_args : tuple
3234 _ava_kwargs : dict [str , Any ]
3335 _exc_ignored : tuple [type [BaseException ], ...]
34- object_io : SuspendObjectStream
36+ object_io : io_T
3537 _ret_addr_stack : Stack [PointerVector ]
3638 _jump_marked : bool
3739
3840 def __init__ (
3941 self ,
4042 node_compose : NodeComposeRendered | SelfCompileInstruction ,
41- object_io : SuspendObjectStream | None = None ,
43+ object_io : SuspendObjectStream [ Any ] | None = None ,
4244 * ,
4345 exception_ignored : tuple [type [BaseException ], ...],
4446 extra_args : tuple ,
@@ -52,7 +54,8 @@ def __init__(
5254 self ._ava_args = (self , * extra_args )
5355 self ._ava_kwargs = deepcopy (extra_kwargs )
5456 self ._exc_ignored = (* exception_ignored , InterruptNotice )
55- self .object_io = object_io or SuspendObjectStream ()
57+ object_io = object_io or SuspendObjectStream ()
58+ self .object_io = cast (io_T , object_io )
5659 self ._ret_addr_stack = addr_stack or Stack ()
5760 self ._jump_marked = False
5861
0 commit comments