diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py index 67c75369..65ba520b 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py @@ -20,6 +20,7 @@ from fms_acceleration import AccelerationPlugin from peft import LoraConfig from transformers import DataCollatorForSeq2Seq, TrainingArguments +from transformers.trainer_utils import RemoveColumnsCollator from trl import ( # pylint: disable=import-error, no-name-in-module DataCollatorForCompletionOnlyLM, ) @@ -72,6 +73,8 @@ def _collator_check(collate_fn): # "The padding-free plugin currently only works with a # `DataCollatorForSeq2Seq` collate_fn, # otherwise the collation can be unreliable" + if isinstance(collate_fn, RemoveColumnsCollator): + collate_fn = collate_fn.data_collator return isinstance( collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM) ) @@ -99,6 +102,14 @@ def _collator_check(collate_fn): def _collator_replacement_builder(collate_fn): + # in case of remove columns collator the actual collate + # function is wrapped inside + if isinstance(collate_fn, RemoveColumnsCollator): + actual_collate_fn = collate_fn.data_collator + replacement = _collator_replacement_builder(actual_collate_fn) + collate_fn.data_collator = replacement + return collate_fn + # in this case, replace seq2seq with flattening collator if isinstance(collate_fn, DataCollatorForSeq2Seq): return DataCollatorWithFlattening()