fix: prevent async RL dispatch crashes on uneven batches#1225
Conversation
Wire eval padding into the custom dispatch path so tensor-like batches are padded to a multiple of data-parallel group size before partitioning, then trim the dummy outputs after collection. Also relax balanced_greedy_partition so it can assign remainder items instead of raising when the number of groups does not evenly divide the input length. Add regression tests for remainder-aware partitioning and padded eval dispatch trimming.
There was a problem hiding this comment.
Code Review
This pull request introduces padding and trimming logic for evaluation items in the TrainController and updates the balanced_greedy_partition utility to handle cases where the number of items is not divisible by the group count. While these changes improve flexibility, the review identifies several critical issues: the padding logic ignores arguments passed via kwargs, and it fails to consistently pad multiple tensor-like lists, which could lead to IndexError or engine crashes. Furthermore, the duplicated logic between the synchronous and asynchronous dispatch methods should be refactored into a shared helper to ensure maintainability.
| for arg in args: | ||
| if isinstance(arg, list) and arg and _is_tensor_like(arg): | ||
| orig_len = len(arg) | ||
| break |
There was a problem hiding this comment.
The orig_len detection logic only inspects args, ignoring kwargs. If a tensor-like batch is passed via kwargs, it will not be padded (as _pad_eval_dispatch_args also only handles args), and the results will not be trimmed. This inconsistency could lead to crashes or incorrect behavior in engines that require even batches across data-parallel ranks. Consider extending this detection and the padding logic to include kwargs.
| args, kwargs = self._pad_eval_dispatch_args( | ||
| args, kwargs, group_size=group_size | ||
| ) |
There was a problem hiding this comment.
There is a potential IndexError when multiple tensor-like lists are passed as arguments. The underlying _pad_eval_batch function (called via _pad_eval_dispatch_args) only pads the first tensor-like list it encounters. If multiple parallel lists (e.g., states and actions) are provided, they will end up with different lengths after padding, causing _partition_inputs to fail when it attempts to index the unpadded lists using indices derived from the padded one. All tensor-like lists should be padded consistently to the same target length.
| group_size = kwargs.get("group_size", 1) | ||
| orig_len = None | ||
| for arg in args: | ||
| if isinstance(arg, list) and arg and _is_tensor_like(arg): | ||
| orig_len = len(arg) | ||
| break | ||
|
|
||
| args, kwargs = self._pad_eval_dispatch_args( | ||
| args, kwargs, group_size=group_size | ||
| ) | ||
| dp_args, dp_kwargs, group_indices = self._prepare_dispatch(*args, **kwargs) | ||
| results = await self._call_workers( | ||
| method, dp_args, dp_kwargs, rpc_meta=rpc_meta | ||
| ) | ||
| return self._collect_results(results, group_indices) | ||
| merged_results = self._collect_results(results, group_indices) | ||
|
|
||
| if ( | ||
| orig_len is not None | ||
| and isinstance(merged_results, list) | ||
| and len(merged_results) > orig_len | ||
| ): | ||
| merged_results = merged_results[:orig_len] | ||
|
|
||
| return merged_results |
There was a problem hiding this comment.
The padding and trimming logic in _async_custom_function_call is identical to that in _custom_function_call. This duplication should be refactored into a shared helper method to improve maintainability and ensure that future fixes (such as handling kwargs or multiple lists) are applied consistently to both execution paths.
|
Hi @yyypluto , thanks for the fix! A similar fix has been applied in #1109 , where we enforce padding the batch before dispatching eval data. Do we have follow-up issues after the fix? |
|
@garrett4wade Thanks for pointing this out. I think the current scope is slightly different from #1109. My understanding is that #1109 fixed the explicit So the current PR is not just the original padding change. It now covers:
After this follow-up fix, I have not reproduced additional issues on the three affected logic paths in my validation. |
|
@yyypluto Hi, thanks for the follow-ups. However, I'm still confused about the motivation of this fix. Could you please:
|
|
This pull request has been automatically marked as stale because it has not had recent activity within the last 14 days. Please add a comment or push new commits to keep it active. Thank you for your contribution! |
Summary
_custom_function_calland_async_custom_function_callbefore DP partitioning, then trim dummy outputs after collectionbalanced_greedy_partitionso remainder items are assigned instead of raising on uneven input lengthsWhy
TrainControllercan still route uneven tensor-like eval batches into_dispatch_tensors, which callsbalanced_greedy_partition(group_weights, K=dp_size). When the grouped batch size is not divisible by the DP size, the async RL eval path can fail instead of padding safely.Validation
python3 -m py_compileon the edited filesgit diff --check