Skip to content

Commit 6321b0c

Browse files
authored
Add CP support to modeling_llama_te.py (#1407)
* Makes the attention_mask type more easily controllable through a config kwarg * passes the additional `_padded` kwargs through to the TE layers * simplifies some of the pre-processing logic and adds comments as to why they're needed Fixes BIO-12 Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 6da7e6d commit 6321b0c

4 files changed

Lines changed: 65 additions & 83 deletions

File tree

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class NVLlamaConfig(LlamaConfig):
4343
"""NVLlama configuration."""
4444

4545
attn_input_format: str = "thd"
46+
self_attn_mask_type: str = "padding_causal"
4647

4748

4849
class NVLlamaPreTrainedModel(PreTrainedModel):
@@ -118,7 +119,7 @@ def _init_method(x):
118119
normalization="RMSNorm",
119120
activation="swiglu",
120121
attn_input_format=config.attn_input_format,
121-
self_attn_mask_type="padding_causal",
122+
self_attn_mask_type=config.self_attn_mask_type,
122123
num_gqa_groups=config.num_key_value_heads,
123124
layer_number=layer_idx + 1,
124125
params_dtype=config.dtype,
@@ -181,49 +182,32 @@ def forward(
181182

182183
hidden_states = inputs_embeds
183184

185+
# TE-specific input handling.
184186
has_thd_input = [x in kwargs for x in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]]
185187
should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd"
186188

187-
# This might be slower for BSHD + padding with fused attention backend. But it should be faster for the flash
188-
# attention backend.
189-
self_attn_mask_type = "padding_causal"
190189
if should_pack_inputs:
191-
# Left-side padding is not supported in TE layers, so to make generation work with TE we dynamically convert
192-
# to THD-style inputs in our forward pass, and then convert back to BSHD for the output. This lets the
193-
# entire transformer stack run in THD mode.
190+
# Left-side padding is not supported in TE layers, so to make huggingface-style generation work with TE we
191+
# dynamically convert to THD-style inputs in our forward pass, and then convert back to BSHD for the output.
192+
# This lets the entire transformer stack run in THD mode. This might be slower for BSHD + padding with fused
193+
# attention backend, but it should be faster for the flash attention backend.
194194
assert attention_mask is not None, "Attention mask is required when packing BSHD inputs."
195195
batch_size = hidden_states.size(0)
196196
hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask)
197-
cu_seq_lens_q = cu_seq_lens_k = cu_seqlens
198-
max_length_q = max_length_k = max_seqlen
197+
kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_k"] = cu_seqlens
198+
kwargs["max_length_q"] = kwargs["max_length_k"] = max_seqlen
199199

200-
elif self.config.attn_input_format == "thd":
201-
# Here, we're providing THD-style inputs, so we can just grab the kwargs.
202-
assert hidden_states.dim() == 3 and hidden_states.size(0) == 1, (
203-
"THD expects embeddings shaped [1, total_tokens, hidden_size]."
204-
)
200+
if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1:
201+
# For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE
202+
# expects a 2-dimensional tensor with shape [total_tokens, hidden_size].
205203
hidden_states = hidden_states.squeeze(0)
206-
cu_seq_lens_q = kwargs["cu_seq_lens_q"]
207-
cu_seq_lens_k = kwargs["cu_seq_lens_k"]
208-
max_length_q = kwargs["max_length_q"]
209-
max_length_k = kwargs["max_length_k"]
210-
211-
else:
212-
if attention_mask is not None:
213-
attention_mask = attention_mask[:, None, None, :] < -1
214-
else:
215-
self_attn_mask_type = "causal"
216-
cu_seq_lens_q = cu_seq_lens_k = None
217-
max_length_q = max_length_k = hidden_states.size(1)
218-
219-
# If we're using kv-caching, we can't trust the max_length_q value as the true max length for rotary
220-
# embeddings, since this will be 1 in generation. Instead we can take the max sequence length from the past
221-
# key values object.
222-
te_rope_emb = self.rotary_emb(
223-
max_seq_len=max_length_q if past_key_values is None else past_key_values.max_ctx_len
224-
)
225204

226-
if isinstance(past_key_values, InferenceParams):
205+
if self.config.attn_input_format == "bshd" and attention_mask is not None and attention_mask.dim() == 2:
206+
# If we're using padded BSHD inputs, we need to convert the 2-dimensional mask to a 4-dimensional mask in
207+
# the expected boolean format for TE.
208+
attention_mask = attention_mask[:, None, None, :] < -1
209+
210+
if isinstance(past_key_values, InferenceParams): # InferenceParams is TE's way of managing kv-caching.
227211
# In generation mode, we set the length to 1 for each batch index. Otherwise, we use the attention mask to
228212
# compute the lengths of each sequence in the batch.
229213
lengths = (
@@ -233,6 +217,8 @@ def forward(
233217
)
234218
past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths)))
235219

