-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstrategy_trainer.py
More file actions
596 lines (519 loc) · 27.8 KB
/
strategy_trainer.py
File metadata and controls
596 lines (519 loc) · 27.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import json
from packaging import version
from torch import nn
from torch.utils.data import Dataset, DataLoader, IterableDataset
from transformers import PreTrainedTokenizer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from transformers.trainer_utils import PredictionOutput, IntervalStrategy
from transformers.utils import logging
from transformers.trainer_utils import EvalLoopOutput, EvalPrediction, denumpify_detensorize
from transformers.deepspeed import deepspeed_init
import collections
from transformers.file_utils import is_torch_tpu_available
import numpy as np
from transformers.trainer_pt_utils import (
IterableDatasetShard,
find_batch_size,
nested_concat,
nested_numpify,
nested_truncate,
)
from transformers.trainer import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
if version.parse(torch.__version__) >= version.parse("1.6"):
from torch.cuda.amp import autocast
logger = logging.get_logger(__name__)
class Seq2SeqTrainer(Trainer):
# 将num_beams放入函数中,是因为可以在训练完模型后, 动态的用不同的beam去evaluate这个模型。
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
max_length: Optional[int] = None,
num_beams: Optional[int] = None,
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init :obj:`compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method.
ignore_keys (:obj:`List[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is ``"eval"`` (default)
max_length (:obj:`int`, `optional`):
The maximum target length to use when predicting with the generate method.
num_beams (:obj:`int`, `optional`):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
self._max_length = None
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
self._sequence_num = 1
self._output_score = None
if self.state.epoch is not None and self.state.epoch >= 3:
self.args.save_strategy = IntervalStrategy.STEPS
self.args.evaluation_strategy = IntervalStrategy.STEPS
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict(
self,
test_dataset: Dataset,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
max_length: Optional[int] = None,
num_beams: Optional[int] = None,
generate_sequence_num: Optional[int] = None,
output_score: Optional[bool] = False,
) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in :obj:`evaluate()`.
Args:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
ignore_keys (:obj:`List[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is ``"eval"`` (default)
max_length (:obj:`int`, `optional`):
The maximum target length to use when predicting with the generate method.
num_beams (:obj:`int`, `optional`):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
.. note::
If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
padding in a token classification task) the predictions will be padded (on the right) to allow for
concatenation into one array. The padding index is -100.
Returns: `NamedTuple` A namedtuple with the following keys:
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
- label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
# self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._max_length = None
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
self._sequence_num = generate_sequence_num
self._output_score = output_score
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[
torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
"""
if not self.args.predict_with_generate or prediction_loss_only:
return super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)
# XXX: adapt synced_gpus for fairscale as well
gen_kwargs = {
"max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
"num_return_sequences": self._sequence_num if self._sequence_num is not None else self.model.config.num_return_sequences,
"output_scores": self._output_score,
"return_dict_in_generate": self._output_score,
"do_sample": True,
"length_penalty": 1.0,
}
try:
if self.is_sampling:
assert self._num_beams is not None and self._sequence_num is not None, "Sampling requires num_beams and num_return_sequences to be set"
gen_kwargs["diversity_penalty"] = 4.0
gen_kwargs["do_sample"] = True
except AttributeError:
pass
if self.tokenizer is not None:
generation_inputs = {k: v for k, v in inputs.items() if k in self.tokenizer.model_input_names}
# very ugly hack to make it work
generation_inputs["input_ids"] = generation_inputs.pop(self.tokenizer.model_input_names[0])
else:
generation_inputs = inputs["input_ids"]
if "history_attention_mask" in inputs:
generation_inputs['history_attention_mask'] = inputs['history_attention_mask']
generation_inputs['history_ids'] = inputs['history_ids']
# generation_inputs['vads'] = inputs['vads']
if "vads" in inputs:
generation_inputs['vads'] = inputs['vads']
result = self.model.generate(
**generation_inputs,
**gen_kwargs,
)
# output_scores = None
sequence_scores = None
scores = None
if isinstance(result, dict):
generated_tokens = result['sequences']
scores = result['scores']
sequence_scores = result['sequences_scores']
# import random
# if random.random() < 0.5:
# print("sequence_scores: ", sequence_scores.size())
# print(sequence_scores)
# print("len_score: ", len(scores))
# print("score_1: ", scores[0].size())
else:
generated_tokens = result
# in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
with torch.no_grad():
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else:
import torch.nn.functional as F
labels = inputs['labels']
tmp_loss = F.cross_entropy(outputs['logits'].view(-1, outputs['logits'].size(-1)).contiguous(),
labels.view(-1), reduction="none", ignore_index=1)
tmp_loss = tmp_loss.view(labels.size(0), labels.size(1))
lm_masked = labels.ne(self.model.config.pad_token_id)
label_size = torch.sum(labels.ne(self.model.config.pad_token_id), dim=1).type_as(tmp_loss)
# masked_lm_loss = torch.sum(tmp_loss) / torch.sum(label_size)
loss = torch.mean(torch.sum(tmp_loss, dim=1).float() / label_size.float()).detach()
else:
loss = None
if self.args.prediction_loss_only:
return (loss, None, None, None, None)
labels = inputs["labels"]
if labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
return (
loss,
generated_tokens,
labels,
sequence_scores,
scores[0] if scores is not None else None
) # 只需要选择第一个生成的概率就行
def _pad_tensors_to_max_len(self, tensor, max_length):
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, : tensor.shape[-1]] = tensor
return padded_tensor
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
Works both with or without labels.
"""
print(' in evaluation loop: ')
prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
# if eval is called w/o train init deepspeed here
if self.args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
# XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
# for example the Z3-optimizer is a must for zero3 to work even for inference - what we
# don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
deepspeed_engine.optimizer.optimizer = None
deepspeed_engine.lr_scheduler = None
model = self._wrap_model(self.model, training=False)
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, halve it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)
batch_size = dataloader.batch_size
logger.info(f"***** Running {description} *****")
if isinstance(dataloader.dataset, collections.abc.Sized):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
logger.info(f" Batch size = {batch_size}")
model.eval()
self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = dataloader.dataset
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
if self.args.past_index >= 0:
self._past = None
# Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
preds_host = None
labels_host = None
sequence_scores_host = None
scores_host = None
# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds = None
all_labels = None
all_sequence_scores = None
all_scores = None
# Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0
# Main evaluation loop
for step, inputs in enumerate(dataloader):
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size
# Prediction step
loss, logits, labels, sequence_scores, scores = self.prediction_step(model, inputs, prediction_loss_only,
ignore_keys=ignore_keys)
# print("oh, yes!")
if loss is not None:
losses = self._nested_gather(loss.repeat(batch_size)) # what is nested_gather do
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if sequence_scores is not None:
scores1 = self._nested_gather(sequence_scores)
sequence_scores_host = scores1 if sequence_scores_host is None else torch.cat(
(sequence_scores_host, scores1), dim=0)
if scores is not None:
scores1 = self._nested_gather(scores)
scores_host = scores if scores_host is None else torch.cat((scores_host, scores1), dim=0)
if logits is not None:
logits = self._pad_across_processes(logits)
logits = self._nested_gather(logits)
if logits.size(0) > labels.size(0):
logits = logits.view(labels.size(0), self._num_beams, logits.size(-1))
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None:
labels = self._pad_across_processes(labels)
labels = self._nested_gather(labels)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = (
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
)
# Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if sequence_scores_host is not None:
scoress = nested_numpify(sequence_scores_host)
all_sequence_scores = scoress if all_sequence_scores is None else np.concatenate(
(all_sequence_scores, scoress), axis=0)
if scores_host is not None:
scoress = nested_numpify(scores_host)
all_scores = scoress if all_scores is None else np.concatenate((all_scores, scoress), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
# Number of samples
if not isinstance(eval_dataset, IterableDataset):
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
num_samples = eval_dataset.num_examples
else:
num_samples = observed_num_examples
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
# samplers has been rounded to a multiple of batch_size, so we truncate.
if all_losses is not None:
all_losses = all_losses[:num_samples]
# if all_preds is not None:
# all_preds = nested_truncate(all_preds, num_samples)
if all_labels is not None:
all_labels = nested_truncate(all_labels, num_samples)
assert all_sequence_scores is None or len(all_preds) == len(all_sequence_scores), print(
f"preds: {len(all_preds)}, sequence_scores: {len(all_sequence_scores)}")
# Metrics!
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
else:
metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)
print("all_loss:", len(all_losses), type(all_losses))
print("all_loss: ", all_losses)
if all_losses is not None:
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
metrics[f"{metric_key_prefix}_ppl"] = np.exp(all_losses).mean().item()
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return EvalLoopOutput(
predictions=(all_preds, all_sequence_scores, all_scores),
label_ids=all_labels,
metrics=metrics,
num_samples=num_samples
)
def candidate_sampling(
self,
sampling_dataset=None,
tokenizer: Optional[PreTrainedTokenizer] = None,
conv_context: Optional[List[dict]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "sampling",
candidate_num: Optional[int] = 10,
output_score: Optional[bool] = False,
save_dir: str = None,
):
self.is_sampling = True
self._max_length = None
self._num_beams = candidate_num
self._sequence_num = candidate_num
self._output_score = output_score
self.compute_metrics = None
results = []
predictions = super().predict(sampling_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
all_preds = predictions.predictions[0]
all_labels = predictions.label_ids
for sample_id, (labels, preds) in enumerate(zip(all_labels, all_preds)):
decoded_responses = tokenizer.batch_decode(np.concatenate([labels.reshape(1, -1), preds], axis=0),
skip_special_tokens=True)
tmp_res_to_append = {
"sample_id": sample_id,
"context": conv_context[sample_id]["context"],
"responses": [r.strip() for r in decoded_responses]
}
results.append(tmp_res_to_append)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
with open(os.path.join(save_dir, "candidates.json"), "w") as f:
json.dump(results, f, ensure_ascii=False, indent=2, sort_keys=False)
with open(os.path.join(save_dir, "candidates.txt"), "w") as f:
for sample in results:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
assert len(results) == len(conv_context)
self.is_sampling = False
def mitigate_base_model(self, **kwargs):
"""
Updating entry point.
Args:
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
The trial run or the hyperparameter dictionary for hyperparameter search.
**kwargs: Additional arguments to be passed to the update method.
"""
self.is_aligning = True
train_result = self.train()
self.is_aligning = False
return train_result
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
Add loss computation about the model updating.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss