@@ -43,6 +43,7 @@ class NVLlamaConfig(LlamaConfig):
4343 """NVLlama configuration."""
4444
4545 attn_input_format : str = "thd"
46+ self_attn_mask_type : str = "padding_causal"
4647
4748
4849class NVLlamaPreTrainedModel (PreTrainedModel ):
@@ -118,7 +119,7 @@ def _init_method(x):
118119 normalization = "RMSNorm" ,
119120 activation = "swiglu" ,
120121 attn_input_format = config .attn_input_format ,
121- self_attn_mask_type = "padding_causal" ,
122+ self_attn_mask_type = config . self_attn_mask_type ,
122123 num_gqa_groups = config .num_key_value_heads ,
123124 layer_number = layer_idx + 1 ,
124125 params_dtype = config .dtype ,
@@ -181,49 +182,32 @@ def forward(
181182
182183 hidden_states = inputs_embeds
183184
185+ # TE-specific input handling.
184186 has_thd_input = [x in kwargs for x in ["cu_seq_lens_q" , "cu_seq_lens_k" , "max_length_q" , "max_length_k" ]]
185187 should_pack_inputs = not any (has_thd_input ) and self .config .attn_input_format == "thd"
186188
187- # This might be slower for BSHD + padding with fused attention backend. But it should be faster for the flash
188- # attention backend.
189- self_attn_mask_type = "padding_causal"
190189 if should_pack_inputs :
191- # Left-side padding is not supported in TE layers, so to make generation work with TE we dynamically convert
192- # to THD-style inputs in our forward pass, and then convert back to BSHD for the output. This lets the
193- # entire transformer stack run in THD mode.
190+ # Left-side padding is not supported in TE layers, so to make huggingface-style generation work with TE we
191+ # dynamically convert to THD-style inputs in our forward pass, and then convert back to BSHD for the output.
192+ # This lets the entire transformer stack run in THD mode. This might be slower for BSHD + padding with fused
193+ # attention backend, but it should be faster for the flash attention backend.
194194 assert attention_mask is not None , "Attention mask is required when packing BSHD inputs."
195195 batch_size = hidden_states .size (0 )
196196 hidden_states , indices , cu_seqlens , max_seqlen , _ = _unpad_input (hidden_states , attention_mask )
197- cu_seq_lens_q = cu_seq_lens_k = cu_seqlens
198- max_length_q = max_length_k = max_seqlen
197+ kwargs [ " cu_seq_lens_q" ] = kwargs [ " cu_seq_lens_k" ] = cu_seqlens
198+ kwargs [ " max_length_q" ] = kwargs [ " max_length_k" ] = max_seqlen
199199
200- elif self .config .attn_input_format == "thd" :
201- # Here, we're providing THD-style inputs, so we can just grab the kwargs.
202- assert hidden_states .dim () == 3 and hidden_states .size (0 ) == 1 , (
203- "THD expects embeddings shaped [1, total_tokens, hidden_size]."
204- )
200+ if self .config .attn_input_format == "thd" and hidden_states .dim () == 3 and hidden_states .size (0 ) == 1 :
201+ # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE
202+ # expects a 2-dimensional tensor with shape [total_tokens, hidden_size].
205203 hidden_states = hidden_states .squeeze (0 )
206- cu_seq_lens_q = kwargs ["cu_seq_lens_q" ]
207- cu_seq_lens_k = kwargs ["cu_seq_lens_k" ]
208- max_length_q = kwargs ["max_length_q" ]
209- max_length_k = kwargs ["max_length_k" ]
210-
211- else :
212- if attention_mask is not None :
213- attention_mask = attention_mask [:, None , None , :] < - 1
214- else :
215- self_attn_mask_type = "causal"
216- cu_seq_lens_q = cu_seq_lens_k = None
217- max_length_q = max_length_k = hidden_states .size (1 )
218-
219- # If we're using kv-caching, we can't trust the max_length_q value as the true max length for rotary
220- # embeddings, since this will be 1 in generation. Instead we can take the max sequence length from the past
221- # key values object.
222- te_rope_emb = self .rotary_emb (
223- max_seq_len = max_length_q if past_key_values is None else past_key_values .max_ctx_len
224- )
225204
226- if isinstance (past_key_values , InferenceParams ):
205+ if self .config .attn_input_format == "bshd" and attention_mask is not None and attention_mask .dim () == 2 :
206+ # If we're using padded BSHD inputs, we need to convert the 2-dimensional mask to a 4-dimensional mask in
207+ # the expected boolean format for TE.
208+ attention_mask = attention_mask [:, None , None , :] < - 1
209+
210+ if isinstance (past_key_values , InferenceParams ): # InferenceParams is TE's way of managing kv-caching.
227211 # In generation mode, we set the length to 1 for each batch index. Otherwise, we use the attention mask to
228212 # compute the lengths of each sequence in the batch.
229213 lengths = (
@@ -233,6 +217,8 @@ def forward(
233217 )
234218 past_key_values .pre_step (OrderedDict (zip (list (range (len (lengths ))), lengths )))
235219
220+ te_rope_emb = self .rotary_emb (max_seq_len = self .config .max_position_embeddings )
221+
236222 for decoder_layer in self .layers [: self .config .num_hidden_layers ]:
237223 if output_hidden_states :
238224 all_hidden_states = (* all_hidden_states , hidden_states )
@@ -241,12 +227,14 @@ def forward(
241227 hidden_states ,
242228 attention_mask = None if self .config .attn_input_format == "thd" else attention_mask ,
243229 rotary_pos_emb = te_rope_emb ,
244- self_attn_mask_type = self_attn_mask_type ,
245230 inference_params = past_key_values ,
246- cu_seqlens_q = cu_seq_lens_q ,
247- cu_seqlens_kv = cu_seq_lens_k ,
248- max_seqlen_q = max_length_q ,
249- max_seqlen_kv = max_length_k ,
231+ cu_seqlens_q = kwargs .get ("cu_seq_lens_q" , None ),
232+ cu_seqlens_kv = kwargs .get ("cu_seq_lens_k" , None ),
233+ cu_seqlens_q_padded = kwargs .get ("cu_seq_lens_q_padded" , None ),
234+ cu_seqlens_kv_padded = kwargs .get ("cu_seq_lens_k_padded" , None ),
235+ max_seqlen_q = kwargs .get ("max_length_q" , None ),
236+ max_seqlen_kv = kwargs .get ("max_length_k" , None ),
237+ pad_between_seqs = kwargs .get ("pad_between_seqs" , None ),
250238 )
251239
252240 hidden_states = self .norm (hidden_states )
@@ -258,7 +246,7 @@ def forward(
258246
259247 if should_pack_inputs :
260248 # If we've converted BSHD to THD for our TE layers, we need to convert back to BSHD for the output.
261- hidden_states = _pad_input (hidden_states , indices , batch_size , max_length_q )
249+ hidden_states = _pad_input (hidden_states , indices , batch_size , max_seqlen )
262250
263251 return BaseModelOutputWithPast (
264252 last_hidden_state = hidden_states ,
0 commit comments