@@ -22,6 +22,7 @@ limitations under the License.
2222#include < torch_npu/csrc/libs/init_npu.h>
2323#include < torch_npu/torch_npu.h>
2424
25+ #include < algorithm>
2526#include < numeric>
2627
2728#include " core/common/global_flags.h"
@@ -209,39 +210,113 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
209210 const int64_t actual_batch_size = params.num_sequences ;
210211
211212 // Copy data from input parameters to persistent graph tensors
212- persistent_tokens_.slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ actual_num_tokens)
213- .copy_ (tokens, /* non_blocking=*/ true );
213+ if (actual_num_tokens > 0 ) {
214+ persistent_tokens_.slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ actual_num_tokens)
215+ .copy_ (tokens, /* non_blocking=*/ true );
216+ }
214217 // mRoPE positions have shape [3, num_tokens], slice on dim 1
215- if (use_mrope_) {
216- persistent_positions_
217- .slice (/* dim=*/ 1 , /* start=*/ 0 , /* end=*/ actual_num_tokens)
218- .copy_ (positions, /* non_blocking=*/ true );
219- } else {
220- persistent_positions_
221- .slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ actual_num_tokens)
222- .copy_ (positions, /* non_blocking=*/ true );
218+ if (actual_num_tokens > 0 ) {
219+ if (use_mrope_) {
220+ persistent_positions_
221+ .slice (/* dim=*/ 1 , /* start=*/ 0 , /* end=*/ actual_num_tokens)
222+ .copy_ (positions, /* non_blocking=*/ true );
223+ } else {
224+ persistent_positions_
225+ .slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ actual_num_tokens)
226+ .copy_ (positions, /* non_blocking=*/ true );
227+ }
228+ }
229+ if (actual_batch_size > 0 && params.q_seq_lens .defined () &&
230+ params.q_seq_lens .dim () >= 1 && params.q_seq_lens .numel () > 0 ) {
231+ const int64_t q_copy_len =
232+ std::min<int64_t >(actual_batch_size, params.q_seq_lens .size (0 ));
233+ if (q_copy_len > 0 ) {
234+ q_seq_lens_.slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ q_copy_len)
235+ .copy_ (params.q_seq_lens .slice (/* dim=*/ 0 ,
236+ /* start=*/ 0 ,
237+ /* end=*/ q_copy_len),
238+ /* non_blocking=*/ true );
239+ }
240+ }
241+ if (actual_batch_size > 0 && params.kv_seq_lens .defined () &&
242+ params.kv_seq_lens .dim () >= 1 && params.kv_seq_lens .numel () > 0 ) {
243+ const int64_t kv_copy_len =
244+ std::min<int64_t >(actual_batch_size, params.kv_seq_lens .size (0 ));
245+ if (kv_copy_len > 0 ) {
246+ kv_seq_lens_.slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ kv_copy_len)
247+ .copy_ (params.kv_seq_lens .slice (/* dim=*/ 0 ,
248+ /* start=*/ 0 ,
249+ /* end=*/ kv_copy_len),
250+ /* non_blocking=*/ true );
251+ }
252+ }
253+ // Keep padded decode slots valid for empty/local-short DP shards.
254+ // These tensors are consumed by ATB setup alongside *_seq_lens_vec.
255+ const int64_t padded_batch_size = static_cast <int64_t >(padded_num_tokens);
256+ if (padded_batch_size > 0 ) {
257+ const int64_t seq_fill_start =
258+ std::min<int64_t >(actual_batch_size, padded_batch_size);
259+ if (seq_fill_start < padded_batch_size) {
260+ q_seq_lens_
261+ .slice (/* dim=*/ 0 , /* start=*/ seq_fill_start, /* end=*/ padded_batch_size)
262+ .fill_ (1 );
263+ kv_seq_lens_
264+ .slice (/* dim=*/ 0 , /* start=*/ seq_fill_start, /* end=*/ padded_batch_size)
265+ .fill_ (1 );
266+ }
223267 }
224- q_seq_lens_.slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ actual_batch_size)
225- .copy_ (params.q_seq_lens , /* non_blocking=*/ true );
226- kv_seq_lens_.slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ actual_batch_size)
227- .copy_ (params.kv_seq_lens , /* non_blocking=*/ true );
228268
229- persistent_new_cache_slots_
230- .slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ actual_num_tokens)
231- .copy_ (params.new_cache_slots , /* non_blocking=*/ true );
269+ if (actual_num_tokens > 0 && params.new_cache_slots .defined () &&
270+ params.new_cache_slots .dim () >= 1 && params.new_cache_slots .numel () > 0 ) {
271+ const int64_t slot_copy_len =
272+ std::min<int64_t >(static_cast <int64_t >(actual_num_tokens),
273+ params.new_cache_slots .size (0 ));
274+ if (slot_copy_len > 0 ) {
275+ persistent_new_cache_slots_
276+ .slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ slot_copy_len)
277+ .copy_ (params.new_cache_slots .slice (/* dim=*/ 0 ,
278+ /* start=*/ 0 ,
279+ /* end=*/ slot_copy_len),
280+ /* non_blocking=*/ true );
281+ }
282+ }
283+ if (actual_num_tokens < padded_num_tokens) {
284+ persistent_new_cache_slots_
285+ .slice (/* dim=*/ 0 ,
286+ /* start=*/ actual_num_tokens,
287+ /* end=*/ static_cast <int64_t >(padded_num_tokens))
288+ .fill_ (0 );
289+ }
232290
233291 // Copy block table data
234- const int64_t actual_block_table_len = params.block_tables .size (1 );
235- auto slice_persistent_block_tables =
236- persistent_block_tables_
237- .slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ actual_batch_size)
238- .slice (/* dim=*/ 1 , /* start=*/ 0 , /* end=*/ actual_block_table_len);
239- slice_persistent_block_tables.copy_ (params.block_tables ,
240- /* non_blocking=*/ true );
292+ if (actual_batch_size > 0 && params.block_tables .defined () &&
293+ params.block_tables .dim () >= 2 && params.block_tables .numel () > 0 ) {
294+ const int64_t block_rows_to_copy =
295+ std::min<int64_t >(actual_batch_size, params.block_tables .size (0 ));
296+ const int64_t actual_block_table_len = params.block_tables .size (1 );
297+ if (block_rows_to_copy > 0 && actual_block_table_len > 0 ) {
298+ auto slice_persistent_block_tables =
299+ persistent_block_tables_
300+ .slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ block_rows_to_copy)
301+ .slice (/* dim=*/ 1 , /* start=*/ 0 , /* end=*/ actual_block_table_len);
302+ slice_persistent_block_tables.copy_ (
303+ params.block_tables .slice (/* dim=*/ 0 ,
304+ /* start=*/ 0 ,
305+ /* end=*/ block_rows_to_copy),
306+ /* non_blocking=*/ true );
307+ }
308+ }
309+ if (actual_batch_size < padded_batch_size) {
310+ persistent_block_tables_
311+ .slice (/* dim=*/ 0 ,
312+ /* start=*/ actual_batch_size,
313+ /* end=*/ padded_batch_size)
314+ .fill_ (0 );
315+ }
241316
242317 // Update persistent embedding from input_embedding if available
243318 const auto & embedding = params.input_embedding ;
244- if (embedding.defined ()) {
319+ if (embedding.defined () && embedding. dim () >= 2 ) {
245320 const int64_t embedding_tokens = embedding.size (0 );
246321
247322 // Initialize persistent_embedding_ if needed and not already initialized
@@ -255,21 +330,50 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
255330 }
256331
257332 // Copy embedding data to persistent buffer
258- persistent_embedding_
259- .slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ embedding_tokens)
260- .copy_ (embedding, /* non_blocking=*/ true );
333+ if (embedding_tokens > 0 && embedding.numel () > 0 ) {
334+ persistent_embedding_
335+ .slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ embedding_tokens)
336+ .copy_ (embedding, /* non_blocking=*/ true );
337+ }
338+ }
339+ // Update q_cu_seq_lens used by sparse MLA indexer.
340+ // Empty local DP shards can carry empty q_cu_seq_lens from upper layers;
341+ // for graph capture we still need a valid non-empty length tensor for padded
342+ // decode slots.
343+ if (q_cu_seq_lens_.numel () == 0 ) {
344+ const int64_t max_seqs_per_batch = get_decode_graph_capacity (options_);
345+ q_cu_seq_lens_ = torch::zeros ({max_seqs_per_batch + 1 },
346+ torch::dtype (torch::kInt ).device (device_));
261347 }
262- // Update q_cu_seq_lens only if params.q_cu_seq_lens is defined
263- if (params.q_cu_seq_lens .defined ()) {
264- // Lazy initialization: if q_cu_seq_lens_ is not initialized, initialize it
265- if (q_cu_seq_lens_.numel () == 0 ) {
266- const int64_t max_seqs_per_batch = get_decode_graph_capacity (options_);
267- q_cu_seq_lens_ = torch::zeros ({max_seqs_per_batch + 1 },
268- torch::dtype (torch::kInt ).device (device_));
348+ const bool has_q_cu =
349+ params.q_cu_seq_lens .defined () && params.q_cu_seq_lens .dim () >= 1 ;
350+ const int64_t q_cu_size = (has_q_cu && params.q_cu_seq_lens .numel () > 0 )
351+ ? params.q_cu_seq_lens .size (0 )
352+ : 0 ;
353+ const int64_t q_cu_copy_len = std::min<int64_t >(actual_batch_size, q_cu_size);
354+ if (q_cu_copy_len > 0 ) {
355+ q_cu_seq_lens_.slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ q_cu_copy_len)
356+ .copy_ (params.q_cu_seq_lens .slice (/* dim=*/ 0 ,
357+ /* start=*/ 0 ,
358+ /* end=*/ q_cu_copy_len),
359+ /* non_blocking=*/ true );
360+ }
361+ if (padded_batch_size > q_cu_copy_len) {
362+ auto tail_q_seq_lens = q_seq_lens_.slice (/* dim=*/ 0 ,
363+ /* start=*/ q_cu_copy_len,
364+ /* end=*/ padded_batch_size);
365+ auto tail_cu = torch::cumsum (tail_q_seq_lens, /* dim=*/ 0 );
366+ if (q_cu_copy_len > 0 ) {
367+ auto last_prefix = q_cu_seq_lens_.slice (/* dim=*/ 0 ,
368+ /* start=*/ q_cu_copy_len - 1 ,
369+ /* end=*/ q_cu_copy_len);
370+ tail_cu = tail_cu + last_prefix;
269371 }
270- // Copy data
271- q_cu_seq_lens_.slice (/* dim=*/ 0 , /* start=*/ 0 , /* end=*/ actual_batch_size)
272- .copy_ (params.q_cu_seq_lens , /* non_blocking=*/ true );
372+ q_cu_seq_lens_
373+ .slice (/* dim=*/ 0 ,
374+ /* start=*/ q_cu_copy_len,
375+ /* end=*/ padded_batch_size)
376+ .copy_ (tail_cu, /* non_blocking=*/ true );
273377 }
274378
275379 // Update attention mask only if needed
@@ -297,12 +401,19 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
297401 params_for_capture->kv_seq_lens_vec .resize (padded_num_tokens);
298402 params_for_capture->q_seq_lens_vec .resize (padded_num_tokens);
299403 // Copy actual values from original params
300- for (int i = 0 ; i < actual_batch_size; i++) {
404+ const int64_t kv_vec_size =
405+ static_cast <int64_t >(params.kv_seq_lens_vec .size ());
406+ const int64_t q_vec_size =
407+ static_cast <int64_t >(params.q_seq_lens_vec .size ());
408+ const int64_t vec_copy_len =
409+ std::min<int64_t >(actual_batch_size, std::min (kv_vec_size, q_vec_size));
410+ for (int64_t i = 0 ; i < vec_copy_len; i++) {
301411 params_for_capture->kv_seq_lens_vec [i] = params.kv_seq_lens_vec [i];
302412 params_for_capture->q_seq_lens_vec [i] = params.q_seq_lens_vec [i];
303413 }
304414 // Fill padded positions with default values
305- for (int i = actual_batch_size; i < padded_num_tokens; i++) {
415+ for (int64_t i = vec_copy_len; i < static_cast <int64_t >(padded_num_tokens);
416+ i++) {
306417 params_for_capture->kv_seq_lens_vec [i] = 1 ;
307418 params_for_capture->q_seq_lens_vec [i] = 1 ;
308419 }
@@ -320,16 +431,17 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
320431 }
321432 params_for_capture->graph_buffer .tiling_data = tiling_data ();
322433 // Set persistent embedding if available
323- if (params.input_embedding .defined ()) {
434+ if (params.input_embedding .defined () && params.input_embedding .dim () >= 2 &&
435+ persistent_embedding_.defined () && persistent_embedding_.numel () > 0 ) {
324436 params_for_capture->input_embedding =
325437 persistent_embedding (padded_num_tokens);
326438 }
327- // Set q_cu_seq_lens if available
328- if (params. q_cu_seq_lens . defined ()) {
439+ // Keep q_cu_seq_lens aligned with padded capture batch.
440+ if (q_cu_seq_lens_. defined () && q_cu_seq_lens_. numel () > 0 ) {
329441 params_for_capture->q_cu_seq_lens =
330442 q_cu_seq_lens_.slice (/* dim=*/ 0 ,
331443 /* start=*/ 0 ,
332- /* end=*/ actual_batch_size );
444+ /* end=*/ padded_batch_size );
333445 }
334446
335447 return params_for_capture;
@@ -981,10 +1093,16 @@ ModelOutput AclGraphExecutorImpl::run(const torch::Tensor& tokens,
9811093 }
9821094
9831095 // Only use acl graph in decode phase for performance optimization
984- // Get actual num_tokens from tokens shape
1096+ // For DP, decode graph bucket should be based on global max tokens across dp
1097+ // groups; local shard can be empty on some ranks.
1098+ uint32_t graph_num_tokens = tokens_tensor.size (/* dim=*/ 0 );
1099+ if (params_single.dp_global_token_nums .size () > 1 ) {
1100+ graph_num_tokens = util::max (params_single.dp_global_token_nums );
1101+ }
1102+ // Keep actual n_tokens for replay output slicing.
9851103 const uint32_t n_tokens = tokens_tensor.size (/* dim=*/ 0 );
9861104 const uint32_t actual_batch_size = n_tokens / options_.num_decoding_tokens ();
987- const uint32_t bucket_num_tokens = get_bucket_num_tokens (n_tokens );
1105+ const uint32_t bucket_num_tokens = get_bucket_num_tokens (graph_num_tokens );
9881106
9891107 // Check if conditions are suitable for graph execution (replay or capture)
9901108 const auto max_seq_len = args_.max_position_embeddings ();
0 commit comments