Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig
from transformers import DataCollatorForSeq2Seq, TrainingArguments
from transformers.trainer_utils import RemoveColumnsCollator
from trl import ( # pylint: disable=import-error, no-name-in-module
DataCollatorForCompletionOnlyLM,
)
Expand Down Expand Up @@ -72,6 +73,8 @@ def _collator_check(collate_fn):
# "The padding-free plugin currently only works with a
# `DataCollatorForSeq2Seq` collate_fn,
# otherwise the collation can be unreliable"
if isinstance(collate_fn, RemoveColumnsCollator):
collate_fn = collate_fn.data_collator
return isinstance(
collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM)
)
Expand Down Expand Up @@ -99,6 +102,14 @@ def _collator_check(collate_fn):

def _collator_replacement_builder(collate_fn):

# in case of remove columns collator the actual collate
# function is wrapped inside
if isinstance(collate_fn, RemoveColumnsCollator):
actual_collate_fn = collate_fn.data_collator
replacement = _collator_replacement_builder(actual_collate_fn)
collate_fn.data_collator = replacement
return collate_fn
Comment thread
kmehant marked this conversation as resolved.

# in this case, replace seq2seq with flattening collator
if isinstance(collate_fn, DataCollatorForSeq2Seq):
return DataCollatorWithFlattening()
Expand Down