@@ -31,9 +31,9 @@ def _cu_seqlens_triton_kernel(
3131 cu_seqlens_ptr , # [num_seqs + 1]
3232 chunk_indices_ptr , # [N] output
3333 chunk_offsets_ptr , # [N] output
34- num_seqs : tl . constexpr ,
34+ num_seqs ,
3535 chunk_size : tl .constexpr ,
36- N : tl . constexpr ,
36+ N ,
3737 BLOCK_SIZE : tl .constexpr ,
3838):
3939 """Computes chunk_indices and chunk_offsets in a single kernel launch."""
@@ -65,8 +65,18 @@ def _cu_seqlens_triton_kernel(
6565
6666def cu_seqlens_to_chunk_indices_offsets_triton (
6767 cu_seqlens : torch .Tensor ,
68- chunk_size : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
69- """Optimized version of cu_seqlens_to_chunk_indices_offsets."""
68+ chunk_size : int ,
69+ total_seqlens : int = - 1 ,
70+ extra_chunks : int = - 1 ) -> Tuple [torch .Tensor , torch .Tensor ]:
71+ """Optimized version of cu_seqlens_to_chunk_indices_offsets.
72+
73+ Args:
74+ total_seqlens: If provided (>= 0), avoids a GPU->CPU sync to read
75+ cu_seqlens[-1]. Callers that already know the total number of
76+ context tokens should pass it here.
77+ extra_chunks: If provided (>= 0), avoids a GPU->CPU sync to compute
78+ the number of extra chunks from misaligned sequence boundaries.
79+ """
7080 device = cu_seqlens .device
7181 num_seqs = cu_seqlens .numel () - 1
7282
@@ -75,18 +85,20 @@ def cu_seqlens_to_chunk_indices_offsets_triton(
7585 torch .empty (0 , dtype = torch .int , device = device ))
7686
7787 cu = cu_seqlens .to (dtype = torch .int64 )
78- total_seqlens = cu [- 1 ].item ()
88+ if total_seqlens < 0 :
89+ total_seqlens = cu [- 1 ].item ()
7990
8091 if num_seqs == 1 :
8192 # Fast path for single sequence (no boundaries to process)
8293 N = (total_seqlens + chunk_size - 1 ) // chunk_size
8394 return (torch .arange (N , device = device , dtype = torch .int ),
8495 torch .zeros (N , device = device , dtype = torch .int ))
8596
86- seq_starts = cu [1 :- 1 ]
87- misaligned = ((seq_starts % chunk_size ) > 0 ).to (torch .int64 )
88- p = torch .cumsum (misaligned , dim = 0 )
89- extra_chunks = p [- 1 ].item () if p .numel () > 0 else 0
97+ if extra_chunks < 0 :
98+ seq_starts = cu [1 :- 1 ]
99+ misaligned = ((seq_starts % chunk_size ) > 0 ).to (torch .int64 )
100+ p = torch .cumsum (misaligned , dim = 0 )
101+ extra_chunks = p [- 1 ].item () if p .numel () > 0 else 0
90102 N = (total_seqlens + chunk_size - 1 ) // chunk_size + extra_chunks
91103 chunk_indices = torch .empty (N , device = device , dtype = torch .int )
92104 chunk_offsets = torch .empty (N , device = device , dtype = torch .int )
@@ -279,8 +291,21 @@ def prepare(self, attn_metadata: AttentionMetadata):
279291 self .has_initial_states_cpu [:num_contexts ].any ())
280292
281293 if self .use_initial_states :
294+ # Compute extra_chunks using pure Python arithmetic on CPU
295+ # seq_lens to avoid any GPU->CPU sync point.
296+ _cs = self .chunk_size
297+ _cumsum = 0
298+ _extra = 0
299+ for i in range (num_contexts - 1 ):
300+ _cumsum += int (attn_metadata .seq_lens [i ])
301+ if _cumsum % _cs != 0 :
302+ _extra += 1
303+
282304 self .chunk_indices , self .chunk_offsets = cu_seqlens_to_chunk_indices_offsets_triton (
283- self .cu_seqlens [:num_contexts + 1 ], self .chunk_size )
305+ self .cu_seqlens [:num_contexts + 1 ],
306+ self .chunk_size ,
307+ total_seqlens = num_ctx_tokens ,
308+ extra_chunks = _extra )
284309 else :
285310 self .chunk_indices = None
286311 self .chunk_offsets = None
0 commit comments