220+
te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings)
221+
236222
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
237223
if output_hidden_states:
238224
all_hidden_states = (*all_hidden_states, hidden_states)
@@ -241,12 +227,14 @@ def forward(
241227
hidden_states,
242228
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
243229
rotary_pos_emb=te_rope_emb,
244-
self_attn_mask_type=self_attn_mask_type,
245230
inference_params=past_key_values,
246-
cu_seqlens_q=cu_seq_lens_q,
247-
cu_seqlens_kv=cu_seq_lens_k,
248-
max_seqlen_q=max_length_q,
249-
max_seqlen_kv=max_length_k,
231+
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
232+
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
233+
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
234+
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
235+
max_seqlen_q=kwargs.get("max_length_q", None),
236+
max_seqlen_kv=kwargs.get("max_length_k", None),
237+
pad_between_seqs=kwargs.get("pad_between_seqs", None),
250238
)
251239

252240
hidden_states = self.norm(hidden_states)
@@ -258,7 +246,7 @@ def forward(
258246

259247
if should_pack_inputs:
260248
# If we've converted BSHD to THD for our TE layers, we need to convert back to BSHD for the output.
261-
hidden_states = _pad_input(hidden_states, indices, batch_size, max_length_q)
249+
hidden_states = _pad_input(hidden_states, indices, batch_size, max_seqlen)
262250

263251
return BaseModelOutputWithPast(
264252
last_hidden_state=hidden_states,

bionemo-recipes/models/llama3/tests/test_lm_eval.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
def model_checkpoint(tmp_path: Path):
3030
tokenizer = AutoTokenizer.from_pretrained("nvidia/Llama-3.1-8B-Instruct-FP8")
3131
config = NVLlamaConfig.from_pretrained(
32-
"nvidia/Llama-3.1-8B-Instruct-FP8", num_hidden_layers=2, attn_input_format="bshd"
32+
"nvidia/Llama-3.1-8B-Instruct-FP8",
33+
num_hidden_layers=2,
34+
attn_input_format="bshd",
35+
self_attn_mask_type="causal",
3336
)
3437
model = NVLlamaForCausalLM(config)
3538
model.save_pretrained(tmp_path / "checkpoint")

bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def test_llama_model_forward_pass(input_text, attn_input_format):
6969
def test_llama_model_forward_pass_no_attention_mask():
7070
tokenizer = AutoTokenizer.from_pretrained("nvidia/Llama-3.1-8B-Instruct-FP8")
7171
config = NVLlamaConfig.from_pretrained(
72-
"nvidia/Llama-3.1-8B-Instruct-FP8", num_hidden_layers=2, attn_input_format="bshd"
72+
"nvidia/Llama-3.1-8B-Instruct-FP8",
73+
num_hidden_layers=2,
74+
attn_input_format="bshd",
75+
self_attn_mask_type="causal",
7376
)
7477
model = NVLlamaForCausalLM(config)
7578

@@ -269,7 +272,7 @@ def test_hf_llama_model_generate_bshd():
269272
def test_te_llama_model_generate_with_cache():
270273
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
271274
model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16)
272-
model_te = convert_llama_hf_to_te(model_hf)
275+
model_te = convert_llama_hf_to_te(model_hf, self_attn_mask_type="padding_causal")
273276

274277
prompt = """
275278
Licensed under the Apache License, Version 2.0 (the "License");

bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class NVLlamaConfig(LlamaConfig):
4343
"""NVLlama configuration."""
4444

4545
attn_input_format: str = "thd"
46+
self_attn_mask_type: str = "padding_causal"
4647

4748

4849
class NVLlamaPreTrainedModel(PreTrainedModel):
@@ -118,7 +119,7 @@ def _init_method(x):
118119
normalization="RMSNorm",
119120
activation="swiglu",
120121
attn_input_format=config.attn_input_format,
121-
self_attn_mask_type="padding_causal",
122+
self_attn_mask_type=config.self_attn_mask_type,
122123
num_gqa_groups=config.num_key_value_heads,
123124
layer_number=layer_idx + 1,
124125
params_dtype=config.dtype,
@@ -181,49 +182,32 @@ def forward(
181182

182183
hidden_states = inputs_embeds
183184

185+
# TE-specific input handling.
184186
has_thd_input = [x in kwargs for x in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]]
185187
should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd"
186188

187-
# This might be slower for BSHD + padding with fused attention backend. But it should be faster for the flash
188-
# attention backend.
189-
self_attn_mask_type = "padding_causal"
190189
if should_pack_inputs:
191-
# Left-side padding is not supported in TE layers, so to make generation work with TE we dynamically convert
192-
# to THD-style inputs in our forward pass, and then convert back to BSHD for the output. This lets the
193-
# entire transformer stack run in THD mode.
190+
# Left-side padding is not supported in TE layers, so to make huggingface-style generation work with TE we
191+
# dynamically convert to THD-style inputs in our forward pass, and then convert back to BSHD for the output.
192+
# This lets the entire transformer stack run in THD mode. This might be slower for BSHD + padding with fused
193+
# attention backend, but it should be faster for the flash attention backend.
194194
assert attention_mask is not None, "Attention mask is required when packing BSHD inputs."
195195
batch_size = hidden_states.size(0)
196196
hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask)
197-
cu_seq_lens_q = cu_seq_lens_k = cu_seqlens
198-
max_length_q = max_length_k = max_seqlen
197+
kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_k"] = cu_seqlens
198+
kwargs["max_length_q"] = kwargs["max_length_k"] = max_seqlen
199199

200-
elif self.config.attn_input_format == "thd":
201-
# Here, we're providing THD-style inputs, so we can just grab the kwargs.
202-
assert hidden_states.dim() == 3 and hidden_states.size(0) == 1, (
203-
"THD expects embeddings shaped [1, total_tokens, hidden_size]."
204-
)
200+
if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1:
201+
# For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE
202+
# expects a 2-dimensional tensor with shape [total_tokens, hidden_size].
205203
hidden_states = hidden_states.squeeze(0)
206-
cu_seq_lens_q = kwargs["cu_seq_lens_q"]
207-
cu_seq_lens_k = kwargs["cu_seq_lens_k"]
208-
max_length_q = kwargs["max_length_q"]
209-
max_length_k = kwargs["max_length_k"]
210-
211-
else:
212-
if attention_mask is not None:
213-
attention_mask = attention_mask[:, None, None, :] < -1
214-
else:
215-
self_attn_mask_type = "causal"
216-
cu_seq_lens_q = cu_seq_lens_k = None
217-
max_length_q = max_length_k = hidden_states.size(1)
218-
219-
# If we're using kv-caching, we can't trust the max_length_q value as the true max length for rotary
220-
# embeddings, since this will be 1 in generation. Instead we can take the max sequence length from the past
221-
# key values object.
222-
te_rope_emb = self.rotary_emb(
223-
max_seq_len=max_length_q if past_key_values is None else past_key_values.max_ctx_len
224-
)
225204

226-
if isinstance(past_key_values, InferenceParams):
205+
if self.config.attn_input_format == "bshd" and attention_mask is not None and attention_mask.dim() == 2:
206+
# If we're using padded BSHD inputs, we need to convert the 2-dimensional mask to a 4-dimensional mask in
207+
# the expected boolean format for TE.
208+
attention_mask = attention_mask[:, None, None, :] < -1
209+
210+
if isinstance(past_key_values, InferenceParams): # InferenceParams is TE's way of managing kv-caching.
227211
# In generation mode, we set the length to 1 for each batch index. Otherwise, we use the attention mask to
228212
# compute the lengths of each sequence in the batch.
229213
lengths = (
@@ -233,6 +217,8 @@ def forward(
233217
)
234218
past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths)))
235219

220+
te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings)
221+
236222
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
237223
if output_hidden_states:
238224
all_hidden_states = (*all_hidden_states, hidden_states)
@@ -241,12 +227,14 @@ def forward(
241227
hidden_states,
242228
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
243229
rotary_pos_emb=te_rope_emb,
244-
self_attn_mask_type=self_attn_mask_type,
245230
inference_params=past_key_values,
246-
cu_seqlens_q=cu_seq_lens_q,
247-
cu_seqlens_kv=cu_seq_lens_k,
248-
max_seqlen_q=max_length_q,
249-
max_seqlen_kv=max_length_k,
231+
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
232+
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
233+
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
234+
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
235+
max_seqlen_q=kwargs.get("max_length_q", None),
236+
max_seqlen_kv=kwargs.get("max_length_k", None),
237+
pad_between_seqs=kwargs.get("pad_between_seqs", None),
250238
)
251239

252240
hidden_states = self.norm(hidden_states)
@@ -258,7 +246,7 @@ def forward(
258246

259247
if should_pack_inputs:
260248
# If we've converted BSHD to THD for our TE layers, we need to convert back to BSHD for the output.
261-
hidden_states = _pad_input(hidden_states, indices, batch_size, max_length_q)
249+
hidden_states = _pad_input(hidden_states, indices, batch_size, max_seqlen)
262250

263251
return BaseModelOutputWithPast(
264252
last_hidden_state=hidden_states,

0 commit comments

Comments
 (0)