Skip to content

Commit 11f1d97

Browse files
committed
bugfix: fix failures when EP/DP and ACL Graph are enabled simultaneously.
1 parent fb683db commit 11f1d97

4 files changed

Lines changed: 181 additions & 54 deletions

File tree

third_party/xllm_atb_layers

Submodule xllm_atb_layers updated from d6aa214 to aed1d6d

xllm/core/framework/model/npu_dp_ep_padding.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,14 @@ void DpEpPadding::prepare_indices() {
118118
lm_head_skip_padding_token_indices_ = build_lm_head_indices();
119119

120120
dp_padding_idx_ = build_dp_padding_indices();
121-
// padding_idx_ =
122-
// dp_padding_idx_[mapping_npu_["attnDp"]["rank"].get<int64_t>()];
123-
padding_idx_ = dp_padding_idx_[mapping_npu_["attnDpSize"].get<int64_t>() - 1];
124-
attn_padding_idx_ =
125-
dp_padding_idx_[mapping_npu_["attnDp"]["rank"].get<int64_t>()];
121+
const int64_t attn_dp_rank = mapping_npu_["attnDp"]["rank"].get<int64_t>();
122+
CHECK_GE(attn_dp_rank, 0);
123+
CHECK_LT(attn_dp_rank, static_cast<int64_t>(dp_padding_idx_.size()));
124+
// in_padding_idx (used by moe distribute dispatch) must match current DP
125+
// shard layout; using a fixed DP group's indices can corrupt dispatch
126+
// indexing when DP groups have different token sizes.
127+
padding_idx_ = dp_padding_idx_[attn_dp_rank];
128+
attn_padding_idx_ = dp_padding_idx_[attn_dp_rank];
126129
un_padding_idx_ = torch::zeros({1}, torch::kInt32);
127130

128131
prepare_gather_indices();

xllm/core/layers/npu/npu_deepseek_v32_decoder_layer_impl.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,11 @@ void NpuDeepseekV32DecoderLayerImpl::initialize_mlp_parameters(
426426
param.enableATBGateMatmul = true;
427427

428428
param.enableIndexGmm = false;
429-
param.enableLcocAll2All = param.isPrefill && dp_size_ == 1;
429+
// LCOC fused all2all path is unstable under ACL graph launch in current
430+
// runtime; keep it for eager mode and fall back to the standard dynamic-ep
431+
// path when graph is enabled.
432+
param.enableLcocAll2All =
433+
param.isPrefill && dp_size_ == 1 && !FLAGS_enable_graph;
430434

431435
if (layer_id_ >= param.firstKDenseReplace) {
432436
param.enableQkvdownDp = false;
@@ -962,7 +966,9 @@ void NpuDeepseekV32DecoderLayerImpl::build_node_variant_pack(
962966
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 30) =
963967
atb_speed::Utils::AtTensor2Tensor(kv_cache.get_index_cache());
964968

965-
if (input_params.q_seq_lens.numel() != 0) {
969+
if (input_params.q_seq_lens.numel() != 0 &&
970+
input_params.q_cu_seq_lens.defined() &&
971+
input_params.q_cu_seq_lens.numel() != 0) {
966972
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 31) =
967973
atb_speed::Utils::AtTensor2Tensor(input_params.q_cu_seq_lens);
968974
} else {

xllm/core/runtime/acl_graph_executor_impl.cpp

Lines changed: 164 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include <torch_npu/csrc/libs/init_npu.h>
2323
#include <torch_npu/torch_npu.h>
2424

25+
#include <algorithm>
2526
#include <numeric>
2627

2728
#include "core/common/global_flags.h"
@@ -209,39 +210,113 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
209210
const int64_t actual_batch_size = params.num_sequences;
210211

211212
// Copy data from input parameters to persistent graph tensors
212-
persistent_tokens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens)
213-
.copy_(tokens, /*non_blocking=*/true);
213+
if (actual_num_tokens > 0) {
214+
persistent_tokens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens)
215+
.copy_(tokens, /*non_blocking=*/true);
216+
}
214217
// mRoPE positions have shape [3, num_tokens], slice on dim 1
215-
if (use_mrope_) {
216-
persistent_positions_
217-
.slice(/*dim=*/1, /*start=*/0, /*end=*/actual_num_tokens)
218-
.copy_(positions, /*non_blocking=*/true);
219-
} else {
220-
persistent_positions_
221-
.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens)
222-
.copy_(positions, /*non_blocking=*/true);
218+
if (actual_num_tokens > 0) {
219+
if (use_mrope_) {
220+
persistent_positions_
221+
.slice(/*dim=*/1, /*start=*/0, /*end=*/actual_num_tokens)
222+
.copy_(positions, /*non_blocking=*/true);
223+
} else {
224+
persistent_positions_
225+
.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens)
226+
.copy_(positions, /*non_blocking=*/true);
227+
}
228+
}
229+
if (actual_batch_size > 0 && params.q_seq_lens.defined() &&
230+
params.q_seq_lens.dim() >= 1 && params.q_seq_lens.numel() > 0) {
231+
const int64_t q_copy_len =
232+
std::min<int64_t>(actual_batch_size, params.q_seq_lens.size(0));
233+
if (q_copy_len > 0) {
234+
q_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_copy_len)
235+
.copy_(params.q_seq_lens.slice(/*dim=*/0,
236+
/*start=*/0,
237+
/*end=*/q_copy_len),
238+
/*non_blocking=*/true);
239+
}
240+
}
241+
if (actual_batch_size > 0 && params.kv_seq_lens.defined() &&
242+
params.kv_seq_lens.dim() >= 1 && params.kv_seq_lens.numel() > 0) {
243+
const int64_t kv_copy_len =
244+
std::min<int64_t>(actual_batch_size, params.kv_seq_lens.size(0));
245+
if (kv_copy_len > 0) {
246+
kv_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/kv_copy_len)
247+
.copy_(params.kv_seq_lens.slice(/*dim=*/0,
248+
/*start=*/0,
249+
/*end=*/kv_copy_len),
250+
/*non_blocking=*/true);
251+
}
252+
}
253+
// Keep padded decode slots valid for empty/local-short DP shards.
254+
// These tensors are consumed by ATB setup alongside *_seq_lens_vec.
255+
const int64_t padded_batch_size = static_cast<int64_t>(padded_num_tokens);
256+
if (padded_batch_size > 0) {
257+
const int64_t seq_fill_start =
258+
std::min<int64_t>(actual_batch_size, padded_batch_size);
259+
if (seq_fill_start < padded_batch_size) {
260+
q_seq_lens_
261+
.slice(/*dim=*/0, /*start=*/seq_fill_start, /*end=*/padded_batch_size)
262+
.fill_(1);
263+
kv_seq_lens_
264+
.slice(/*dim=*/0, /*start=*/seq_fill_start, /*end=*/padded_batch_size)
265+
.fill_(1);
266+
}
223267
}
224-
q_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size)
225-
.copy_(params.q_seq_lens, /*non_blocking=*/true);
226-
kv_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size)
227-
.copy_(params.kv_seq_lens, /*non_blocking=*/true);
228268

