@@ -23,19 +23,21 @@ void CudaGraphRunner::capturePrefill() {
2323 inputs.attention_inputs .prefix_lengths .fill_ (0 );
2424 // Must set cu_seqlens/cu_kv_seqlens/input_lengths to match actual seq_len,
2525 // otherwise FlashInfer plans for max_seq_len tokens but q/k/v only have seq_len tokens
26- inputs.attention_inputs .cu_seqlens .data_ptr <int >()[0 ] = 0 ;
27- inputs.attention_inputs .cu_seqlens .data_ptr <int >()[1 ] = seq_len;
28- inputs.attention_inputs .input_lengths .data_ptr <int >()[0 ] = seq_len;
26+ inputs.attention_inputs .cu_seqlens_host [0 ] = 0 ;
27+ inputs.attention_inputs .cu_seqlens_host [1 ] = seq_len;
28+ inputs.attention_inputs .cu_seqlens .copy_ (inputs.attention_inputs .cu_seqlens_host , false );
29+ inputs.attention_inputs .input_lengths [0 ] = seq_len;
2930 } else {
30- inputs.attention_inputs .cu_seqlens .fill_ (seq_len);
31+ inputs.attention_inputs .cu_seqlens_host .fill_ (seq_len);
32+ inputs.attention_inputs .cu_seqlens_host [0 ] = 0 ;
33+ inputs.attention_inputs .cu_seqlens .copy_ (inputs.attention_inputs .cu_seqlens_host , false );
3134 inputs.attention_inputs .input_lengths .fill_ (0 );
3235 int kv_len = max_seq_len_ + seq_len;
3336 int prefix_len = kv_len;
3437 inputs.attention_inputs .cu_kv_seqlens .fill_ (kv_len);
38+ inputs.attention_inputs .cu_kv_seqlens [0 ] = 0 ;
3539 inputs.attention_inputs .prefix_lengths .fill_ (prefix_len);
36- inputs.attention_inputs .cu_seqlens .data_ptr <int >()[0 ] = 0 ;
37- inputs.attention_inputs .cu_kv_seqlens .data_ptr <int >()[0 ] = 0 ;
38- inputs.attention_inputs .input_lengths .data_ptr <int >()[0 ] = seq_len;
40+ inputs.attention_inputs .input_lengths [0 ] = seq_len;
3941 }
4042
4143 inputs.attention_inputs .context_total_kv_length = seq_len;
0 commit comments