@@ -67,6 +67,7 @@ def __init__(
6767 # This should be longer than the controller's internal timeout (`rollout_timeout`)
6868 # to account for potential queuing delays and other overheads.
6969 self .timeout_multiplier = 2.0
70+ self .cancel_response_timeout = 5.0
7071 self .rollout_cfg = rollout_cfg
7172
7273 async def generate ( # type: ignore[override]
@@ -140,10 +141,25 @@ async def generate( # type: ignore[override]
140141
141142 response_future .append (fut )
142143 try :
144+ response_gather = asyncio .gather (* response_future )
143145 rollout_responses = await asyncio .wait_for (
144- asyncio .gather ( * response_future ), timeout = self .rollout_timeout * self .timeout_multiplier
146+ asyncio .shield ( response_gather ), timeout = self .rollout_timeout * self .timeout_multiplier
145147 )
148+ except asyncio .CancelledError as exc :
149+ for fut in response_future :
150+ ray .cancel (fut , recursive = True )
151+ try :
152+ rollout_responses = await asyncio .wait_for (
153+ asyncio .gather (* response_future , return_exceptions = True ),
154+ timeout = self .cancel_response_timeout ,
155+ )
156+ except BaseException :
157+ raise exc
158+ if not all (isinstance (response , RLRolloutResponseItem ) for response in rollout_responses ):
159+ raise exc
146160 except asyncio .TimeoutError :
161+ for fut in response_future :
162+ ray .cancel (fut , recursive = True )
147163 self .logger .error ("Get rollout controller response timeout and return the failed response." )
148164 rollout_responses = [RLRolloutResponseItem (state = "skipped" ) for _ in group_data_items ]
149165 group_data_items = update_rollout_item (group_data_items , rollout_responses )
@@ -171,12 +187,17 @@ async def run( # type: ignore[override]
171187 group_data_items = await self .generate (group_data_items , sample_params , extra_params ) # type: ignore[assignment]
172188 continue_judger = is_valid_for_training (group_data_items )
173189 if self .judger_controller and continue_judger :
190+ judger_response_ref = self .judger_controller .run .remote (group_data_items )
174191 try :
175192 judger_responses : List [RLJudgerResponseItem ] = await asyncio .wait_for (
176- self . judger_controller . run . remote ( group_data_items ) ,
193+ judger_response_ref ,
177194 timeout = self .judger_timeout * self .timeout_multiplier ,
178195 )
196+ except asyncio .CancelledError :
197+ ray .cancel (judger_response_ref , recursive = True )
198+ raise
179199 except asyncio .TimeoutError :
200+ ray .cancel (judger_response_ref , recursive = True )
180201 self .logger .error ("Get judger controller response timeout and return the failed response." )
181202 judger_responses = [
182203 RLJudgerResponseItem (
0 commit comments