229-
persistent_new_cache_slots_
230-
.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens)
231-
.copy_(params.new_cache_slots, /*non_blocking=*/true);
269+
if (actual_num_tokens > 0 && params.new_cache_slots.defined() &&
270+
params.new_cache_slots.dim() >= 1 && params.new_cache_slots.numel() > 0) {
271+
const int64_t slot_copy_len =
272+
std::min<int64_t>(static_cast<int64_t>(actual_num_tokens),
273+
params.new_cache_slots.size(0));
274+
if (slot_copy_len > 0) {
275+
persistent_new_cache_slots_
276+
.slice(/*dim=*/0, /*start=*/0, /*end=*/slot_copy_len)
277+
.copy_(params.new_cache_slots.slice(/*dim=*/0,
278+
/*start=*/0,
279+
/*end=*/slot_copy_len),
280+
/*non_blocking=*/true);
281+
}
282+
}
283+
if (actual_num_tokens < padded_num_tokens) {
284+
persistent_new_cache_slots_
285+
.slice(/*dim=*/0,
286+
/*start=*/actual_num_tokens,
287+
/*end=*/static_cast<int64_t>(padded_num_tokens))
288+
.fill_(0);
289+
}
232290

233291
// Copy block table data
234-
const int64_t actual_block_table_len = params.block_tables.size(1);
235-
auto slice_persistent_block_tables =
236-
persistent_block_tables_
237-
.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size)
238-
.slice(/*dim=*/1, /*start=*/0, /*end=*/actual_block_table_len);
239-
slice_persistent_block_tables.copy_(params.block_tables,
240-
/*non_blocking=*/true);
292+
if (actual_batch_size > 0 && params.block_tables.defined() &&
293+
params.block_tables.dim() >= 2 && params.block_tables.numel() > 0) {
294+
const int64_t block_rows_to_copy =
295+
std::min<int64_t>(actual_batch_size, params.block_tables.size(0));
296+
const int64_t actual_block_table_len = params.block_tables.size(1);
297+
if (block_rows_to_copy > 0 && actual_block_table_len > 0) {
298+
auto slice_persistent_block_tables =
299+
persistent_block_tables_
300+
.slice(/*dim=*/0, /*start=*/0, /*end=*/block_rows_to_copy)
301+
.slice(/*dim=*/1, /*start=*/0, /*end=*/actual_block_table_len);
302+
slice_persistent_block_tables.copy_(
303+
params.block_tables.slice(/*dim=*/0,
304+
/*start=*/0,
305+
/*end=*/block_rows_to_copy),
306+
/*non_blocking=*/true);
307+
}
308+
}
309+
if (actual_batch_size < padded_batch_size) {
310+
persistent_block_tables_
311+
.slice(/*dim=*/0,
312+
/*start=*/actual_batch_size,
313+
/*end=*/padded_batch_size)
314+
.fill_(0);
315+
}
241316

