Skip to content

Commit b5d29aa

Browse files
authored
Merge pull request #80 from gkumbhat/add_support_causalm_finetune
Add support causalm finetune
2 parents 5f1d9cc + 664a3d5 commit b5d29aa

9 files changed

Lines changed: 403 additions & 137 deletions

File tree

caikit_nlp/modules/text_generation/fine_tuning.py

Lines changed: 149 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# Standard
16+
from typing import Optional
1517

1618
# Third Party
1719
from torch.utils.data import IterableDataset
18-
from transformers import (
19-
AutoConfig,
20-
AutoTokenizer,
21-
DataCollatorForSeq2Seq,
22-
Seq2SeqTrainer,
23-
Seq2SeqTrainingArguments,
24-
Trainer,
25-
)
20+
from transformers import AutoConfig, AutoTokenizer
2621
import torch
2722

2823
# First Party
@@ -35,6 +30,11 @@
3530

3631
# Local
3732
from ...data_model import GenerationTrainRecord
33+
from ...resources.pretrained_model import (
34+
HFAutoCausalLM,
35+
HFAutoSeq2SeqLM,
36+
PretrainedModelBase,
37+
)
3838
from ...toolkit.data_stream_wrapper import SimpleIterableStreamWrapper
3939
from ...toolkit.data_type_utils import get_torch_dtype
4040

@@ -55,17 +55,26 @@
5555
class FineTuning(ModuleBase):
5656
"""Module to provide fine-tuning support for text generation task"""
5757

58-
def __init__(self, tokenizer, model):
58+
RANDOM_SEED = 73
59+
supported_resources = [HFAutoCausalLM, HFAutoSeq2SeqLM]
60+
61+
def __init__(
62+
self,
63+
tokenizer,
64+
model,
65+
bos_token: Optional[str] = None,
66+
sep_token: Optional[str] = None,
67+
eos_token: Optional[str] = None,
68+
pad_token: Optional[str] = None,
69+
):
5970
super().__init__()
6071

6172
self.tokenizer = tokenizer
62-
# NOTE: self.model here can also be HF trainer. This is because
63-
# if we have just trained the model then the models weights might be
64-
# available in different devices (and configuration), depending on
65-
# how it was trained. For now (July 10, 2023), we are not trying to
66-
# extract the model out from trainer itself, since that would require
67-
# us to essentially save it or reconstruct it to do normal inferring.
6873
self.model = model
74+
self._bos_token = bos_token
75+
self._sep_token = sep_token
76+
self._eos_token = eos_token
77+
self._pad_token = pad_token
6978

