Skip to content

Commit 99a9602

Browse files
authored
Propagate dataflow cleanup cancellation to rollout Ray tasks (#1699)
1 parent 7a66e66 commit 99a9602

4 files changed

Lines changed: 60 additions & 6 deletions

File tree

xtuner/v1/ray/dataflow/flow.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from tqdm.auto import tqdm
1313
from typing_extensions import Annotated
1414

15-
from xtuner.v1.data_proto.rl_data import MultimodalTrainInfo, RLDataFlowItem, RolloutState
15+
from xtuner.v1.data_proto.rl_data import (
16+
MultimodalTrainInfo,
17+
RLDataFlowItem,
18+
RolloutState,
19+
)
1620
from xtuner.v1.ray.environment import SingleTurnEnvironment
1721
from xtuner.v1.ray.rollout.controller import SampleParams
1822
from xtuner.v1.ray.utils import create_task
@@ -180,6 +184,7 @@ def __init__(
180184
)
181185
self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}")
182186
self.cleanup_task_time = 5 * 60 # 5 minutes
187+
self.cancel_response_timeout = 5.0
183188

184189
def _reset_internal_states(
185190
self,
@@ -253,11 +258,22 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
253258
assert len(group_data_items) > 0, "Sampled empty group data items from replay buffer."
254259
action_id = group_data_items[0].uid.action_id
255260
# step 2: env generate
256-
group_data_items = await self.env_controller.run.remote( # type: ignore[attr-defined]
261+
env_run_ref = self.env_controller.run.remote( # type: ignore[attr-defined]
257262
group_data_items,
258263
sample_params=self.sample_params,
259264
extra_params=self.extra_params,
260265
)
266+
try:
267+
group_data_items = await asyncio.shield(env_run_ref)
268+
except asyncio.CancelledError as exc:
269+
ray.cancel(env_run_ref, recursive=True)
270+
try:
271+
group_data_items = await asyncio.wait_for(
272+
asyncio.shield(env_run_ref),
273+
timeout=self.cancel_response_timeout,
274+
)
275+
except BaseException:
276+
raise exc
261277

262278
# Step 3: Determine the sample's state and act accordingly.
263279
group_state = determine_group_state(group_data_items)

xtuner/v1/ray/environment/single_turn_env.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
# This should be longer than the controller's internal timeout (`rollout_timeout`)
6868
# to account for potential queuing delays and other overheads.
6969
self.timeout_multiplier = 2.0
70+
self.cancel_response_timeout = 5.0
7071
self.rollout_cfg = rollout_cfg
7172

7273
async def generate( # type: ignore[override]
@@ -140,10 +141,25 @@ async def generate( # type: ignore[override]
140141

141142
response_future.append(fut)
142143
try:
144+
response_gather = asyncio.gather(*response_future)
143145
rollout_responses = await asyncio.wait_for(
144-
asyncio.gather(*response_future), timeout=self.rollout_timeout * self.timeout_multiplier
146+
asyncio.shield(response_gather), timeout=self.rollout_timeout * self.timeout_multiplier
145147
)
148+
except asyncio.CancelledError as exc:
149+
for fut in response_future:
150+
ray.cancel(fut, recursive=True)
151+
try:
152+
rollout_responses = await asyncio.wait_for(
153+
asyncio.gather(*response_future, return_exceptions=True),
154+
timeout=self.cancel_response_timeout,
155+
)
156+
except BaseException:
157+
raise exc
158+
if not all(isinstance(response, RLRolloutResponseItem) for response in rollout_responses):
159+
raise exc
146160
except asyncio.TimeoutError:
161+
for fut in response_future:
162+
ray.cancel(fut, recursive=True)
147163
self.logger.error("Get rollout controller response timeout and return the failed response.")
148164
rollout_responses = [RLRolloutResponseItem(state="skipped") for _ in group_data_items]
149165
group_data_items = update_rollout_item(group_data_items, rollout_responses)
@@ -171,12 +187,17 @@ async def run( # type: ignore[override]
171187
group_data_items = await self.generate(group_data_items, sample_params, extra_params) # type: ignore[assignment]
172188
continue_judger = is_valid_for_training(group_data_items)
173189
if self.judger_controller and continue_judger:
190+
judger_response_ref = self.judger_controller.run.remote(group_data_items)
174191
try:
175192
judger_responses: List[RLJudgerResponseItem] = await asyncio.wait_for(
176-
self.judger_controller.run.remote(group_data_items),
193+
judger_response_ref,
177194
timeout=self.judger_timeout * self.timeout_multiplier,
178195
)
196+
except asyncio.CancelledError:
197+
ray.cancel(judger_response_ref, recursive=True)
198+
raise
179199
except asyncio.TimeoutError:
200+
ray.cancel(judger_response_ref, recursive=True)
180201
self.logger.error("Get judger controller response timeout and return the failed response.")
181202
judger_responses = [
182203
RLJudgerResponseItem(

xtuner/v1/ray/rollout/controller.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def __init__(
163163
# This should be longer than the controller's internal timeout (`rollout_timeout`)
164164
# to account for potential queuing delays and other overheads.
165165
self.timeout_multiplier = 2.0
166+
self.cancel_response_timeout = 5.0
166167

167168
def _get_worker_status_for_router(self) -> Dict[RolloutWorker, bool]:
168169
"""Helper to generate the status dict required by the SessionRouter."""
@@ -403,15 +404,22 @@ async def rollout(
403404
try:
404405
selected_worker_info = self.workers_info[server_url]
405406
response = await asyncio.wait_for(
406-
response_ref, timeout=self.config.rollout_timeout * self.timeout_multiplier
407+
asyncio.shield(response_ref), timeout=self.config.rollout_timeout * self.timeout_multiplier
407408
)
408409
selected_worker_info.success_count += 1
409410
if response.state == "failed" or response.state == "skipped":
410411
selected_worker_info.failure_count += 1
411412
self.logger.error(f"Get failed/skipped response from rollout worker {worker}, deactivate it.")
412413
self.deactivate_worker_by_url(server_url)
413414
return response
415+
except asyncio.CancelledError as exc:
416+
ray.cancel(response_ref, recursive=True)
417+
try:
418+
return await asyncio.wait_for(asyncio.shield(response_ref), timeout=self.cancel_response_timeout)
419+
except BaseException:
420+
raise exc
414421
except asyncio.TimeoutError:
422+
ray.cancel(response_ref, recursive=True)
415423
selected_worker_info.failure_count += 1
416424
self.logger.error(f"Get response from rollout worker {worker} timeout and return skip this sample.")
417425
self.deactivate_worker_by_url(server_url)

xtuner/v1/ray/rollout/worker.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def _check_infer_engine_version(self, return_token_ids: bool):
331331
self.check_flag = False
332332

333333
async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult:
334+
send_task = None
334335
try:
335336
if self.receive_abort_request.is_set():
336337
self.logger.debug(f"Request to {url} was cancelled before sending due to an abort signal.")
@@ -341,10 +342,18 @@ async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult:
341342
headers=headers,
342343
json=payload,
343344
)
344-
r = await self.client.send(req)
345+
send_task = asyncio.create_task(self.client.send(req))
346+
r = await send_task
345347
r.raise_for_status()
346348
return HttpRequestResult(response=r)
347349

350+
except asyncio.CancelledError:
351+
self.logger.debug(f"Request to {url} was cancelled while waiting for the response.")
352+
if send_task is not None and not send_task.done():
353+
send_task.cancel()
354+
await asyncio.gather(send_task, return_exceptions=True)
355+
self.receive_abort_request.set()
356+
return HttpRequestResult(error_type=HttpRequestErrorType.REQUEST_ABORTED, url=url, payload=payload)
348357
except Exception as e:
349358
error_type = HttpRequestErrorType.from_exception(e)
350359
result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload)

0 commit comments

Comments
 (0)