@@ -54,11 +54,21 @@ def process_input_for_thd(
5454 [total_tokens, hidden_dim] for 3D embeddings
5555 - 'labels': Reshaped labels tensor of shape [total_tokens]
5656 - 'position_ids': Reshaped tensor of shape [total_tokens]
57- - 'cu_seqlens': Cumulative padded sequence lengths tensor of shape [num_sequences + 1] (int32)
57+ - 'cu_seqlens': Cumulative REAL sequence lengths tensor of shape [num_sequences + 1] (int32)
5858 where num_sequences is the total count of non-padded sequences across the batch.
59- NOTE: This contains cumulative lengths from seq_lens_padded (not seq_lens) since
60- CP doesn't support padding between sequences (resulting in NaNs). The labels or loss mask
61- will ensure that loss is computed correctly.
59+ Built from seq_lens (the unpadded real lengths). When the trailing pack-pad is
60+ purely at the end (cp_size == 1), the last entry is grown to total_tokens to absorb
61+ that pad and avoid TE's ``pad_between_seqs=True`` path; see the absorption block in
62+ the function body for the gate.
63+ - 'cu_seqlens_padded': (optional) Cumulative PADDED sequence lengths tensor of the same
64+ shape as ``cu_seqlens``. Only emitted when it differs from ``cu_seqlens`` after
65+ absorption (i.e., when padding lives between sub-sequences, which is the CP case).
66+ Forwarded to TE as ``cu_seqlens_q_padded`` / ``cu_seqlens_kv_padded`` with
67+ ``pad_between_seqs=True`` so the kernel reads memory offsets from the padded
68+ variant while attending only over the real-length slots.
69+ - 'max_seqlen': Scalar int32 tensor equal to ``max(cu_seqlens[i+1] - cu_seqlens[i])``
70+ after any absorption. Honors TE's contract that
71+ ``max_seqlen_q >= max(cu_seqlens_q[i+1] - cu_seqlens_q[i])``.
6272 - 'padding_mask': Boolean tensor of shape [total_tokens] indicating padding positions
6373 - Non-tensor keys from input batch are preserved (e.g., 'qkv_format')
6474
@@ -77,8 +87,11 @@ def process_input_for_thd(
7787 >>> # result['input_ids'].shape: [12] (2D input collapsed to 1D)
7888 >>> # result['labels'].shape: [12]
7989 >>> # result['position_ids'].shape: [12]
80- >>> # result['cu_seqlens']: tensor([0, 4, 6, 12], dtype=torch.int32)
90+ >>> # result['cu_seqlens']: tensor([0, 3, 5, 11], dtype=torch.int32)
91+ >>> # Breakdown: [0] + cumsum([3, 2, 6]) = [0, 3, 5, 11] (from seq_lens — real lengths)
92+ >>> # result['cu_seqlens_padded']: tensor([0, 4, 6, 12], dtype=torch.int32)
8193 >>> # Breakdown: [0] + cumsum([4, 2, 6]) = [0, 4, 6, 12] (from seq_lens_padded)
94+ >>> # result['max_seqlen']: tensor(6, dtype=torch.int32) # max slot width in cu_seqlens
8295 >>> # result['padding_mask'].shape: [12]
8396 """
8497 input_ids = batch ["input_ids" ]
@@ -96,13 +109,13 @@ def process_input_for_thd(
96109 input_ids_thd = input_ids .reshape (total_tokens , - 1 ).squeeze (- 1 )
97110 labels_thd = labels .reshape (total_tokens , - 1 ).squeeze (- 1 )
98111
112+ cu_seqlens = None
113+ cu_seqlens_padded = None
114+ max_seqlen = None
99115 if seq_lens is not None :
100- # Filter out padding values and flatten
101- # seq_lens shape: [batch_size, num_packs] -> flatten and remove padding values
102116 seq_lens_flat = seq_lens .reshape (- 1 )
103117 valid_seq_lens = seq_lens_flat [seq_lens_flat != seq_lens_padding_value ]
104118
105- # Compute cumulative sequence lengths for attention
106119 cu_seqlens = torch .cat (
107120 [
108121 torch .tensor ([0 ], dtype = valid_seq_lens .dtype , device = valid_seq_lens .device ),
@@ -112,7 +125,6 @@ def process_input_for_thd(
112125 cu_seqlens = cu_seqlens .to (dtype = torch .int32 ).to (device = valid_seq_lens .device )
113126
114127 if seq_lens_padded is not None :
115- # Same processing for padded sequence lengths
116128 seq_lens_padded_flat = seq_lens_padded .reshape (- 1 )
117129 valid_seq_lens_padded = seq_lens_padded_flat [seq_lens_padded_flat != seq_lens_padding_value ]
118130
@@ -121,16 +133,46 @@ def process_input_for_thd(
121133 )
122134 cu_seqlens_padded = cu_seqlens_padded .to (dtype = torch .int32 ).to (device = valid_seq_lens_padded .device )
123135
136+ # Trailing-only pack-pad (cp_size==1): absorb into cu_seqlens[-1] so
137+ # the emit gate below drops cu_seqlens_padded and TE skips its
138+ # pad_between_seqs=True path. CP>1 differs in multiple entries and
139+ # falls through; both arrays are emitted and TE handles padding.
140+ if (
141+ cu_seqlens is not None
142+ and cu_seqlens_padded is not None
143+ and cu_seqlens .numel () == cu_seqlens_padded .numel ()
144+ and cu_seqlens .numel () > 1
145+ and torch .equal (cu_seqlens [:- 1 ], cu_seqlens_padded [:- 1 ])
146+ ):
147+ _total = int (total_tokens )
148+ _real_total = int (cu_seqlens [- 1 ].item ())
149+ if _real_total < _total :
150+ _extended = cu_seqlens .clone ()
151+ _extended [- 1 ] = _total
152+ cu_seqlens = _extended
153+ cu_seqlens_padded = cu_seqlens .clone ()
154+
155+ # Compute max_seqlen from the FINAL cu_seqlens to honor TE's contract
156+ # (``max_seqlen_q >= max(cu_seqlens[i+1] - cu_seqlens[i])``, see TE's
157+ # cpp_extensions/fused_attn.py:152-159).
158+ if cu_seqlens is not None and cu_seqlens .numel () > 1 :
159+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().to (dtype = torch .int32 )
160+
124161 result = {
125162 "input_ids" : input_ids_thd ,
126163 "position_ids" : position_ids_thd ,
127- # Pass cu_seqlens_padded here since CP doesn't support padding between sequences correctly, the labels or loss mask will ensure that loss is computed correctly.
128- "cu_seqlens" : cu_seqlens_padded ,
164+ "cu_seqlens" : cu_seqlens ,
129165 "labels" : labels_thd ,
130166 "padding_mask" : (input_ids_thd == padding_token_id ),
131167 }
168+ # Emit cu_seqlens_padded only when it differs from cu_seqlens — its
169+ # presence is what flips TE's pad_between_seqs=True path in
170+ # attention/utils.py.
171+ if cu_seqlens_padded is not None and not torch .equal (cu_seqlens_padded , cu_seqlens ):
172+ result ["cu_seqlens_padded" ] = cu_seqlens_padded
173+ if max_seqlen is not None :
174+ result ["max_seqlen" ] = max_seqlen
132175
133- # Preserve qkv_format and other non-tensor keys from the original batch
134176 for key , value in batch .items ():
135177 if key not in result and not isinstance (value , torch .Tensor ):
136178 result [key ] = value
@@ -175,8 +217,14 @@ def split_batch_into_thd_chunks(
175217 - 'input_ids': [num_chunks, tokens_per_chunk] or [num_chunks, tokens_per_chunk, hidden_dim]
176218 - 'labels': [num_chunks, tokens_per_chunk]
177219 - 'position_ids': [num_chunks, tokens_per_chunk]
178- - 'cu_seqlens': [num_chunks, max_sequences_per_chunk + 1] (padded with seq_lens_padding_value).
179- Contains cumulative lengths from seq_lens_padded for CP compatibility.
220+ - 'cu_seqlens': [num_chunks, max_sequences_per_chunk + 1] (right-padded with
221+ seq_lens_padding_value across chunks for rectangularity). Built from seq_lens
222+ (real lengths) per chunk; see ``process_input_for_thd`` for the absorption
223+ semantics applied per chunk.
224+ - 'cu_seqlens_padded': (optional) Same shape, emitted whenever ANY chunk emits it.
225+ For chunks that absorbed (no separate padded variant), this row equals the
226+ chunk's ``cu_seqlens``.
227+ - 'max_seqlen': [num_chunks] per-chunk scalar tensor.
180228 - 'padding_mask': [num_chunks, tokens_per_chunk]
181229 - Non-tensor keys from input batch are preserved
182230 - When num_chunks <= 1:
@@ -230,12 +278,21 @@ def pad_and_stack(tensor_list, padding_value):
230278 for i in range (num_chunks )
231279 ]
232280
233- # Stack results
234- return {
281+ stacked : dict = {
235282 "input_ids" : torch .stack ([c ["input_ids" ] for c in chunk_results ]),
236283 "labels" : torch .stack ([c ["labels" ] for c in chunk_results ]),
237284 "position_ids" : torch .stack ([c ["position_ids" ] for c in chunk_results ]),
238285 "cu_seqlens" : pad_and_stack ([c ["cu_seqlens" ] for c in chunk_results ], seq_lens_padding_value ),
239286 "padding_mask" : torch .stack ([c ["padding_mask" ] for c in chunk_results ]),
240- ** {k : v for k , v in chunk_results [0 ].items () if not isinstance (v , torch .Tensor )},
241287 }
288+ # Emit cu_seqlens_padded whenever any chunk emits it; absorbed chunks
289+ # fall back to their cu_seqlens (semantically equal) for rectangularity.
290+ if any ("cu_seqlens_padded" in c for c in chunk_results ):
291+ stacked ["cu_seqlens_padded" ] = pad_and_stack (
292+ [c .get ("cu_seqlens_padded" , c ["cu_seqlens" ]) for c in chunk_results ],
293+ seq_lens_padding_value ,
294+ )
295+ if all ("max_seqlen" in c for c in chunk_results ):
296+ stacked ["max_seqlen" ] = torch .stack ([c ["max_seqlen" ] for c in chunk_results ])
297+ stacked .update ({k : v for k , v in chunk_results [0 ].items () if not isinstance (v , torch .Tensor )})
298+ return stacked
0 commit comments