2020from fms_acceleration import AccelerationPlugin
2121from peft import LoraConfig
2222from transformers import DataCollatorForSeq2Seq , TrainingArguments
23+ from transformers .trainer_utils import RemoveColumnsCollator
2324from trl import DataCollatorForCompletionOnlyLM # pylint: disable=import-error
2425import torch
2526
@@ -71,7 +72,7 @@ def _collator_check(collate_fn):
7172 # `DataCollatorForSeq2Seq` collate_fn,
7273 # otherwise the collation can be unreliable"
7374 return isinstance (
74- collate_fn , (DataCollatorForSeq2Seq , DataCollatorForCompletionOnlyLM )
75+ collate_fn , (DataCollatorForSeq2Seq , DataCollatorForCompletionOnlyLM , RemoveColumnsCollator )
7576 )
7677
7778 # This check is done here to only patch the attention forward
@@ -97,6 +98,14 @@ def _collator_check(collate_fn):
9798
9899 def _collator_replacement_builder (collate_fn ):
99100
101+ # in case of remove columns collator the actual collate
102+ # function is wrapped inside
103+ if isinstance (collate_fn , RemoveColumnsCollator ):
104+ actual_collate_fn = collate_fn .data_collator
105+ replacement = _collator_replacement_builder (actual_collate_fn )
106+ collate_fn .data_collator = replacement
107+ return collate_fn
108+
100109 # in this case, replace seq2seq with flattening collator
101110 if isinstance (collate_fn , DataCollatorForSeq2Seq ):
102111 return DataCollatorWithFlattening ()
0 commit comments