Skip to content

Commit af8ffbe

Browse files
committed
fix collator
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 11d583b commit af8ffbe

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

models/esm2/src/esm/collator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def __init__(
111111
self.flattening_collator = DataCollatorWithFlattening(
112112
return_flash_attn_kwargs=return_flash_attn_kwargs,
113113
return_seq_idx=return_seq_idx,
114+
return_tensors=return_tensors,
114115
)
116+
self.return_tensors = return_tensors
115117

116118
def __call__(self, features, return_tensors=None):
117119
"""Process a batch of variable-length sequences for Flash Attention with MLM.
@@ -168,6 +170,9 @@ def __call__(self, features, return_tensors=None):
168170
sequence_length=total_tokens, optimized for Flash Attention's variable-length
169171
sequence processing capabilities.
170172
"""
173+
if return_tensors is None:
174+
return_tensors = self.return_tensors
175+
171176
batch = self.flattening_collator(features, return_tensors)
172177

173178
special_tokens_mask = batch.pop("special_tokens_mask", None)

0 commit comments

Comments
 (0)