Skip to content

fix: prevent async RL dispatch crashes on uneven batches#1225

Open
yyypluto wants to merge 2 commits into
areal-project:mainfrom
yyypluto:fix/async-rl-crashes
Open

fix: prevent async RL dispatch crashes on uneven batches#1225
yyypluto wants to merge 2 commits into
areal-project:mainfrom
yyypluto:fix/async-rl-crashes

Conversation

@yyypluto
Copy link
Copy Markdown

@yyypluto yyypluto commented Apr 22, 2026

Summary

  • pad eval batches in _custom_function_call and _async_custom_function_call before DP partitioning, then trim dummy outputs after collection
  • relax balanced_greedy_partition so remainder items are assigned instead of raising on uneven input lengths
  • add regression tests for remainder-aware partitioning and padded eval dispatch trimming

Why

TrainController can still route uneven tensor-like eval batches into _dispatch_tensors, which calls balanced_greedy_partition(group_weights, K=dp_size). When the grouped batch size is not divisible by the DP size, the async RL eval path can fail instead of padding safely.

Validation

  • author-verified training on the three affected logic paths
  • python3 -m py_compile on the edited files
  • git diff --check
  • ad hoc remainder partitioning checks

Wire eval padding into the custom dispatch path so tensor-like batches are padded to a multiple of data-parallel group size before partitioning, then trim the dummy outputs after collection.

Also relax balanced_greedy_partition so it can assign remainder items instead of raising when the number of groups does not evenly divide the input length.

Add regression tests for remainder-aware partitioning and padded eval dispatch trimming.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces padding and trimming logic for evaluation items in the TrainController and updates the balanced_greedy_partition utility to handle cases where the number of items is not divisible by the group count. While these changes improve flexibility, the review identifies several critical issues: the padding logic ignores arguments passed via kwargs, and it fails to consistently pad multiple tensor-like lists, which could lead to IndexError or engine crashes. Furthermore, the duplicated logic between the synchronous and asynchronous dispatch methods should be refactored into a shared helper to ensure maintainability.

Comment on lines +457 to +460
for arg in args:
if isinstance(arg, list) and arg and _is_tensor_like(arg):
orig_len = len(arg)
break
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The orig_len detection logic only inspects args, ignoring kwargs. If a tensor-like batch is passed via kwargs, it will not be padded (as _pad_eval_dispatch_args also only handles args), and the results will not be trimmed. This inconsistency could lead to crashes or incorrect behavior in engines that require even batches across data-parallel ranks. Consider extending this detection and the padding logic to include kwargs.

Comment on lines +462 to +464
args, kwargs = self._pad_eval_dispatch_args(
args, kwargs, group_size=group_size
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential IndexError when multiple tensor-like lists are passed as arguments. The underlying _pad_eval_batch function (called via _pad_eval_dispatch_args) only pads the first tensor-like list it encounters. If multiple parallel lists (e.g., states and actions) are provided, they will end up with different lengths after padding, causing _partition_inputs to fail when it attempts to index the unpadded lists using indices derived from the padded one. All tensor-like lists should be padded consistently to the same target length.

Comment on lines +488 to +511
group_size = kwargs.get("group_size", 1)
orig_len = None
for arg in args:
if isinstance(arg, list) and arg and _is_tensor_like(arg):
orig_len = len(arg)
break

args, kwargs = self._pad_eval_dispatch_args(
args, kwargs, group_size=group_size
)
dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs)
results = await self._call_workers(
method, dp_args, dp_kwargs, rpc_meta=rpc_meta
)
return self._collect_results(results, group_indices)
merged_results = self._collect_results(results, group_indices)

if (
orig_len is not None
and isinstance(merged_results, list)
and len(merged_results) > orig_len
):
merged_results = merged_results[:orig_len]

return merged_results
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The padding and trimming logic in _async_custom_function_call is identical to that in _custom_function_call. This duplication should be refactored into a shared helper method to improve maintainability and ensure that future fixes (such as handling kwargs or multiple lists) are applied consistently to both execution paths.

@garrett4wade
Copy link
Copy Markdown
Collaborator

Hi @yyypluto , thanks for the fix! A similar fix has been applied in #1109 , where we enforce padding the batch before dispatching eval data. Do we have follow-up issues after the fix?

@yyypluto
Copy link
Copy Markdown
Author

@garrett4wade Thanks for pointing this out. I think the current scope is slightly different from #1109.

My understanding is that #1109 fixed the explicit evaluate_* path by padding eval batches before dispatch. The follow-up issue I hit here was on the custom dispatch path: _custom_function_call / _async_custom_function_call could still route uneven top-level tensor-like inputs into _dispatch_tensors without applying that padding consistently, especially when the batch came through kwargs or when multiple parallel tensor-like lists had to stay aligned.

So the current PR is not just the original padding change. It now covers:

After this follow-up fix, I have not reproduced additional issues on the three affected logic paths in my validation.

@garrett4wade
Copy link
Copy Markdown
Collaborator

@yyypluto Hi, thanks for the follow-ups. However, I'm still confused about the motivation of this fix. Could you please:

  1. Provide an exact case where the current code runs into a bug? Which method is called under which circumstances (e.g., batch size, dp size, group size, whether input is chat messages or tensors)?
  2. Illustrate why we need mirrored fixes in both controller and the partition function? If we have already padded the batch, then there's no need to "fix" the partition function, vice versa.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 8, 2026

This pull request has been automatically marked as stale because it has not had recent activity within the last 14 days.

Please add a comment or push new commits to keep it active.

Thank you for your contribution!

@github-actions github-actions Bot added the stale label May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants