Skip to content

Commit 7b37b34

Browse files
authored
Add THD input format support to ESM-2 model (#1149)
Allows our ESM-2 model to natively support THD inputs via ```python AutoModel.from_pretrained( "nvidia/esm2_t6_8M_UR50D", trust_remote_code=True, attn_input_format="thd" ) ``` also adds golden value tests to ensure the outputs from BSHD and THD attention formats are equivalent <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added flattening+MLM collators that emit Flash Attention/THD metadata and optional sequence indices. - ESM‑2 models now support an explicit "thd" attention-input mode for Transformer Engine execution. - **Tests** - New THD-focused tests and fixtures validating parity between THD and batch‑first behavior. - Updated distributed test to use explicit collator setup. - **Chores** - Lint/isort configuration tweaks. - Public API signatures/returns adjusted and one legacy helper removed. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 54bfba8 commit 7b37b34

8 files changed

Lines changed: 596 additions & 84 deletions

File tree

models/.ruff.toml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@ line-length = 119
22
target-version = "py312"
33

44
[lint]
5-
ignore = ["D100", "E501", "N811", "N814"]
5+
ignore = ["C901", "D100", "E501", "N811", "N814"]
66
select = [
7-
"C", # Pylint conventions
8-
"D", # Documentation formatting
9-
"E", # style stuff, whitespaces
10-
"F", # important pyflakes lints
11-
"I", # import sorting
12-
"RUF", # Some Ruff-specific lints, unused noqas, etc.
13-
"W", # Pylint warnings
7+
"C", # Pylint conventions
8+
"D", # Documentation formatting
9+
"E", # style stuff, whitespaces
10+
"F", # important pyflakes lints
11+
"FURB",
12+
"I", # import sorting
1413
"N",
1514
"NPY",
1615
"PERF",
1716
"PLE",
1817
"PLW",
19-
"FURB",
18+
"RUF", # Some Ruff-specific lints, unused noqas, etc.
19+
"W", # Pylint warnings
2020
]
2121

2222
# Allow fix for all enabled rules (when `--fix`) is provided.
@@ -59,6 +59,7 @@ exclude = [
5959
[lint.isort]
6060
lines-after-imports = 2
6161
known-third-party = ["wandb"]
62+
known-first-party = ["esm"]
6263

6364
[lint.pydocstyle]
6465
convention = "google"

models/esm2/src/esm/collator.py

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Data collator for THD input format tests.
17+
18+
This should eventually get moved to a separate package, or possibly upstreamed into `transformers`.
19+
"""
20+
21+
from dataclasses import dataclass
22+
23+
import numpy as np
24+
from transformers import DataCollatorForLanguageModeling, DefaultDataCollator, PreTrainedTokenizerBase
25+
26+
27+
class MLMDataCollatorWithFlattening:
28+
"""Combines a DataCollatorForLanguageModeling and a DataCollatorWithFlattening.
29+
30+
This data collator enables efficient training on variable-length sequences by:
31+
1. First flattening multiple sequences into a single packed tensor (no padding)
32+
2. Then applying MLM masking to the flattened sequence
33+
3. Providing Flash Attention metadata (cu_seq_lens) for sequence boundary awareness.
34+
Note. cu_seq_lens stands for cumulative sequence lengths.
35+
36+
The result is a THD-format batch optimized for Flash Attention with sequence packing,
37+
eliminating the need for traditional attention masks while maintaining proper sequence
38+
boundaries during attention computation.
39+
40+
Attributes:
41+
mlm_collator (DataCollatorForLanguageModeling): Handles MLM token masking.
42+
flattening_collator (DataCollatorWithFlattening): Handles sequence packing and
43+
Flash Attention metadata generation.
44+
45+
Example:
46+
>>> from transformers import AutoTokenizer, DataCollatorForLanguageModeling
47+
>>> from transformers.data.data_collator import DataCollatorWithFlattening
48+
>>>
49+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
50+
>>>
51+
>>> # Input: Variable-length protein sequences
52+
>>> sequences = [
53+
... {"input_ids": [0, 5, 6, 7, 2]}, # CLS + amino acids + EOS (5 tokens)
54+
... {"input_ids": [0, 8, 9, 10, 11, 2]}, # CLS + amino acids + EOS (6 tokens)
55+
... {"input_ids": [0, 12, 13, 2]}, # CLS + amino acids + EOS (4 tokens)
56+
... ]
57+
>>>
58+
>>> # Create the collator
59+
>>> collator = MLMDataCollatorWithFlattening(
60+
... tokenizer=tokenizer,
61+
... mlm_probability=0.15,
62+
... return_flash_attn_kwargs=True,
63+
... )
64+
>>>
65+
>>> # Process batch
66+
>>> batch = collator(sequences)
67+
>>>
68+
>>> # Output: Flattened and masked sequences
69+
>>> print(batch['input_ids'])
70+
>>> # tensor([[ 0, 5, 6, 7, 2, 0, 8, 9, 10, 11, 2, 0, 12, 16, 2]])
71+
>>> # ↑ masked token
72+
>>>
73+
>>> print(batch['labels'])
74+
>>> # tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 13, -100]])
75+
>>> # ↑ original token
76+
>>>
77+
>>> print(batch['cu_seq_lens_q'])
78+
>>> # tensor([ 0, 5, 11, 15], dtype=torch.int32) # Sequence boundaries: [0:5], [5:11], [11:15]
79+
>>>
80+
>>> # Ready for Flash Attention without attention masks!
81+
"""
82+
83+
def __init__(
84+
self,
85+
# DataCollatorForLanguageModeling
86+
tokenizer: PreTrainedTokenizerBase,
87+
mlm: bool = True,
88+
mlm_probability: float | None = 0.15,
89+
mask_replace_prob: float = 0.8,
90+
random_replace_prob: float = 0.1,
91+
pad_to_multiple_of: int | None = None,
92+
tf_experimental_compile: bool = False,
93+
return_tensors: str = "pt",
94+
seed: int | None = None,
95+
# DataCollatorWithFlattening
96+
return_flash_attn_kwargs=True,
97+
return_seq_idx=False,
98+
):
99+
"""Initialize the MLMDataCollatorWithFlattening."""
100+
self.mlm_collator = DataCollatorForLanguageModeling(
101+
tokenizer=tokenizer,
102+
mlm=mlm,
103+
mlm_probability=mlm_probability,
104+
mask_replace_prob=mask_replace_prob,
105+
random_replace_prob=random_replace_prob,
106+
pad_to_multiple_of=pad_to_multiple_of,
107+
tf_experimental_compile=tf_experimental_compile,
108+
return_tensors=return_tensors,
109+
seed=seed,
110+
)
111+
self.flattening_collator = DataCollatorWithFlattening(
112+
return_flash_attn_kwargs=return_flash_attn_kwargs,
113+
return_seq_idx=return_seq_idx,
114+
return_tensors=return_tensors,
115+
)
116+
self.return_tensors = return_tensors
117+
118+
def __call__(self, features, return_tensors=None):
119+
"""Process a batch of variable-length sequences for Flash Attention with MLM.
120+
121+
This method performs a two-step process:
122+
1. Flattens multiple sequences into a single packed tensor with Flash Attention metadata
123+
2. Applies MLM masking to the flattened sequence while preserving special tokens
124+
125+
Args:
126+
features (List[Dict[str, List[int]]]): List of tokenized sequences, each containing
127+
'input_ids' and optionally 'attention_mask'. Example:
128+
[
129+
{"input_ids": [0, 5, 6, 7, 2]}, # Protein sequence 1
130+
{"input_ids": [0, 8, 9, 10, 11, 2]}, # Protein sequence 2
131+
{"input_ids": [0, 12, 13, 2]} # Protein sequence 3
132+
]
133+
return_tensors (str, optional): Format for returned tensors ('pt' for PyTorch).
134+
Defaults to None (uses collator default).
135+
136+
Returns:
137+
Dict[str, torch.Tensor]: Batch dictionary containing:
138+
- input_ids (torch.Tensor): Flattened and MLM-masked token sequences.
139+
Shape: [1, total_tokens] where total_tokens = sum of all sequence lengths.
140+
- labels (torch.Tensor): MLM labels with -100 for non-masked tokens and
141+
original token IDs for masked positions. Same shape as input_ids.
142+
- position_ids (torch.Tensor): Position indices that reset at sequence boundaries.
143+
Shape: [1, total_tokens].
144+
- cu_seq_lens_q (torch.IntTensor): Cumulative sequence lengths for queries.
145+
Shape: [num_sequences + 1]. Example: [0, 5, 11, 15].
146+
- cu_seq_lens_k (torch.IntTensor): Cumulative sequence lengths for keys.
147+
Same as cu_seq_lens_q for self-attention.
148+
- max_length_q (int): Maximum sequence length in the batch.
149+
- max_length_k (int): Same as max_length_q for self-attention.
150+
151+
Example:
152+
>>> # Input features
153+
>>> features = [
154+
... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens
155+
... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens
156+
... {"input_ids": [0, 12, 13, 2]} # 4 tokens
157+
... ]
158+
>>>
159+
>>> batch = collator(features)
160+
>>>
161+
>>> # Output shapes and values
162+
>>> batch['input_ids'].shape # torch.Size([1, 15])
163+
>>> batch['labels'].shape # torch.Size([1, 15])
164+
>>> batch['cu_seq_lens_q'] # tensor([0, 5, 11, 15], dtype=torch.int32)
165+
>>>
166+
>>> # Flash Attention can now process this without attention masks!
167+
168+
Note:
169+
The output is in THD (Total, Height, Depth) format with batch_size=1 and
170+
sequence_length=total_tokens, optimized for Flash Attention's variable-length
171+
sequence processing capabilities.
172+
"""
173+
if return_tensors is None:
174+
return_tensors = self.return_tensors
175+
176+
batch = self.flattening_collator(features, return_tensors)
177+
178+
special_tokens_mask = batch.pop("special_tokens_mask", None)
179+
180+
if return_tensors == "pt":
181+
batch["input_ids"], batch["labels"] = self.mlm_collator.torch_mask_tokens(
182+
batch["input_ids"], special_tokens_mask=special_tokens_mask
183+
)
184+
elif return_tensors == "np":
185+
batch["input_ids"], batch["labels"] = self.mlm_collator.numpy_mask_tokens(
186+
batch["input_ids"], special_tokens_mask=special_tokens_mask
187+
)
188+
else:
189+
raise ValueError(f'return_tensors must be one of ("pt", "np"), {return_tensors=} not suported')
190+
191+
return batch
192+
193+
194+
@dataclass
195+
class DataCollatorWithFlattening(DefaultDataCollator):
196+
"""Data collator used for padding free approach.
197+
198+
Modified from transformers.data.data_collator.DataCollatorWithFlattening to not use a separator_id.
199+
200+
Does the following:
201+
202+
- concatenates the entire mini batch into single long sequence of shape [1, total_tokens]
203+
- no padding will be added, returns `input_ids`, `labels` and `position_ids` by default
204+
- optionally returns the kwargs contained in FlashAttentionKwargs
205+
- optionally returns seq_idx indicating which sequence each token belongs to
206+
207+
<Tip warning={true}>
208+
209+
Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence.
210+
Make sure your attention computation is able to handle it!
211+
212+
</Tip>
213+
"""
214+
215+
def __init__(
216+
self,
217+
*args,
218+
return_flash_attn_kwargs=True,
219+
return_seq_idx=False,
220+
**kwargs,
221+
):
222+
"""Initialize the DataCollatorWithFlattening.
223+
224+
Args:
225+
*args: Arguments for the parent class.
226+
return_flash_attn_kwargs (bool): Whether to return FlashAttention kwargs.
227+
return_seq_idx (bool): Whether to return sequence indices.
228+
**kwargs: Keyword arguments for the parent class.
229+
"""
230+
super().__init__(*args, **kwargs)
231+
self.return_flash_attn_kwargs = return_flash_attn_kwargs
232+
self.return_seq_idx = return_seq_idx
233+
self._int_64_keys = {"labels", "position_ids", "input_ids"}
234+
self._batch_dim_keys = {"labels", "position_ids", "input_ids", "seq_idx"}
235+
self._py_int_keys = {"max_length_q", "max_length_k"}
236+
237+
def __call__(self, features, return_tensors=None):
238+
"""Process a batch of variable-length sequences for Flash Attention with MLM.
239+
240+
Args:
241+
features (List[Dict[str, List[int]]]): List of tokenized sequences, each containing
242+
'input_ids' and optionally 'attention_mask'. Example:
243+
[
244+
{"input_ids": [0, 5, 6, 7, 2]}, # Protein sequence 1
245+
{"input_ids": [0, 8, 9, 10, 11, 2]}, # Protein sequence 2
246+
{"input_ids": [0, 12, 13, 2]} # Protein sequence 3
247+
]
248+
return_tensors (str, optional): Format for returned tensors ('pt' for PyTorch).
249+
Defaults to None (uses collator default).
250+
251+
Returns:
252+
Dict[str, torch.Tensor]: Batch dictionary containing:
253+
- input_ids (torch.Tensor): Flattened and MLM-masked token sequences.
254+
Shape: [1, total_tokens] where total_tokens = sum of all sequence lengths.
255+
- labels (torch.Tensor): MLM labels with -100 for non-masked tokens and
256+
original token IDs for masked positions. Same shape as input_ids.
257+
- position_ids (torch.Tensor): Position indices that reset at sequence boundaries.
258+
Shape: [1, total_tokens].
259+
- cu_seq_lens_q (torch.IntTensor): Cumulative sequence lengths for queries.
260+
Shape: [num_sequences + 1]. Example: [0, 5, 11, 15].
261+
- cu_seq_lens_k (torch.IntTensor): Cumulative sequence lengths for keys.
262+
Same as cu_seq_lens_q for self-attention.
263+
- max_length_q (int): Maximum sequence length in the batch.
264+
- max_length_k (int): Same as max_length_q for self-attention.
265+
266+
Example:
267+
>>> # Input features
268+
>>> features = [
269+
... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens
270+
... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens
271+
... {"input_ids": [0, 12, 13, 2]} # 4 tokens
272+
... ]
273+
>>>
274+
>>> batch = collator(features)
275+
>>>
276+
>>> # Output shapes and values
277+
>>> batch['input_ids'].shape # torch.Size([1, 15])
278+
>>> batch['labels'].shape # torch.Size([1, 15])
279+
>>> batch['cu_seq_lens_q'] # tensor([0, 5, 11, 15], dtype=torch.int32)
280+
>>>
281+
>>> # Flash Attention can now process this without attention masks!
282+
283+
Note:
284+
The output is in THD (Tokens, Height, Depth) format with batch_size=1 and
285+
sequence_length=total_tokens, optimized for Flash Attention's variable-length
286+
sequence processing capabilities.
287+
"""
288+
if return_tensors is None:
289+
return_tensors = self.return_tensors
290+
is_labels_provided = "labels" in features[0]
291+
batch = {"input_ids": [], "labels": []}
292+
if self.return_seq_idx:
293+
batch.update({"seq_idx": []})
294+
if self.return_flash_attn_kwargs:
295+
cu_seq_lens = [0]
296+
max_length = 0
297+
for seq_idx, sample in enumerate(features):
298+
input_ids = sample["input_ids"]
299+
batch["input_ids"] += input_ids
300+
if is_labels_provided:
301+
batch["labels"] += sample["labels"]
302+
if self.return_seq_idx:
303+
batch["seq_idx"] += [seq_idx for _ in range(len(input_ids))]
304+
if self.return_flash_attn_kwargs:
305+
cu_seq_lens.append(cu_seq_lens[-1] + len(input_ids))
306+
max_length = max(max_length, len(input_ids))
307+
308+
if self.return_flash_attn_kwargs:
309+
batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens
310+
batch["max_length_q"] = batch["max_length_k"] = max_length
311+
312+
# FlashAttentionKwargs and seq_idx are expected to be int32s.
313+
if return_tensors == "pt":
314+
import torch
315+
316+
data_cls = torch.tensor
317+
dtype_64 = torch.int64
318+
dtype_32 = torch.int32
319+
elif return_tensors == "np":
320+
data_cls = np.array
321+
dtype_64 = np.int64
322+
dtype_32 = np.int32
323+
else:
324+
raise ValueError(f'return_tensors must be one of ("pt", "np"), {return_tensors=} not suported')
325+
326+
for k, v in batch.items():
327+
v_ = v # Avoid modifying the original loop variable v
328+
if k in self._batch_dim_keys:
329+
v_ = [v]
330+
# Flash attention max_len_{q,k} are python ints
331+
if k not in self._py_int_keys:
332+
batch[k] = data_cls(v_, dtype=dtype_64 if k in self._int_64_keys else dtype_32)
333+
334+
return batch

0 commit comments

Comments
 (0)