1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ # Standard
16+ from typing import Optional
1517
1618# Third Party
1719from torch .utils .data import IterableDataset
18- from transformers import (
19- AutoConfig ,
20- AutoTokenizer ,
21- DataCollatorForSeq2Seq ,
22- Seq2SeqTrainer ,
23- Seq2SeqTrainingArguments ,
24- Trainer ,
25- )
20+ from transformers import AutoConfig , AutoTokenizer
2621import torch
2722
2823# First Party
3530
3631# Local
3732from ...data_model import GenerationTrainRecord
33+ from ...resources .pretrained_model import (
34+ HFAutoCausalLM ,
35+ HFAutoSeq2SeqLM ,
36+ PretrainedModelBase ,
37+ )
3838from ...toolkit .data_stream_wrapper import SimpleIterableStreamWrapper
3939from ...toolkit .data_type_utils import get_torch_dtype
4040
5555class FineTuning (ModuleBase ):
5656 """Module to provide fine-tuning support for text generation task"""
5757
58- def __init__ (self , tokenizer , model ):
58+ RANDOM_SEED = 73
59+ supported_resources = [HFAutoCausalLM , HFAutoSeq2SeqLM ]
60+
61+ def __init__ (
62+ self ,
63+ tokenizer ,
64+ model ,
65+ bos_token : Optional [str ] = None ,
66+ sep_token : Optional [str ] = None ,
67+ eos_token : Optional [str ] = None ,
68+ pad_token : Optional [str ] = None ,
69+ ):
5970 super ().__init__ ()
6071
6172 self .tokenizer = tokenizer
62- # NOTE: self.model here can also be HF trainer. This is because
63- # if we have just trained the model then the models weights might be
64- # available in different devices (and configuration), depending on
65- # how it was trained. For now (July 10, 2023), we are not trying to
66- # extract the model out from trainer itself, since that would require
67- # us to essentially save it or reconstruct it to do normal inferring.
6873 self .model = model
74+ self ._bos_token = bos_token
75+ self ._sep_token = sep_token
76+ self ._eos_token = eos_token
77+ self ._pad_token = pad_token
6978
7079 @classmethod
7180 def train (
@@ -78,12 +87,49 @@ def train(
7887 batch_size : int = 8 ,
7988 num_epochs : int = 5 ,
8089 accumulate_steps : int = 32 ,
90+ random_seed : int = RANDOM_SEED ,
8191 lr : float = 2e-5 ,
8292 # Directory where model predictions and checkpoints will be written
8393 checkpoint_dir : str = "/tmp" ,
94+ ** training_arguments ,
8495 ):
8596 """
86- # FIXME: Below is currently configured for Seq2Seq only
97+ Fine-tune a CausalLM or Seq2seq text generation model.
98+
99+ Args:
100+ base_model: Union[str, caikit_nlp.resources.pretrained_model.base.PretrainedModelBase]
101+ Base resource model used for underlying generation.
102+ train_stream: DataStream[GenerationTrainRecord] or DataStream[ClassificationTrainRecord]
103+ Data to be used for fine-tuning the generation model.
104+ torch_dtype: str
105+ TODO: Optional[Union[torch.dtype, str]]
106+ Data type to use for training/inference of the underlying text generation model.
107+ If no value is provided, we pull from torch_dtype in config. If an in memory
108+ resource is provided which does not match the specified data type, the model
109+ underpinning the resource will be converted in place to the correct torch dtype.
110+ max_source_length: int
111+ Max length of input sequences being considered. Default: 256.
112+ max_target_length: int
113+ Max length of target sequences being predicted. Default: 128.
114+ batch_size: int
115+ Batch sized to be used for training / evaluation data. Default: 8.
116+ num_epochs: int
117+ Number of epochs to tune the model. Default: 20.
118+ accumulate_steps: int
119+ Number of steps to use for gradient accumulation. Default: 1.
120+ lr: float
121+ Learning rate to be used while tuning model. Default: 2e-5.
122+ checkpoint_dir: str
123+ Directory where model predictions and checkpoints will be written
124+ **training_arguments:
125+ Arguments supported by HF Training Arguments.
126+ TrainingArguments:
127+ https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments
128+ Seq2SeqTrainingArguments:
129+ https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments
130+ Returns:
131+ FineTuning
132+ Instance of this class with fine-tuned models.
87133 """
88134
89135 torch_dtype = get_torch_dtype (torch_dtype )
@@ -92,11 +138,12 @@ def train(
92138 # text_generation module. In future, we would want to consolidate this into
93139 # a base class or a toolkit function
94140 # pylint: disable=duplicate-code
141+ resource_type = None
142+
95143 ## Load base model
96144 if isinstance (base_model , str ):
97145 model_config = AutoConfig .from_pretrained (base_model )
98146
99- resource_type = None
100147 for resource in cls .supported_resources :
101148 if model_config .model_type in resource .SUPPORTED_MODEL_TYPES :
102149 resource_type = resource
@@ -112,8 +159,14 @@ def train(
112159 log .debug ("Bootstrapping base resource [%s]" , base_model )
113160 base_model = resource_type .bootstrap (base_model , torch_dtype = torch_dtype )
114161
162+ else :
163+ # base_model is actually a resource object
164+ resource_type = type (base_model )
165+
166+ error .type_check ("<NLP03221895E>" , PretrainedModelBase , base_model = base_model )
115167 ## Generate data loader from stream
116168 training_dataset : IterableDataset = cls ._preprocess_function (
169+ base_model = base_model ,
117170 train_stream = train_stream ,
118171 tokenizer = base_model .tokenizer ,
119172 max_source_length = max_source_length ,
@@ -144,47 +197,33 @@ def train(
144197 # by optionally accepting `training_args`
145198 # as argument to this train function.
146199 # TODO: Remove all the default used below and make them all configurable
147- training_args = Seq2SeqTrainingArguments (
148- output_dir = checkpoint_dir ,
149- per_device_train_batch_size = batch_size ,
150- per_device_eval_batch_size = batch_size ,
151- num_train_epochs = num_epochs ,
200+
201+ training_args = {
202+ "output_dir" : checkpoint_dir ,
203+ "per_device_train_batch_size" : batch_size ,
204+ "per_device_eval_batch_size" : batch_size ,
205+ "num_train_epochs" : num_epochs ,
206+ "seed" : random_seed ,
152207 # NOTE: We have disabled evaluation for now
153- do_eval = False ,
154- # evaluation_strategy = "epoch",
155- learning_rate = lr ,
156- weight_decay = 0.01 ,
157- save_total_limit = 3 ,
158- predict_with_generate = True ,
159- push_to_hub = False ,
160- no_cuda = False , # Default
161- generation_max_length = max_target_length ,
162- remove_unused_columns = False ,
163- dataloader_pin_memory = False ,
164- gradient_accumulation_steps = accumulate_steps ,
165- eval_accumulation_steps = accumulate_steps ,
166- logging_strategy = "epoch" ,
167- disable_tqdm = True ,
168- # NOTE: Following not possible without save and eval strategy
169- # load_best_model_at_end=True,
208+ "do_eval" : False ,
209+ # "evaluation_strategy ": "epoch",
210+ "learning_rate" : lr ,
211+ "weight_decay" : 0.01 ,
212+ "save_total_limit" : 3 ,
213+ "push_to_hub" : False ,
214+ "no_cuda" : False , # Default
215+ "remove_unused_columns" : False ,
216+ "dataloader_pin_memory" : False ,
217+ "gradient_accumulation_steps" : accumulate_steps ,
218+ "eval_accumulation_steps" : accumulate_steps ,
170219 # eval_steps=1,
220+ # load_best_model_at_end
221+ ** training_arguments ,
171222 ** dtype_based_params ,
172- ## TODO: Make below configurable
173- # fsdp="full_shard auto_wrap",
174- # local_rank=0,
175- )
176-
177- data_collator = DataCollatorForSeq2Seq (
178- tokenizer = base_model .tokenizer , model = base_model .model
179- )
223+ }
180224
181- trainer = Seq2SeqTrainer (
182- base_model .model ,
183- training_args ,
184- train_dataset = training_dataset ,
185- data_collator = data_collator ,
186- tokenizer = base_model .tokenizer ,
187- # compute_metrics=compute_metrics,
225+ trainer = base_model .get_trainer (
226+ train_dataset = training_dataset , ** training_args
188227 )
189228
190229 if num_epochs < 1 :
@@ -201,17 +240,25 @@ def train(
201240
202241 # Start training via Trainer.train function
203242 trainer .train ()
204- # NOTE: By default the model would be available in different ways
205- # depending on where and how it was trained. So we need to fetch the model
206- # from the trainer depending on the training method, like fsdp, ddp etc.
207- # For simplicity, currently we will use trainer as the model since it anyways
208- # enable the `predict` function on it and has all the layers of the model
209- # distributed already, so it will be most optimized to use trainer to
210- # perform prediction at this stage.
243+
244+ # save the model temporarily and reload it
245+ # this is done, since otherwise the model might be distributed in different
246+ # devices, in which case its better to use trainer's `prediction_step`
247+ # functions, but then, they don't always give API similar to `generate`
248+ # and thus cause incompatibilities in `run` function
249+ trainer .save_model (checkpoint_dir )
250+
251+ model = resource_type .bootstrap (
252+ checkpoint_dir , checkpoint_dir , torch_dtype = torch_dtype
253+ )
211254
212255 return cls (
213- tokenizer = base_model .tokenizer ,
214- model = trainer ,
256+ tokenizer = model .tokenizer ,
257+ model = model ,
258+ bos_token = model .tokenizer .bos_token or None ,
259+ sep_token = model .tokenizer .sep_token or None ,
260+ eos_token = model .tokenizer .eos_token or None ,
261+ pad_token = model .tokenizer .pad_token or None ,
215262 )
216263
217264 # pylint: disable=unused-argument
@@ -236,44 +283,41 @@ def run(
236283 GeneratedTextResult
237284 Generated text result
238285 """
239- if isinstance (self .model , Trainer ):
240- # Apply the tokenizer to the sample text & move to correct device
241- tok_tensors = self .tokenizer (text , return_tensors = "pt" )
242- # NOTE: below function is prediction on trainer, for which we need to supply
243- # the actual underlying model as well
244- # NOTE: We are using prediction_step instead of calling `self.model.generate`
245- # because this way HF Trainer automatically handles device placement of the
246- # data and model. Since the model is with Trainer at this point
247- # and thus the device placement be according to training strategy,
248- # its better to let Trainer handle the evaluation / prediction
249-
250- # TODO: Add support for passing extra arguments to prediction_step
251- _ , generated_tokens , _ = self .model .prediction_step (
252- self .model .model ,
253- tok_tensors ,
254- prediction_loss_only = False ,
255- max_new_tokens = max_new_tokens ,
256- min_new_tokens = min_new_tokens ,
257- )
258286
259- generated_text = self .tokenizer .batch_decode (
260- generated_tokens .detach ().cpu ().numpy (), skip_special_tokens = True
261- )[0 ]
287+ inputs = self .model .tokenizer (text , return_tensors = "pt" )
288+ generate_ids = self .model .model .generate (
289+ input_ids = inputs ["input_ids" ],
290+ max_new_tokens = max_new_tokens ,
291+ min_new_tokens = min_new_tokens ,
292+ use_cache = True ,
293+ )
262294
263- else :
264- error (
265- "<NLP38929392E>" ,
266- NotImplementedError (
267- "model prediction on pre-finetuned model currently not supported"
268- ),
295+ token_count = generate_ids .size (1 ) - 1
296+ preds = [
297+ self .model .tokenizer .decode (
298+ g , skip_special_tokens = True , clean_up_tokenization_spaces = True
269299 )
300+ for g in generate_ids
301+ ]
302+ if generate_ids [0 ][- 1 ].item () == self ._eos_token :
303+ finish_reason = "EOS_TOKEN"
304+ elif generate_ids .size (1 ) - 1 == max_new_tokens :
305+ finish_reason = "MAX_TOKENS"
306+ else :
307+ finish_reason = "OTHER"
270308
271- return GeneratedTextResult (generated_text = generated_text )
309+ return GeneratedTextResult (
310+ generated_tokens = token_count ,
311+ generated_text = preds [0 ],
312+ finish_reason = finish_reason ,
313+ producer_id = self .PRODUCER_ID ,
314+ )
272315
273316 ################################## Private Functions ###########################################
274317
275318 @staticmethod
276319 def _preprocess_function (
320+ base_model : PretrainedModelBase ,
277321 train_stream : DataStream [GenerationTrainRecord ],
278322 tokenizer : AutoTokenizer ,
279323 max_source_length : int ,
@@ -282,28 +326,17 @@ def _preprocess_function(
282326 ):
283327 """Pre-process each example to get it prepared for training."""
284328
285- # FIXME: Below is currently configured for Seq2Seq only
286-
287- def _tokenization_func (
288- example : GenerationTrainRecord ,
289- ):
290- model_inputs = tokenizer (
291- example .input ,
292- max_length = max_source_length ,
293- truncation = True ,
294- )
295-
296- labels = tokenizer (
297- example .output ,
298- max_length = max_target_length ,
299- padding = "max_length" ,
300- truncation = True ,
301- )
302-
303- model_inputs ["labels" ] = labels ["input_ids" ]
304-
305- return model_inputs
306-
307- return SimpleIterableStreamWrapper (
308- train_stream .map (_tokenization_func ), shuffle = shuffle
329+ # TODO: We are using a default verbalizer which is strictly tied to
330+ # source training record currently. We need to figure out a better
331+ # way to make verbalizer optional for build_task_tokenize_function
332+ (
333+ tokenize_function ,
334+ requires_unwrapping ,
335+ ) = base_model .build_task_tokenize_function (
336+ tokenizer , max_source_length , max_target_length , verbalizer = "{{input}}"
309337 )
338+ mapped_stream = train_stream .map (tokenize_function )
339+ if requires_unwrapping :
340+ mapped_stream = mapped_stream .flatten ()
341+
342+ return SimpleIterableStreamWrapper (mapped_stream , shuffle = shuffle )
0 commit comments