Skip to content

Commit d934455

Browse files
committed
♻️ Refactor trainer logic and move it to resources
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
1 parent d346708 commit d934455

3 files changed

Lines changed: 105 additions & 43 deletions

File tree

caikit_nlp/modules/text_generation/fine_tuning.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,7 @@
1515

1616
# Third Party
1717
from torch.utils.data import IterableDataset
18-
from transformers import (
19-
AutoConfig,
20-
AutoTokenizer,
21-
DataCollatorForSeq2Seq,
22-
Seq2SeqTrainer,
23-
Seq2SeqTrainingArguments,
24-
Trainer,
25-
)
18+
from transformers import AutoConfig, AutoTokenizer, Trainer
2619

2720
# First Party
2821
from caikit.core.data_model import DataStream
@@ -32,6 +25,7 @@
3225

3326
# Local
3427
from ...data_model import GeneratedResult, GenerationTrainRecord
28+
from ...resources.pretrained_model.base import PretrainedModelBase
3529
from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper
3630
from ...toolkit.data_type_utils import get_torch_dtype
3731
from .text_generation_task import TextGenerationTask
@@ -79,6 +73,7 @@ def train(
7973
lr: float = 2e-5,
8074
# Directory where model predictions and checkpoints will be written
8175
checkpoint_dir: str = "/tmp",
76+
**training_arguments
8277
):
8378
"""
8479
# FIXME: Below is currently configured for Seq2Seq only
@@ -110,6 +105,7 @@ def train(
110105
log.debug("Bootstrapping base resource [%s]", base_model)
111106
base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype)
112107

108+
error.type_check("<NLP03221895E>", PretrainedModelBase, base_model=base_model)
113109
## Generate data loader from stream
114110
training_dataset: IterableDataset = cls._preprocess_function(
115111
train_stream=train_stream,
@@ -125,40 +121,33 @@ def train(
125121
# by optionally accepting `training_args`
126122
# as argument to this train function.
127123
# TODO: Remove all the default used below and make them all configurable
128-
training_args = Seq2SeqTrainingArguments(
129-
output_dir=checkpoint_dir,
130-
per_device_train_batch_size=batch_size,
131-
per_device_eval_batch_size=batch_size,
132-
num_train_epochs=num_epochs,
124+
125+
training_args = {
126+
"output_dir": checkpoint_dir,
127+
"per_device_train_batch_size": batch_size,
128+
"per_device_eval_batch_size": batch_size,
129+
"num_train_epochs": num_epochs,
133130
# NOTE: We have disabled evaluation for now
134-
do_eval=False,
135-
# evaluation_strategy = "epoch",
136-
learning_rate=lr,
137-
weight_decay=0.01,
138-
save_total_limit=3,
139-
predict_with_generate=True,
140-
fp16=True,
141-
push_to_hub=False,
142-
no_cuda=False, # Default
143-
generation_max_length=max_target_length,
144-
remove_unused_columns=False,
145-
dataloader_pin_memory=False,
146-
gradient_accumulation_steps=accumulate_steps,
147-
eval_accumulation_steps=accumulate_steps,
131+
"do_eval": False,
132+
"# evaluation_strategy ": "epoch",
133+
"learning_rate": lr,
134+
"weight_decay": 0.01,
135+
"save_total_limit": 3,
136+
"predict_with_generate": True,
137+
"fp16": True,
138+
"push_to_hub": False,
139+
"no_cuda": False, # Default
140+
"generation_max_length": max_target_length,
141+
"remove_unused_columns": False,
142+
"dataloader_pin_memory": False,
143+
"gradient_accumulation_steps": accumulate_steps,
144+
"eval_accumulation_steps": accumulate_steps,
148145
# eval_steps=1,
149-
)
146+
**training_arguments,
147+
}
150148

151-
data_collator = DataCollatorForSeq2Seq(
152-
tokenizer=base_model.tokenizer, model=base_model.model
153-
)
154-
155-
trainer = Seq2SeqTrainer(
156-
base_model.model,
157-
training_args,
158-
train_dataset=training_dataset,
159-
data_collator=data_collator,
160-
tokenizer=base_model.tokenizer,
161-
# compute_metrics=compute_metrics,
149+
trainer = base_model.get_trainer(
150+
train_dataset=training_dataset, **training_args
162151
)
163152

164153
# Start training via Trainer.train function

caikit_nlp/resources/pretrained_model/base.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
# Standard
1616
from abc import ABC, abstractmethod
17-
from typing import List, Optional, Type
17+
from typing import List, Optional, Type, Union
1818
import json
1919
import os
2020

2121
# Third Party
22-
from transformers import AutoTokenizer
22+
from torch.utils.data import IterableDataset
23+
from transformers import AutoTokenizer, DataCollator, Trainer, TrainingArguments
2324
from transformers.models.auto.auto_factory import _BaseAutoModelClass
2425
import torch
2526

@@ -233,6 +234,38 @@ def save(
233234
self.tokenizer.save_pretrained(tok_abs_path)
234235
self.model.save_pretrained(model_abs_path)
235236

237+
def get_trainer(
238+
self,
239+
train_dataset: IterableDataset,
240+
eval_dataset: Union[IterableDataset, None] = None,
241+
optimizers=(None, None),
242+
**kwargs,
243+
):
244+
"""
245+
NOTE: following parameters are not supported currently:
246+
1. model_init
247+
2. compute_metrics
248+
3. callbacks
249+
4. preprocess_logits_for_metrics
250+
"""
251+
252+
training_args = TrainingArguments(**kwargs)
253+
254+
# TODO: Fetch DataCollator either from property of this
255+
# class or fetch it as an argument.
256+
data_collator = DataCollator(tokenizer=self._tokenizer, model=self._model)
257+
258+
# pylint: disable=duplicate-code
259+
trainer_arguments = {
260+
"train_dataset": train_dataset,
261+
"data_collator": data_collator,
262+
"tokenizer": self._tokenizer,
263+
"optimizers": optimizers,
264+
"eval_dataset": eval_dataset,
265+
}
266+
267+
return Trainer(self._model, training_args, **trainer_arguments)
268+
236269
# pylint: disable=unused-argument
237270
@classmethod
238271
def get_num_transformers_submodules(

caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@
1515
Huggingface auto causal LM resource type
1616
"""
1717
# Standard
18-
from typing import List
18+
from typing import List, Union
1919

