Skip to content

Commit e920c2f

Browse files
authored
fix: missing log in multitask training (#5382)
fix missing log during multitask training, resulting from #4850 . <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Cleaner multi-task training logs: task-specific training metrics are always shown, and validation metrics are emitted only when valid results are present to avoid incomplete validation output. * **Tests** * Added post-training checks to ensure the training log contains expected model-probability columns and that data rows match the header column count. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 3992159 commit e920c2f

2 files changed

Lines changed: 33 additions & 14 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,25 +1372,25 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13721372
train_results[_key] = log_loss_train(
13731373
loss, more_loss, _task_key=_key
13741374
)
1375-
valid_results[_key] = log_loss_valid(_task_key=_key)
1376-
if self.rank == 0:
1377-
log.info(
1378-
format_training_message_per_task(
1379-
batch=display_step_id,
1380-
task_name=_key + "_trn",
1381-
rmse=train_results[_key],
1382-
learning_rate=cur_lr,
1383-
)
1384-
)
1385-
if valid_results[_key]:
1375+
valid_results[_key] = log_loss_valid(_task_key=_key)
1376+
if self.rank == 0:
13861377
log.info(
13871378
format_training_message_per_task(
13881379
batch=display_step_id,
1389-
task_name=_key + "_val",
1390-
rmse=valid_results[_key],
1391-
learning_rate=None,
1380+
task_name=_key + "_trn",
1381+
rmse=train_results[_key],
1382+
learning_rate=cur_lr,
13921383
)
13931384
)
1385+
if valid_results[_key]:
1386+
log.info(
1387+
format_training_message_per_task(
1388+
batch=display_step_id,
1389+
task_name=_key + "_val",
1390+
rmse=valid_results[_key],
1391+
learning_rate=None,
1392+
)
1393+
)
13941394
self.wrapper.train()
13951395

13961396
if self.disp_avg:

source/tests/pt/test_multitask.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,25 @@ def test_multitask_train(self) -> None:
5959
self.share_fitting = getattr(self, "share_fitting", False)
6060
trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links)
6161
trainer.run()
62+
63+
# check lcurve.out columns for all model keys
64+
with open("lcurve.out") as f:
65+
lines = f.readlines()
66+
header_line = lines[0]
67+
header_cols = header_line.strip().lstrip("#").split()
68+
# each model key should appear in header columns
69+
model_keys = list(self.config["training"]["model_prob"].keys())
70+
for mk in model_keys:
71+
cols_for_model = [c for c in header_cols if mk in c]
72+
self.assertGreater(
73+
len(cols_for_model), 0, f"No lcurve columns found for {mk}"
74+
)
75+
# data line column count should match header
76+
data_lines = [l for l in lines if not l.startswith("#")]
77+
self.assertGreater(len(data_lines), 0, "No data lines in lcurve.out")
78+
data_cols = data_lines[0].split()
79+
self.assertEqual(len(data_cols), len(header_cols))
80+
6281
# check model keys
6382
self.assertEqual(len(trainer.wrapper.model), 2)
6483
self.assertIn("model_1", trainer.wrapper.model)

0 commit comments

Comments
 (0)