242317
// Update persistent embedding from input_embedding if available
243318
const auto& embedding = params.input_embedding;
244-
if (embedding.defined()) {
319+
if (embedding.defined() && embedding.dim() >= 2) {
245320
const int64_t embedding_tokens = embedding.size(0);
246321

247322
// Initialize persistent_embedding_ if needed and not already initialized
@@ -255,21 +330,50 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
255330
}
256331

257332
// Copy embedding data to persistent buffer
258-
persistent_embedding_
259-
.slice(/*dim=*/0, /*start=*/0, /*end=*/embedding_tokens)
260-
.copy_(embedding, /*non_blocking=*/true);
333+
if (embedding_tokens > 0 && embedding.numel() > 0) {
334+
persistent_embedding_
335+
.slice(/*dim=*/0, /*start=*/0, /*end=*/embedding_tokens)
336+
.copy_(embedding, /*non_blocking=*/true);
337+
}
338+
}
339+
// Update q_cu_seq_lens used by sparse MLA indexer.
340+
// Empty local DP shards can carry empty q_cu_seq_lens from upper layers;
341+
// for graph capture we still need a valid non-empty length tensor for padded
342+
// decode slots.
343+
if (q_cu_seq_lens_.numel() == 0) {
344+
const int64_t max_seqs_per_batch = get_decode_graph_capacity(options_);
345+
q_cu_seq_lens_ = torch::zeros({max_seqs_per_batch + 1},
346+
torch::dtype(torch::kInt).device(device_));
261347
}
262-
// Update q_cu_seq_lens only if params.q_cu_seq_lens is defined
263-
if (params.q_cu_seq_lens.defined()) {
264-
// Lazy initialization: if q_cu_seq_lens_ is not initialized, initialize it
265-
if (q_cu_seq_lens_.numel() == 0) {
266-
const int64_t max_seqs_per_batch = get_decode_graph_capacity(options_);
267-
q_cu_seq_lens_ = torch::zeros({max_seqs_per_batch + 1},
268-
torch::dtype(torch::kInt).device(device_));
348+
const bool has_q_cu =
349+
params.q_cu_seq_lens.defined() && params.q_cu_seq_lens.dim() >= 1;
350+
const int64_t q_cu_size = (has_q_cu && params.q_cu_seq_lens.numel() > 0)
351+
? params.q_cu_seq_lens.size(0)
352+
: 0;
353+
const int64_t q_cu_copy_len = std::min<int64_t>(actual_batch_size, q_cu_size);
354+
if (q_cu_copy_len > 0) {
355+
q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_cu_copy_len)
356+
.copy_(params.q_cu_seq_lens.slice(/*dim=*/0,
357+
/*start=*/0,
358+
/*end=*/q_cu_copy_len),
359+
/*non_blocking=*/true);
360+
}
361+
if (padded_batch_size > q_cu_copy_len) {
362+
auto tail_q_seq_lens = q_seq_lens_.slice(/*dim=*/0,
363+
/*start=*/q_cu_copy_len,
364+
/*end=*/padded_batch_size);
365+
auto tail_cu = torch::cumsum(tail_q_seq_lens, /*dim=*/0);
366+
if (q_cu_copy_len > 0) {
367+
auto last_prefix = q_cu_seq_lens_.slice(/*dim=*/0,
368+
/*start=*/q_cu_copy_len - 1,
369+
/*end=*/q_cu_copy_len);
370+
tail_cu = tail_cu + last_prefix;
269371
}
270-
// Copy data
271-
q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_batch_size)
272-
.copy_(params.q_cu_seq_lens, /*non_blocking=*/true);
372+
q_cu_seq_lens_
373+
.slice(/*dim=*/0,
374+
/*start=*/q_cu_copy_len,
375+
/*end=*/padded_batch_size)
376+
.copy_(tail_cu, /*non_blocking=*/true);
273377
}
274378