2020
# Third Party
21-
from transformers import AutoModelForSeq2SeqLM
21+
from torch.utils.data import IterableDataset
22+
from transformers import (
23+
AutoModelForSeq2SeqLM,
24+
DataCollatorForSeq2Seq,
25+
Seq2SeqTrainer,
26+
Seq2SeqTrainingArguments,
27+
)
2228
from transformers.models.auto import modeling_auto
2329

2430
# First Party
@@ -64,3 +70,37 @@ def get_num_transformers_submodules(
6470
"<NLP71505742E>", 0 < num_transformer_submodules <= cls.MAX_NUM_TRANSFORMERS
6571
)
6672
return num_transformer_submodules
73+
74+
def get_trainer(
75+
self,
76+
train_dataset: IterableDataset,
77+
eval_dataset: Union[IterableDataset, None] = None,
78+
optimizers=(None, None),
79+
**kwargs
80+
):
81+
"""
82+
NOTE: following parameters are not supported currently:
83+
1. model_init
84+
2. compute_metrics
85+
3. callbacks
86+
4. preprocess_logits_for_metrics
87+
"""
88+
89+
training_args = Seq2SeqTrainingArguments(**kwargs)
90+
91+
# TODO: Fetch DataCollator either from property of this
92+
# class or fetch it as an argument.
93+
data_collator = DataCollatorForSeq2Seq(
94+
tokenizer=self._tokenizer, model=self._model
95+
)
96+
97+
# pylint: disable=duplicate-code
98+
trainer_arguments = {
99+
"train_dataset": train_dataset,
100+
"data_collator": data_collator,
101+
"tokenizer": self._tokenizer,
102+
"optimizers": optimizers,
103+
"eval_dataset": eval_dataset,
104+
}
105+
106+
return Seq2SeqTrainer(self._model, training_args, **trainer_arguments)

0 commit comments

Comments
 (0)