7079
@classmethod
7180
def train(
@@ -78,12 +87,49 @@ def train(
7887
batch_size: int = 8,
7988
num_epochs: int = 5,
8089
accumulate_steps: int = 32,
90+
random_seed: int = RANDOM_SEED,
8191
lr: float = 2e-5,
8292
# Directory where model predictions and checkpoints will be written
8393
checkpoint_dir: str = "/tmp",
94+
**training_arguments,
8495
):
8596
"""
86-
# FIXME: Below is currently configured for Seq2Seq only
97+
Fine-tune a CausalLM or Seq2seq text generation model.
98+
99+
Args:
100+
base_model: Union[str, caikit_nlp.resources.pretrained_model.base.PretrainedModelBase]
101+
Base resource model used for underlying generation.
102+
train_stream: DataStream[GenerationTrainRecord] or DataStream[ClassificationTrainRecord]
103+
Data to be used for fine-tuning the generation model.
104+
torch_dtype: str
105+
TODO: Optional[Union[torch.dtype, str]]
106+
Data type to use for training/inference of the underlying text generation model.
107+
If no value is provided, we pull from torch_dtype in config. If an in memory
108+
resource is provided which does not match the specified data type, the model
109+
underpinning the resource will be converted in place to the correct torch dtype.
110+
max_source_length: int
111+
Max length of input sequences being considered. Default: 256.
112+
max_target_length: int
113+
Max length of target sequences being predicted. Default: 128.
114+
batch_size: int
115+
Batch sized to be used for training / evaluation data. Default: 8.
116+
num_epochs: int
117+
Number of epochs to tune the model. Default: 20.
118+
accumulate_steps: int
119+
Number of steps to use for gradient accumulation. Default: 1.
120+
lr: float
121+
Learning rate to be used while tuning model. Default: 2e-5.
122+
checkpoint_dir: str
123+
Directory where model predictions and checkpoints will be written
124+
**training_arguments:
125+
Arguments supported by HF Training Arguments.
126+
TrainingArguments:
127+
https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.TrainingArguments
128+
Seq2SeqTrainingArguments:
129+
https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments
130+
Returns:
131+
FineTuning
132+
Instance of this class with fine-tuned models.
87133
"""
88134

89135
torch_dtype = get_torch_dtype(torch_dtype)
@@ -92,11 +138,12 @@ def train(
92138
# text_generation module. In future, we would want to consolidate this into
93139
# a base class or a toolkit function
94140
# pylint: disable=duplicate-code
141+
resource_type = None
142+
95143
## Load base model
96144
if isinstance(base_model, str):
97145
model_config = AutoConfig.from_pretrained(base_model)
98146

99-
resource_type = None
100147
for resource in cls.supported_resources:
101148
if model_config.model_type in resource.SUPPORTED_MODEL_TYPES:
102149
resource_type = resource
@@ -112,8 +159,14 @@ def train(
112159
log.debug("Bootstrapping base resource [%s]", base_model)
113160
base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype)
114161

162+
else:
163+
# base_model is actually a resource object
164+
resource_type = type(base_model)
165+
166+
error.type_check("<NLP03221895E>", PretrainedModelBase, base_model=base_model)
115167
## Generate data loader from stream
116168
training_dataset: IterableDataset = cls._preprocess_function(
169+
base_model=base_model,
117170
train_stream=train_stream,
118171
tokenizer=base_model.tokenizer,
119172
max_source_length=max_source_length,
@@ -144,47 +197,33 @@ def train(
144197
# by optionally accepting `training_args`
145198
# as argument to this train function.
146199
# TODO: Remove all the default used below and make them all configurable
147-
training_args = Seq2SeqTrainingArguments(
148-
output_dir=checkpoint_dir,
149-
per_device_train_batch_size=batch_size,
150-
per_device_eval_batch_size=batch_size,
151-
num_train_epochs=num_epochs,
200+
201+
training_args = {
202+
"output_dir": checkpoint_dir,
203+
"per_device_train_batch_size": batch_size,
204+
"per_device_eval_batch_size": batch_size,
205+
"num_train_epochs": num_epochs,
206+
"seed": random_seed,
152207
# NOTE: We have disabled evaluation for now
153-
do_eval=False,
154-
# evaluation_strategy = "epoch",
155-
learning_rate=lr,
156-
weight_decay=0.01,
157-
save_total_limit=3,
158-
predict_with_generate=True,
159-
push_to_hub=False,
160-
no_cuda=False, # Default
161-
generation_max_length=max_target_length,
162-
remove_unused_columns=False,
163-
dataloader_pin_memory=False,
164-
gradient_accumulation_steps=accumulate_steps,
165-
eval_accumulation_steps=accumulate_steps,
166-
logging_strategy="epoch",
167-
disable_tqdm=True,
168-
# NOTE: Following not possible without save and eval strategy
169-
# load_best_model_at_end=True,
208+
"do_eval": False,
209+
# "evaluation_strategy ": "epoch",
210+
"learning_rate": lr,
211+
"weight_decay": 0.01,
212+
"save_total_limit": 3,
213+
"push_to_hub": False,
214+
"no_cuda": False, # Default
215+
"remove_unused_columns": False,
216+
"dataloader_pin_memory": False,
217+
"gradient_accumulation_steps": accumulate_steps,
218+
"eval_accumulation_steps": accumulate_steps,
170219
# eval_steps=1,
220+
# load_best_model_at_end
221+
**training_arguments,
171222
**dtype_based_params,
172-
## TODO: Make below configurable
173-
# fsdp="full_shard auto_wrap",
174-
# local_rank=0,
175-
)
176-
177-
data_collator = DataCollatorForSeq2Seq(
178-
tokenizer=base_model.tokenizer, model=base_model.model
179-
)
223+
}
180224

181-
trainer = Seq2SeqTrainer(
182-
base_model.model,
183-
training_args,
184-
train_dataset=training_dataset,
185-
data_collator=data_collator,
186-
tokenizer=base_model.tokenizer,
187-
# compute_metrics=compute_metrics,
225+
trainer = base_model.get_trainer(
226+
train_dataset=training_dataset, **training_args
188227
)
189228

190229
if num_epochs < 1:
@@ -201,17 +240,25 @@ def train(
201240

202241
# Start training via Trainer.train function
203242
trainer.train()
204-
# NOTE: By default the model would be available in different ways
205-
# depending on where and how it was trained. So we need to fetch the model
206-
# from the trainer depending on the training method, like fsdp, ddp etc.
207-
# For simplicity, currently we will use trainer as the model since it anyways
208-
# enable the `predict` function on it and has all the layers of the model
209-
# distributed already, so it will be most optimized to use trainer to
210-
# perform prediction at this stage.
243+
244+
# save the model temporarily and reload it
245+
# this is done, since otherwise the model might be distributed in different
246+
# devices, in which case its better to use trainer's `prediction_step`
247+
# functions, but then, they don't always give API similar to `generate`
248+
# and thus cause incompatibilities in `run` function
249+
trainer.save_model(checkpoint_dir)
250+
251+
model = resource_type.bootstrap(
252+
checkpoint_dir, checkpoint_dir, torch_dtype=torch_dtype
253+
)
211254

212255
return cls(
213-
tokenizer=base_model.tokenizer,
214-
model=trainer,
256+
tokenizer=model.tokenizer,
257+
model=model,
258+
bos_token=model.tokenizer.bos_token or None,
259+
sep_token=model.tokenizer.sep_token or None,
260+
eos_token=model.tokenizer.eos_token or None,
261+
pad_token=model.tokenizer.pad_token or None,
215262
)
216263

217264
# pylint: disable=unused-argument
@@ -236,44 +283,41 @@ def run(
236283
GeneratedTextResult
237284
Generated text result
238285
"""
239-
if isinstance(self.model, Trainer):
240-
# Apply the tokenizer to the sample text & move to correct device
241-
tok_tensors = self.tokenizer(text, return_tensors="pt")
242-
# NOTE: below function is prediction on trainer, for which we need to supply
243-
# the actual underlying model as well
244-
# NOTE: We are using prediction_step instead of calling `self.model.generate`
245-
# because this way HF Trainer automatically handles device placement of the
246-
# data and model. Since the model is with Trainer at this point
247-
# and thus the device placement be according to training strategy,
248-
# its better to let Trainer handle the evaluation / prediction
249-
250-
# TODO: Add support for passing extra arguments to prediction_step
251-
_, generated_tokens, _ = self.model.prediction_step(
252-
self.model.model,
253-
tok_tensors,
254-
prediction_loss_only=False,
255-
max_new_tokens=max_new_tokens,
256-
min_new_tokens=min_new_tokens,
257-
)
258286

259-
generated_text = self.tokenizer.batch_decode(
260-
generated_tokens.detach().cpu().numpy(), skip_special_tokens=True
261-
)[0]
287+
inputs = self.model.tokenizer(text, return_tensors="pt")
288+
generate_ids = self.model.model.generate(
289+
input_ids=inputs["input_ids"],
290+
max_new_tokens=max_new_tokens,
291+
min_new_tokens=min_new_tokens,
292+
use_cache=True,
293+
)
262294

263-
else:
264-
error(
265-
"<NLP38929392E>",
266-
NotImplementedError(
267-
"model prediction on pre-finetuned model currently not supported"
268-
),
295+
token_count = generate_ids.size(1) - 1
296+
preds = [
297+
self.model.tokenizer.decode(
298+
g, skip_special_tokens=True, clean_up_tokenization_spaces=True
269299
)
300+
for g in generate_ids
301+
]
302+
if generate_ids[0][-1].item() == self._eos_token:
303+
finish_reason = "EOS_TOKEN"
304+
elif generate_ids.size(1) - 1 == max_new_tokens:
305+
finish_reason = "MAX_TOKENS"
306+
else:
307+
finish_reason = "OTHER"
270308

271-
return GeneratedTextResult(generated_text=generated_text)
309+
return GeneratedTextResult(
310+
generated_tokens=token_count,
311+
generated_text=preds[0],
312+
finish_reason=finish_reason,
313+
producer_id=self.PRODUCER_ID,
314+
)
272315

273316
################################## Private Functions ###########################################
274317

275318
@staticmethod
276319
def _preprocess_function(
320+
base_model: PretrainedModelBase,
277321
train_stream: DataStream[GenerationTrainRecord],
278322
tokenizer: AutoTokenizer,
279323
max_source_length: int,
@@ -282,28 +326,17 @@ def _preprocess_function(
282326
):
283327
"""Pre-process each example to get it prepared for training."""
284328

285-
# FIXME: Below is currently configured for Seq2Seq only
286-
287-
def _tokenization_func(
288-
example: GenerationTrainRecord,
289-
):
290-
model_inputs = tokenizer(
291-
example.input,
292-
max_length=max_source_length,
293-
truncation=True,
294-
)
295-
296-
labels = tokenizer(
297-
example.output,
298-
max_length=max_target_length,
299-
padding="max_length",
300-
truncation=True,
301-
)
302-
303-
model_inputs["labels"] = labels["input_ids"]
304-
305-
return model_inputs
306-
307-
return SimpleIterableStreamWrapper(
308-
train_stream.map(_tokenization_func), shuffle=shuffle
329+
# TODO: We are using a default verbalizer which is strictly tied to
330+
# source training record currently. We need to figure out a better
331+
# way to make verbalizer optional for build_task_tokenize_function
332+
(
333+
tokenize_function,
334+
requires_unwrapping,
335+
) = base_model.build_task_tokenize_function(
336+
tokenizer, max_source_length, max_target_length, verbalizer="{{input}}"
309337
)
338+
mapped_stream = train_stream.map(tokenize_function)
339+
if requires_unwrapping:
340+
mapped_stream = mapped_stream.flatten()
341+
342+
return SimpleIterableStreamWrapper(mapped_stream, shuffle=shuffle)

0 commit comments

Comments
 (0)