275379
// Update attention mask only if needed
@@ -297,12 +401,19 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
297401
params_for_capture->kv_seq_lens_vec.resize(padded_num_tokens);
298402
params_for_capture->q_seq_lens_vec.resize(padded_num_tokens);
299403
// Copy actual values from original params
300-
for (int i = 0; i < actual_batch_size; i++) {
404+
const int64_t kv_vec_size =
405+
static_cast<int64_t>(params.kv_seq_lens_vec.size());
406+
const int64_t q_vec_size =
407+
static_cast<int64_t>(params.q_seq_lens_vec.size());
408+
const int64_t vec_copy_len =
409+
std::min<int64_t>(actual_batch_size, std::min(kv_vec_size, q_vec_size));
410+
for (int64_t i = 0; i < vec_copy_len; i++) {
301411
params_for_capture->kv_seq_lens_vec[i] = params.kv_seq_lens_vec[i];
302412
params_for_capture->q_seq_lens_vec[i] = params.q_seq_lens_vec[i];
303413
}
304414
// Fill padded positions with default values
305-
for (int i = actual_batch_size; i < padded_num_tokens; i++) {
415+
for (int64_t i = vec_copy_len; i < static_cast<int64_t>(padded_num_tokens);
416+
i++) {
306417
params_for_capture->kv_seq_lens_vec[i] = 1;
307418
params_for_capture->q_seq_lens_vec[i] = 1;
308419
}
@@ -320,16 +431,17 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
320431
}
321432
params_for_capture->graph_buffer.tiling_data = tiling_data();
322433
// Set persistent embedding if available
323-
if (params.input_embedding.defined()) {
434+
if (params.input_embedding.defined() && params.input_embedding.dim() >= 2 &&
435+
persistent_embedding_.defined() && persistent_embedding_.numel() > 0) {
324436
params_for_capture->input_embedding =
325437
persistent_embedding(padded_num_tokens);
326438
}
327-
// Set q_cu_seq_lens if available
328-
if (params.q_cu_seq_lens.defined()) {
439+
// Keep q_cu_seq_lens aligned with padded capture batch.
440+
if (q_cu_seq_lens_.defined() && q_cu_seq_lens_.numel() > 0) {
329441
params_for_capture->q_cu_seq_lens =
330442
q_cu_seq_lens_.slice(/*dim=*/0,
331443
/*start=*/0,
332-
/*end=*/actual_batch_size);
444+
/*end=*/padded_batch_size);
333445
}
334446

335447
return params_for_capture;
@@ -981,10 +1093,16 @@ ModelOutput AclGraphExecutorImpl::run(const torch::Tensor& tokens,
9811093
}
9821094

9831095
// Only use acl graph in decode phase for performance optimization
984-
// Get actual num_tokens from tokens shape
1096+
// For DP, decode graph bucket should be based on global max tokens across dp
1097+
// groups; local shard can be empty on some ranks.
1098+
uint32_t graph_num_tokens = tokens_tensor.size(/*dim=*/0);
1099+
if (params_single.dp_global_token_nums.size() > 1) {
1100+
graph_num_tokens = util::max(params_single.dp_global_token_nums);
1101+
}
1102+
// Keep actual n_tokens for replay output slicing.
9851103
const uint32_t n_tokens = tokens_tensor.size(/*dim=*/0);
9861104
const uint32_t actual_batch_size = n_tokens / options_.num_decoding_tokens();
987-
const uint32_t bucket_num_tokens = get_bucket_num_tokens(n_tokens);
1105+
const uint32_t bucket_num_tokens = get_bucket_num_tokens(graph_num_tokens);
9881106

9891107
// Check if conditions are suitable for graph execution (replay or capture)
9901108
const auto max_seq_len = args_.max_position_embeddings();

0 commit comments

Comments
 (0)