bugfix: fix failures when EP/DP and ACL Graph are enabled simultaneously.#1218
bugfix: fix failures when EP/DP and ACL Graph are enabled simultaneously.#1218DongheJin wants to merge 1 commit intojd-opensource:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request improves the stability and correctness of ACL graph execution, particularly for Data Parallel (DP) configurations. Key changes include adding robust validation checks for input tensors, ensuring padded decode slots are correctly initialized for empty or short DP shards, and disabling the unstable LCOC fused all2all path when graph mode is enabled. Review feedback identified critical logic errors in the generation of cumulative sequence lengths (q_cu_seq_lens), which require N+1 elements to correctly support attention kernels, and noted a style guide violation regarding implicit type conversion.
| const int64_t q_cu_copy_len = std::min<int64_t>(actual_batch_size, q_cu_size); | ||
| if (q_cu_copy_len > 0) { | ||
| q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_cu_copy_len) | ||
| .copy_(params.q_cu_seq_lens.slice(/*dim=*/0, | ||
| /*start=*/0, | ||
| /*end=*/q_cu_copy_len), | ||
| /*non_blocking=*/true); | ||
| } | ||
| if (padded_batch_size > q_cu_copy_len) { | ||
| auto tail_q_seq_lens = q_seq_lens_.slice(/*dim=*/0, | ||
| /*start=*/q_cu_copy_len, | ||
| /*end=*/padded_batch_size); | ||
| auto tail_cu = torch::cumsum(tail_q_seq_lens, /*dim=*/0); | ||
| if (q_cu_copy_len > 0) { | ||
| auto last_prefix = q_cu_seq_lens_.slice(/*dim=*/0, | ||
| /*start=*/q_cu_copy_len - 1, | ||
| /*end=*/q_cu_copy_len); | ||
| tail_cu = tail_cu + last_prefix; | ||
| } | ||
| // Copy data | ||
| q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size) | ||
| .copy_(params.q_cu_seq_lens, /*non_blocking=*/true); | ||
| q_cu_seq_lens_ | ||
| .slice(/*dim=*/0, | ||
| /*start=*/q_cu_copy_len, | ||
| /*end=*/padded_batch_size) | ||
| .copy_(tail_cu, /*non_blocking=*/true); |
There was a problem hiding this comment.
The logic for extending q_cu_seq_lens_ is incorrect. A cumulative sum for
- Only copies up to
actual_batch_sizeelements, missing the total sum of the actual sequences at indexactual_batch_size. - Incorrectly calculates the extension starting from
q_cu_copy_len, which uses the length of the first padded sequence to calculate the offset for the sequence at that index, instead of using the length of the last actual sequence. - Fills only up to index
padded_batch_size - 1, leaving the final cumulative sum value at indexpadded_batch_sizeas zero.
This will lead to incorrect attention offsets in kernels like MLA. The suggested fix correctly handles
const int64_t q_cu_copy_len = (has_q_cu && q_cu_size > 0)
? std::min<int64_t>(actual_batch_size + 1, q_cu_size)
: 0;
if (q_cu_copy_len > 0) {
q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_cu_copy_len)
.copy_(params.q_cu_seq_lens.slice(/*dim=*/0, /*start=*/0, /*end=*/q_cu_copy_len),
/*non_blocking=*/true);
}
const int64_t fill_start = std::max<int64_t>(1, q_cu_copy_len);
if (padded_batch_size + 1 > fill_start) {
auto tail_q_seq_lens = q_seq_lens_.slice(/*dim=*/0, /*start=*/fill_start - 1,
/*end=*/padded_batch_size);
auto tail_cu = torch::cumsum(tail_q_seq_lens, /*dim=*/0);
auto last_prefix = q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/fill_start - 1,
/*end=*/fill_start);
tail_cu = tail_cu + last_prefix;
q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/fill_start, /*end=*/padded_batch_size + 1)
.copy_(tail_cu, /*non_blocking=*/true);
}| params_for_capture->q_cu_seq_lens = | ||
| q_cu_seq_lens_.slice(/*dim=*/0, | ||
| /*start=*/0, | ||
| /*end=*/actual_batch_size); | ||
| /*end=*/padded_batch_size); |
There was a problem hiding this comment.
The slice end for q_cu_seq_lens should be padded_batch_size + 1 to include the final cumulative sum value required for
| params_for_capture->q_cu_seq_lens = | |
| q_cu_seq_lens_.slice(/*dim=*/0, | |
| /*start=*/0, | |
| /*end=*/actual_batch_size); | |
| /*end=*/padded_batch_size); | |
| params_for_capture->q_cu_seq_lens = | |
| q_cu_seq_lens_.slice(/*dim=*/0, | |
| /*start=*/0, | |
| /*end=*/padded_batch_size + 1); |
| // groups; local shard can be empty on some ranks. | ||
| uint32_t graph_num_tokens = tokens_tensor.size(/*dim=*/0); | ||
| if (params_single.dp_global_token_nums.size() > 1) { | ||
| graph_num_tokens = util::max(params_single.dp_global_token_nums); |
There was a problem hiding this comment.
Implicit conversion from int32_t to uint32_t should be avoided. Please use static_cast to adhere to the repository style guide.
| graph_num_tokens = util::max(params_single.dp_global_token_nums); | |
| graph_num_tokens = static_cast<uint32_t>(util::max(params_single.dp_global_token_nums)); |
References
- Use static_cast for all type conversions. Never use C-style casts. (link)
No description provided.