1515
1616# Third Party
1717from torch .utils .data import IterableDataset
18- from transformers import (
19- AutoConfig ,
20- AutoTokenizer ,
21- DataCollatorForSeq2Seq ,
22- Seq2SeqTrainer ,
23- Seq2SeqTrainingArguments ,
24- Trainer ,
25- )
18+ from transformers import AutoConfig , AutoTokenizer , Trainer
2619
2720# First Party
2821from caikit .core .data_model import DataStream
3225
3326# Local
3427from ...data_model import GeneratedResult , GenerationTrainRecord
28+ from ...resources .pretrained_model .base import PretrainedModelBase
3529from ...toolkit .data_stream_wrapper import SimpleIterableStreamWrapper
3630from ...toolkit .data_type_utils import get_torch_dtype
3731from .text_generation_task import TextGenerationTask
@@ -79,6 +73,7 @@ def train(
7973 lr : float = 2e-5 ,
8074 # Directory where model predictions and checkpoints will be written
8175 checkpoint_dir : str = "/tmp" ,
76+ ** training_arguments
8277 ):
8378 """
8479 # FIXME: Below is currently configured for Seq2Seq only
@@ -110,6 +105,7 @@ def train(
110105 log .debug ("Bootstrapping base resource [%s]" , base_model )
111106 base_model = resource_type .bootstrap (base_model , torch_dtype = torch_dtype )
112107
108+ error .type_check ("<NLP03221895E>" , PretrainedModelBase , base_model = base_model )
113109 ## Generate data loader from stream
114110 training_dataset : IterableDataset = cls ._preprocess_function (
115111 train_stream = train_stream ,
@@ -125,40 +121,33 @@ def train(
125121 # by optionally accepting `training_args`
126122 # as argument to this train function.
127123 # TODO: Remove all the default used below and make them all configurable
128- training_args = Seq2SeqTrainingArguments (
129- output_dir = checkpoint_dir ,
130- per_device_train_batch_size = batch_size ,
131- per_device_eval_batch_size = batch_size ,
132- num_train_epochs = num_epochs ,
124+
125+ training_args = {
126+ "output_dir" : checkpoint_dir ,
127+ "per_device_train_batch_size" : batch_size ,
128+ "per_device_eval_batch_size" : batch_size ,
129+ "num_train_epochs" : num_epochs ,
133130 # NOTE: We have disabled evaluation for now
134- do_eval = False ,
135- # evaluation_strategy = "epoch",
136- learning_rate = lr ,
137- weight_decay = 0.01 ,
138- save_total_limit = 3 ,
139- predict_with_generate = True ,
140- fp16 = True ,
141- push_to_hub = False ,
142- no_cuda = False , # Default
143- generation_max_length = max_target_length ,
144- remove_unused_columns = False ,
145- dataloader_pin_memory = False ,
146- gradient_accumulation_steps = accumulate_steps ,
147- eval_accumulation_steps = accumulate_steps ,
131+ " do_eval" : False ,
132+ " # evaluation_strategy " : "epoch" ,
133+ " learning_rate" : lr ,
134+ " weight_decay" : 0.01 ,
135+ " save_total_limit" : 3 ,
136+ " predict_with_generate" : True ,
137+ " fp16" : True ,
138+ " push_to_hub" : False ,
139+ " no_cuda" : False , # Default
140+ " generation_max_length" : max_target_length ,
141+ " remove_unused_columns" : False ,
142+ " dataloader_pin_memory" : False ,
143+ " gradient_accumulation_steps" : accumulate_steps ,
144+ " eval_accumulation_steps" : accumulate_steps ,
148145 # eval_steps=1,
149- )
146+ ** training_arguments ,
147+ }
150148
151- data_collator = DataCollatorForSeq2Seq (
152- tokenizer = base_model .tokenizer , model = base_model .model
153- )
154-
155- trainer = Seq2SeqTrainer (
156- base_model .model ,
157- training_args ,
158- train_dataset = training_dataset ,
159- data_collator = data_collator ,
160- tokenizer = base_model .tokenizer ,
161- # compute_metrics=compute_metrics,
149+ trainer = base_model .get_trainer (
150+ train_dataset = training_dataset , ** training_args
162151 )
163152
164153 # Start training via Trainer.train function
0 commit comments