Skip to content

Commit 8f3a72c

Browse files
committed
add fix for remove columns collator which fails with streaming
1 parent f7210f7 commit 8f3a72c

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from fms_acceleration import AccelerationPlugin
2121
from peft import LoraConfig
2222
from transformers import DataCollatorForSeq2Seq, TrainingArguments
23+
from transformers.trainer_utils import RemoveColumnsCollator
2324
from trl import DataCollatorForCompletionOnlyLM # pylint: disable=import-error
2425
import torch
2526

@@ -71,7 +72,7 @@ def _collator_check(collate_fn):
7172
# `DataCollatorForSeq2Seq` collate_fn,
7273
# otherwise the collation can be unreliable"
7374
return isinstance(
74-
collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM)
75+
collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM, RemoveColumnsCollator)
7576
)
7677

7778
# This check is done here to only patch the attention forward
@@ -97,6 +98,14 @@ def _collator_check(collate_fn):
9798

9899
def _collator_replacement_builder(collate_fn):
99100

101+
# in case of remove columns collator the actual collate
102+
# function is wrapped inside
103+
if isinstance(collate_fn, RemoveColumnsCollator):
104+
actual_collate_fn = collate_fn.data_collator
105+
replacement = _collator_replacement_builder(actual_collate_fn)
106+
collate_fn.data_collator = replacement
107+
return collate_fn
108+
100109
# in this case, replace seq2seq with flattening collator
101110
if isinstance(collate_fn, DataCollatorForSeq2Seq):
102111
return DataCollatorWithFlattening()

0 commit comments

Comments
 (0)