|
43 | 43 | is_torch_available, |
44 | 44 | ) |
45 | 45 | from transformers.integrations.integration_utils import KubeflowCallback, SwanLabCallback |
46 | | -from transformers.testing_utils import require_torch |
| 46 | +from transformers.testing_utils import require_ipython, require_torch |
47 | 47 | from transformers.trainer_callback import CallbackHandler, ExportableState, TrainerControl |
48 | 48 |
|
49 | 49 |
|
@@ -1269,3 +1269,75 @@ def state(self): |
1269 | 1269 |
|
1270 | 1270 | self.assertEqual(instance.name, "test") |
1271 | 1271 | self.assertEqual(instance.counter, 5) |
| 1272 | + |
| 1273 | + |
| 1274 | +@require_torch |
| 1275 | +@require_ipython |
| 1276 | +class NotebookProgressCallbackTest(unittest.TestCase): |
| 1277 | + """Tests for NotebookProgressCallback behavior in notebook environments.""" |
| 1278 | + |
| 1279 | + def setUp(self): |
| 1280 | + self.output_dir = tempfile.mkdtemp() |
| 1281 | + |
| 1282 | + def tearDown(self): |
| 1283 | + shutil.rmtree(self.output_dir) |
| 1284 | + |
| 1285 | + def _create_trainer(self): |
| 1286 | + train_dataset = RegressionDataset(length=16) |
| 1287 | + eval_dataset = RegressionDataset(length=16) |
| 1288 | + config = RegressionModelConfig(a=0, b=0) |
| 1289 | + model = RegressionPreTrainedModel(config) |
| 1290 | + |
| 1291 | + args = TrainingArguments( |
| 1292 | + self.output_dir, |
| 1293 | + per_device_train_batch_size=2, |
| 1294 | + per_device_eval_batch_size=2, |
| 1295 | + num_train_epochs=1, |
| 1296 | + logging_strategy="no", |
| 1297 | + report_to=[], |
| 1298 | + eval_strategy="epoch", |
| 1299 | + disable_tqdm=True, |
| 1300 | + ) |
| 1301 | + |
| 1302 | + from transformers.utils.notebook import NotebookProgressCallback |
| 1303 | + |
| 1304 | + trainer = Trainer( |
| 1305 | + model=model, |
| 1306 | + args=args, |
| 1307 | + train_dataset=train_dataset, |
| 1308 | + eval_dataset=eval_dataset, |
| 1309 | + callbacks=[NotebookProgressCallback()], # force it |
| 1310 | + ) |
| 1311 | + return trainer |
| 1312 | + |
| 1313 | + def test_evaluate_before_training(self): |
| 1314 | + """Calling evaluate() before training does not crash and returns metrics.""" |
| 1315 | + trainer = self._create_trainer() |
| 1316 | + metrics = trainer.evaluate() |
| 1317 | + self.assertIn("eval_loss", metrics) |
| 1318 | + # Check that the notebook callback exists in callback handler |
| 1319 | + from transformers.utils.notebook import NotebookProgressCallback |
| 1320 | + |
| 1321 | + cb = next( |
| 1322 | + (c for c in trainer.callback_handler.callbacks if isinstance(c, NotebookProgressCallback)), |
| 1323 | + None, |
| 1324 | + ) |
| 1325 | + self.assertIsNotNone(cb) |
| 1326 | + |
| 1327 | + def test_evaluate_after_training(self): |
| 1328 | + """Calling evaluate() after training does not crash and returns metrics.""" |
| 1329 | + trainer = self._create_trainer() |
| 1330 | + trainer.train() |
| 1331 | + metrics = trainer.evaluate() |
| 1332 | + self.assertIn("eval_loss", metrics) |
| 1333 | + |
| 1334 | + def test_multiple_evaluate_calls(self): |
| 1335 | + """Calling evaluate() multiple times in a row works in notebook environment.""" |
| 1336 | + trainer = self._create_trainer() |
| 1337 | + metrics1 = trainer.evaluate() |
| 1338 | + trainer.train() |
| 1339 | + metrics2 = trainer.evaluate() |
| 1340 | + metrics3 = trainer.evaluate() |
| 1341 | + self.assertIn("eval_loss", metrics1) |
| 1342 | + self.assertIn("eval_loss", metrics2) |
| 1343 | + self.assertIn("eval_loss", metrics3) |
0 commit comments