Skip to content

Commit 0b5dbfc

Browse files
authored
Fix: NotebookProgressCallback crash when evaluating with the Trainer (#44949)
* Fix NotebookProgressCallback to allow evaluate() before and after train * Add unit test for NotebookProgressCallback evaluating before and after training * Skip NotebookProgressCallback tests when IPython is not installed * Display eval metrics when training tracker is None on NotebookProgressCallback * Add is_ipython_available and require_ipython test decorator * Filter model_preparation_time metric and add code comments in on_eval
1 parent 6db69b9 commit 0b5dbfc

5 files changed

Lines changed: 101 additions & 6 deletions

File tree

src/transformers/testing_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
is_hadamard_available,
107107
is_hqq_available,
108108
is_huggingface_hub_greater_or_equal,
109+
is_ipython_available,
109110
is_jinja_available,
110111
is_jmespath_available,
111112
is_jumanpp_available,
@@ -1179,6 +1180,11 @@ def require_faiss(test_case):
11791180
return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
11801181

11811182

1183+
def require_ipython(test_case):
1184+
"""Decorator marking a test that requires IPython. These tests are skipped when IPython isn't installed."""
1185+
return unittest.skipUnless(is_ipython_available(), "test requires `IPython`")(test_case)
1186+
1187+
11821188
def require_optuna(test_case):
11831189
"""
11841190
Decorator marking a test that requires optuna.

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
is_hqq_available,
151151
is_huggingface_hub_greater_or_equal,
152152
is_in_notebook,
153+
is_ipython_available,
153154
is_jinja_available,
154155
is_jmespath_available,
155156
is_jumanpp_available,

src/transformers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,11 @@ def msg_callable():
15401540
torch._check_with(error_type, cond, msg_callable)
15411541

15421542

1543+
@lru_cache
1544+
def is_ipython_available() -> bool:
1545+
return importlib.util.find_spec("IPython") is not None
1546+
1547+
15431548
@lru_cache
15441549
def is_in_notebook() -> bool:
15451550
try:

src/transformers/utils/notebook.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,9 @@ def on_log(self, args, state, control, logs=None, **kwargs):
351351
tt.write_line(values)
352352

353353
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
354-
tt = _require(self.training_tracker, "on_train_begin must be called before on_evaluate")
354+
# Recompute first_column here since on_evaluate can be called before on_train_begin,
355+
# where it is normally initialized.
356+
self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step"
355357

356358
values = {"Training Loss": "No log", "Validation Loss": "No log"}
357359
for log in reversed(state.log_history):
@@ -374,18 +376,27 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
374376
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
375377
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
376378
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
379+
_ = metrics.pop(f"{metric_key_prefix}_model_preparation_time", None)
380+
377381
for k, v in metrics.items():
378382
splits = k.split("_")
379383
name = " ".join([part.capitalize() for part in splits[1:]])
380384
if name == "Loss":
381385
# Single dataset
382386
name = "Validation Loss"
383387
values[name] = v
384-
tt.write_line(values)
385-
tt.remove_child()
388+
389+
if self.training_tracker is not None:
390+
tt = self.training_tracker
391+
tt.write_line(values)
392+
tt.remove_child()
393+
# Evaluation takes a long time so we should force the next update.
394+
self._force_next_update = True
395+
else:
396+
# No training tracker, but still show the metrics
397+
disp.display(disp.HTML(text_to_html_table([list(values.keys()), list(values.values())])))
398+
386399
self.prediction_bar = None
387-
# Evaluation takes a long time so we should force the next update.
388-
self._force_next_update = True
389400

390401
def on_train_end(self, args, state, control, **kwargs):
391402
tt = _require(self.training_tracker, "on_train_begin must be called before on_train_end")

tests/trainer/test_trainer_callback.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
is_torch_available,
4444
)
4545
from transformers.integrations.integration_utils import KubeflowCallback, SwanLabCallback
46-
from transformers.testing_utils import require_torch
46+
from transformers.testing_utils import require_ipython, require_torch
4747
from transformers.trainer_callback import CallbackHandler, ExportableState, TrainerControl
4848

4949

@@ -1269,3 +1269,75 @@ def state(self):
12691269

12701270
self.assertEqual(instance.name, "test")
12711271
self.assertEqual(instance.counter, 5)
1272+
1273+
1274+
@require_torch
1275+
@require_ipython
1276+
class NotebookProgressCallbackTest(unittest.TestCase):
1277+
"""Tests for NotebookProgressCallback behavior in notebook environments."""
1278+
1279+
def setUp(self):
1280+
self.output_dir = tempfile.mkdtemp()
1281+
1282+
def tearDown(self):
1283+
shutil.rmtree(self.output_dir)
1284+
1285+
def _create_trainer(self):
1286+
train_dataset = RegressionDataset(length=16)
1287+
eval_dataset = RegressionDataset(length=16)
1288+
config = RegressionModelConfig(a=0, b=0)
1289+
model = RegressionPreTrainedModel(config)
1290+
1291+
args = TrainingArguments(
1292+
self.output_dir,
1293+
per_device_train_batch_size=2,
1294+
per_device_eval_batch_size=2,
1295+
num_train_epochs=1,
1296+
logging_strategy="no",
1297+
report_to=[],
1298+
eval_strategy="epoch",
1299+
disable_tqdm=True,
1300+
)
1301+
1302+
from transformers.utils.notebook import NotebookProgressCallback
1303+
1304+
trainer = Trainer(
1305+
model=model,
1306+
args=args,
1307+
train_dataset=train_dataset,
1308+
eval_dataset=eval_dataset,
1309+
callbacks=[NotebookProgressCallback()], # force it
1310+
)
1311+
return trainer
1312+
1313+
def test_evaluate_before_training(self):
1314+
"""Calling evaluate() before training does not crash and returns metrics."""
1315+
trainer = self._create_trainer()
1316+
metrics = trainer.evaluate()
1317+
self.assertIn("eval_loss", metrics)
1318+
# Check that the notebook callback exists in callback handler
1319+
from transformers.utils.notebook import NotebookProgressCallback
1320+
1321+
cb = next(
1322+
(c for c in trainer.callback_handler.callbacks if isinstance(c, NotebookProgressCallback)),
1323+
None,
1324+
)
1325+
self.assertIsNotNone(cb)
1326+
1327+
def test_evaluate_after_training(self):
1328+
"""Calling evaluate() after training does not crash and returns metrics."""
1329+
trainer = self._create_trainer()
1330+
trainer.train()
1331+
metrics = trainer.evaluate()
1332+
self.assertIn("eval_loss", metrics)
1333+
1334+
def test_multiple_evaluate_calls(self):
1335+
"""Calling evaluate() multiple times in a row works in notebook environment."""
1336+
trainer = self._create_trainer()
1337+
metrics1 = trainer.evaluate()
1338+
trainer.train()
1339+
metrics2 = trainer.evaluate()
1340+
metrics3 = trainer.evaluate()
1341+
self.assertIn("eval_loss", metrics1)
1342+
self.assertIn("eval_loss", metrics2)
1343+
self.assertIn("eval_loss", metrics3)

0 commit comments

Comments
 (0)