2020 LoopNotFoundError ,
2121 LoopPausedError ,
2222 LoopStoppedError ,
23+ WorkflowMaxRetriesError ,
2324 WorkflowNextError ,
2425 WorkflowNotFoundError ,
2526 WorkflowRepeatError ,
2627)
2728from .logging import setup_logger
2829from .state .state import LoopState , StateManager
29- from .types import BaseConfig , LoopEventSender , LoopStatus
30+ from .types import BaseConfig , LoopEventSender , LoopStatus , RetryPolicy
3031from .utils import get_func_import_path
3132
33+ DEFAULT_RETRY_POLICY = RetryPolicy ()
34+
3235if TYPE_CHECKING :
3336 from .context import LoopContext
3437
@@ -135,16 +138,25 @@ class WorkflowState:
135138 blocks : list [dict [str , Any ]] = field (default_factory = list )
136139 current_block_index : int = 0
137140 next_payload : dict [str , Any ] | None = None
141+ completed_blocks : list [int ] = field (default_factory = list )
142+ block_attempts : dict [int , int ] = field (default_factory = dict )
143+ last_error : str | None = None
138144
139145 def to_dict (self ) -> dict [str , Any ]:
140- return self .__dict__ .copy ()
146+ d = self .__dict__ .copy ()
147+ d ["block_attempts" ] = {str (k ): v for k , v in self .block_attempts .items ()}
148+ return d
141149
142150 def to_string (self ) -> str :
143- return json .dumps (self .__dict__ , default = str )
151+ return json .dumps (self .to_dict () , default = str )
144152
145153 @classmethod
146154 def from_json (cls , json_str : str ) -> "WorkflowState" :
147155 data = json .loads (json_str )
156+ if "block_attempts" in data and isinstance (data ["block_attempts" ], dict ):
157+ data ["block_attempts" ] = {
158+ int (k ): v for k , v in data ["block_attempts" ].items ()
159+ }
148160 return cls (** data )
149161
150162
@@ -434,6 +446,27 @@ def __init__(self, state_manager: StateManager):
434446 self .tasks : dict [str , asyncio .Task [None ]] = {}
435447 self .state_manager = state_manager
436448
449+ async def _persist_block_attempt (
450+ self , workflow_id : str , idx : int , error : str | None = None
451+ ) -> int :
452+ workflow = await self .state_manager .get_workflow (workflow_id )
453+ attempts = workflow .block_attempts .get (idx , 0 ) + 1
454+ workflow .block_attempts [idx ] = attempts
455+ workflow .last_error = error
456+ await self .state_manager .update_workflow (workflow_id , workflow )
457+ return attempts
458+
459+ async def _mark_block_completed (
460+ self , workflow_id : str , idx : int , next_idx : int , payload : dict | None
461+ ) -> None :
462+ workflow = await self .state_manager .get_workflow (workflow_id )
463+ if idx not in workflow .completed_blocks :
464+ workflow .completed_blocks .append (idx )
465+ workflow .current_block_index = next_idx
466+ workflow .next_payload = payload
467+ workflow .block_attempts .pop (idx , None )
468+ await self .state_manager .update_workflow (workflow_id , workflow )
469+
437470 async def _run (
438471 self ,
439472 func : Callable [..., Any ],
@@ -442,12 +475,13 @@ async def _run(
442475 on_stop : Callable [..., Any ] | None ,
443476 on_block_complete : Callable [..., Any ] | None ,
444477 on_error : Callable [..., Any ] | None ,
478+ retry_policy : RetryPolicy ,
445479 ) -> None :
446480 try :
447481 async with self .state_manager .with_workflow_claim (workflow_id ):
448482 while True :
449483 workflow = await self .state_manager .get_workflow (workflow_id )
450- if workflow .status == LoopStatus .STOPPED :
484+ if workflow .status in ( LoopStatus .STOPPED , LoopStatus . FAILED ) :
451485 break
452486
453487 blocks = [WorkflowBlock (** b ) for b in workflow .blocks ]
@@ -456,6 +490,12 @@ async def _run(
456490 if idx >= len (blocks ):
457491 raise LoopStoppedError ()
458492
493+ if idx in workflow .completed_blocks :
494+ await self ._mark_block_completed (
495+ workflow_id , idx , idx + 1 , workflow .next_payload
496+ )
497+ continue
498+
459499 current_block = blocks [idx ]
460500 context .block_index = idx
461501 context .block_count = len (blocks )
@@ -467,12 +507,12 @@ async def _run(
467507 raise LoopStoppedError ()
468508
469509 except WorkflowNextError as e :
510+ await self ._mark_block_completed (
511+ workflow_id , idx , idx + 1 , e .payload
512+ )
470513 await _call (
471514 on_block_complete , context , current_block , e .payload
472515 )
473- await self .state_manager .update_workflow_block_index (
474- workflow_id , idx + 1 , e .payload
475- )
476516
477517 except WorkflowRepeatError :
478518 pass
@@ -481,20 +521,61 @@ async def _run(
481521 raise
482522
483523 except BaseException as e :
524+ error_str = str (e )
484525 logger .error (
485- "Workflow error" ,
526+ "Workflow block error" ,
486527 extra = {
487528 "workflow_id" : workflow_id ,
488- "error" : str (e ),
529+ "block_index" : idx ,
530+ "error" : error_str ,
489531 "traceback" : traceback .format_exc (),
490532 },
491533 )
534+
535+ attempts = await self ._persist_block_attempt (
536+ workflow_id , idx , error_str
537+ )
538+
539+ should_retry = False
492540 if on_error :
493541 try :
494542 await _call (on_error , context , current_block , e )
495543 except WorkflowRepeatError :
496- continue
497- raise LoopStoppedError () from e
544+ should_retry = True
545+
546+ if not should_retry and attempts < retry_policy .max_attempts :
547+ should_retry = True
548+
549+ if should_retry and attempts < retry_policy .max_attempts :
550+ delay = retry_policy .compute_delay (attempts )
551+ logger .info (
552+ "Retrying workflow block" ,
553+ extra = {
554+ "workflow_id" : workflow_id ,
555+ "block_index" : idx ,
556+ "attempt" : attempts ,
557+ "delay" : delay ,
558+ },
559+ )
560+ await asyncio .sleep (delay )
561+ continue
562+
563+ max_retries_error = WorkflowMaxRetriesError (
564+ workflow_id , idx , attempts , error_str
565+ )
566+ logger .error (
567+ "Workflow block failed after max retries" ,
568+ extra = {
569+ "workflow_id" : workflow_id ,
570+ "block_index" : idx ,
571+ "attempts" : attempts ,
572+ },
573+ )
574+ await _call (on_error , context , current_block , max_retries_error )
575+ await self .state_manager .update_workflow_status (
576+ workflow_id , LoopStatus .FAILED
577+ )
578+ return
498579
499580 except asyncio .CancelledError :
500581 pass
@@ -521,6 +602,7 @@ async def start(
521602 on_stop : Callable [..., Any ] | None = None ,
522603 on_block_complete : Callable [..., Any ] | None = None ,
523604 on_error : Callable [..., Any ] | None = None ,
605+ retry_policy : RetryPolicy | None = None ,
524606 ) -> bool :
525607 if workflow .workflow_id in self .tasks :
526608 return False
@@ -535,6 +617,7 @@ async def start(
535617 on_stop ,
536618 on_block_complete ,
537619 on_error ,
620+ retry_policy or DEFAULT_RETRY_POLICY ,
538621 )
539622 )
540623 return True
0 commit comments