Skip to content

Commit 8cef984

Browse files
committed
add num_epoch_dict for multitask training
1 parent 7c9813b commit 8cef984

5 files changed

Lines changed: 252 additions & 34 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
# Iteration config
135135
self.num_steps = training_params.get("numb_steps")
136136
self.num_epoch = training_params.get("num_epoch")
137+
self.num_epoch_dict = training_params.get("num_epoch_dict")
137138
self.acc_freq: int = training_params.get(
138139
"acc_freq", 1
139140
) # gradient accumulation steps
@@ -465,24 +466,63 @@ def get_lr(lr_params):
465466
np.ceil(np.sum(np.asarray(per_task_total) * self.model_prob))
466467
)
467468
if self.num_steps is None:
468-
if self.num_epoch is None:
469+
# === Step 1. Check num_epoch_dict first (multi-task only) ===
470+
if self.multi_task and self.num_epoch_dict:
471+
missing = [k for k in self.model_keys if k not in self.num_epoch_dict]
472+
if missing:
473+
raise ValueError(
474+
f"training.num_epoch_dict must specify all tasks; missing: {missing}"
475+
)
476+
# Validate epoch values
477+
for model_key in self.model_keys:
478+
epoch_value = self.num_epoch_dict[model_key]
479+
if epoch_value is not None and epoch_value <= 0:
480+
raise ValueError(
481+
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
482+
)
483+
# Compute steps needed for each task to complete its epochs
484+
per_task_steps = []
485+
for ii, model_key in enumerate(self.model_keys):
486+
epoch_value = self.num_epoch_dict[model_key]
487+
if epoch_value is not None:
488+
# steps_i = epoch_i * per_task_total[i] / model_prob[i]
489+
steps_i = epoch_value * per_task_total[ii] / self.model_prob[ii]
490+
per_task_steps.append(steps_i)
491+
self.num_steps = int(np.ceil(np.max(per_task_steps)))
492+
log.info(
493+
"Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s.",
494+
self.num_steps,
495+
self.num_epoch_dict,
496+
{
497+
k: int(np.ceil(v))
498+
for k, v in zip(self.model_keys, per_task_steps)
499+
},
500+
)
501+
# === Step 2. Fall back to num_epoch ===
502+
elif self.num_epoch is None:
469503
raise ValueError(
470-
"Either training.numb_steps or training.num_epoch must be set."
504+
"Either training.numb_steps, training.num_epoch, or "
505+
"training.num_epoch_dict (multi-task only) must be set."
471506
)
472-
if self.num_epoch <= 0:
473-
raise ValueError("training.num_epoch must be positive.")
474-
if total_numb_batch <= 0:
475-
raise ValueError("Total number of training batches must be positive.")
476-
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
477-
log.info(
478-
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
479-
self.num_steps,
480-
self.num_epoch,
481-
total_numb_batch,
482-
)
483-
elif self.num_epoch is not None:
507+
else:
508+
if self.num_epoch <= 0:
509+
raise ValueError("training.num_epoch must be positive.")
510+
if total_numb_batch <= 0:
511+
raise ValueError(
512+
"Total number of training batches must be positive."
513+
)
514+
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
515+
log.info(
516+
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
517+
self.num_steps,
518+
self.num_epoch,
519+
total_numb_batch,
520+
)
521+
elif self.num_epoch is not None or (
522+
self.multi_task and self.num_epoch_dict is not None
523+
):
484524
log.warning(
485-
"Both training.numb_steps and training.num_epoch are set; "
525+
"Both training.numb_steps and training.num_epoch (or num_epoch_dict) are set; "
486526
"using numb_steps=%d.",
487527
self.num_steps,
488528
)

deepmd/pt/train/training.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(
144144
# Iteration config
145145
self.num_steps = training_params.get("numb_steps")
146146
self.num_epoch = training_params.get("num_epoch")
147+
self.num_epoch_dict = training_params.get("num_epoch_dict")
147148
self.disp_file = training_params.get("disp_file", "lcurve.out")
148149
self.disp_freq = training_params.get("disp_freq", 1000)
149150
self.disp_avg = training_params.get("disp_avg", False)
@@ -519,24 +520,63 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
519520
np.ceil(np.sum(np.asarray(per_task_total) * self.model_prob))
520521
)
521522
if self.num_steps is None:
522-
if self.num_epoch is None:
523+
# === Step 1. Check num_epoch_dict first (multi-task only) ===
524+
if self.multi_task and self.num_epoch_dict:
525+
missing = [k for k in self.model_keys if k not in self.num_epoch_dict]
526+
if missing:
527+
raise ValueError(
528+
f"training.num_epoch_dict must specify all tasks; missing: {missing}"
529+
)
530+
# Validate epoch values
531+
for model_key in self.model_keys:
532+
epoch_value = self.num_epoch_dict[model_key]
533+
if epoch_value is not None and epoch_value <= 0:
534+
raise ValueError(
535+
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
536+
)
537+
# Compute steps needed for each task to complete its epochs
538+
per_task_steps = []
539+
for ii, model_key in enumerate(self.model_keys):
540+
epoch_value = self.num_epoch_dict[model_key]
541+
if epoch_value is not None:
542+
# steps_i = epoch_i * per_task_total[i] / model_prob[i]
543+
steps_i = epoch_value * per_task_total[ii] / self.model_prob[ii]
544+
per_task_steps.append(steps_i)
545+
self.num_steps = int(np.ceil(np.max(per_task_steps)))
546+
log.info(
547+
"Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s.",
548+
self.num_steps,
549+
self.num_epoch_dict,
550+
{
551+
k: int(np.ceil(v))
552+
for k, v in zip(self.model_keys, per_task_steps)
553+
},
554+
)
555+
# === Step 2. Fall back to num_epoch ===
556+
elif self.num_epoch is None:
523557
raise ValueError(
524-
"Either training.numb_steps or training.num_epoch must be set."
558+
"Either training.numb_steps, training.num_epoch, or "
559+
"training.num_epoch_dict (multi-task only) must be set."
525560
)
526-
if self.num_epoch <= 0:
527-
raise ValueError("training.num_epoch must be positive.")
528-
if total_numb_batch <= 0:
529-
raise ValueError("Total number of training batches must be positive.")
530-
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
531-
log.info(
532-
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
533-
self.num_steps,
534-
self.num_epoch,
535-
total_numb_batch,
536-
)
537-
elif self.num_epoch is not None:
561+
else:
562+
if self.num_epoch <= 0:
563+
raise ValueError("training.num_epoch must be positive.")
564+
if total_numb_batch <= 0:
565+
raise ValueError(
566+
"Total number of training batches must be positive."
567+
)
568+
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
569+
log.info(
570+
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
571+
self.num_steps,
572+
self.num_epoch,
573+
total_numb_batch,
574+
)
575+
elif self.num_epoch is not None or (
576+
self.multi_task and self.num_epoch_dict is not None
577+
):
538578
log.warning(
539-
"Both training.numb_steps and training.num_epoch are set; "
579+
"Both training.numb_steps and training.num_epoch (or num_epoch_dict) are set; "
540580
"using numb_steps=%d.",
541581
self.num_steps,
542582
)

deepmd/utils/argcheck.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3229,10 +3229,26 @@ def training_args(
32293229
"as above, and the final total_numb_batch is their model_prob-weighted sum. "
32303230
"Note that in multi-task mode, this defines an 'expected epoch' where each "
32313231
"sample is visited once in expectation across all tasks, rather than a "
3232-
"full epoch for each individual task. For multi-task pretraining scenarios "
3233-
"where different tasks require different numbers of visits, using numb_steps "
3234-
"directly is recommended for more explicit control. At least one of numb_steps "
3235-
"or num_epoch must be set; otherwise a ValueError is raised."
3232+
"full epoch for each individual task. In multi-task mode, num_epoch_dict "
3233+
"takes precedence over num_epoch if both are set. For multi-task pretraining "
3234+
"scenarios where different tasks require different numbers of visits, using "
3235+
"numb_steps directly is recommended for more explicit control. At least one "
3236+
"of numb_steps or num_epoch (or num_epoch_dict in multi-task mode) must be "
3237+
"set; otherwise a ValueError is raised."
3238+
)
3239+
doc_num_epoch_dict = (
3240+
"Number of training epochs for each model branch in multi-task mode "
3241+
"(can be fractional). This is a dictionary mapping model keys to the "
3242+
"number of epochs to train that specific model. When set, the total "
3243+
"training steps are computed as max_i(num_epoch_dict[i] * per_task_total[i] / model_prob[i]), "
3244+
"ensuring each model completes at least its specified number of epochs. "
3245+
"The model requiring the most steps will complete approximately its target "
3246+
"epochs, while other models may complete more epochs. This is particularly "
3247+
"useful for multi-task fine-tuning scenarios where a data-rich pretrained model "
3248+
"is jointly trained with a data-scarce downstream task, and only the downstream "
3249+
"task's epoch count is of interest. In multi-task mode, this parameter takes "
3250+
"precedence over num_epoch if both are set. All model keys must be specified "
3251+
"in the dictionary."
32363252
)
32373253
doc_seed = "The random seed for getting frames from the training data set."
32383254
doc_disp_file = "The file for printing learning curve."
@@ -3303,6 +3319,13 @@ def training_args(
33033319
if not multi_task
33043320
else [
33053321
Argument("model_prob", dict, optional=True, default={}, doc=doc_model_prob),
3322+
Argument(
3323+
"num_epoch_dict",
3324+
dict,
3325+
optional=True,
3326+
default={},
3327+
doc=doc_num_epoch_dict,
3328+
),
33063329
Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict),
33073330
]
33083331
)

doc/train/multi-task-training.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ Specifically, there are several parts that need to be modified:
8181
You can specify any positive real number weight for each task. The higher the weight, the higher the probability of being sampled in each training.
8282
This setting is optional, and if not set, tasks will be sampled with equal weights.
8383

84+
- (Optional) {ref}`training/num_epoch_dict <training/num_epoch_dict>`: The number of training epochs for each model branch, specified as a dictionary mapping `model_key` to epoch values.
85+
This allows different tasks to train for different numbers of epochs, which is particularly useful for multi-task fine-tuning scenarios
86+
where a data-rich pretrained model is jointly trained with a data-scarce downstream task.
87+
When set, the total training steps are computed as `max_i(num_epoch_dict[i] * per_task_total[i] / model_prob[i])`,
88+
ensuring each model completes at least its specified number of epochs.
89+
The model requiring the most steps will complete approximately its target epochs, while other models may complete more epochs.
90+
In multi-task mode, this parameter takes precedence over `num_epoch` if both are set.
91+
8492
An example input for multi-task training two models in water system is shown as following:
8593

8694
```{literalinclude} ../../examples/water_multi_task/pytorch_example/input_torch.json

source/tests/pt/test_sampler.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,113 @@ def test_sampling_stability_multi_task(self) -> None:
410410
)
411411
)
412412

413+
def test_num_epoch_dict(self) -> None:
414+
"""Test num_epoch_dict calculation logic for multi-task training."""
415+
# === Step 1. Build Datasets ===
416+
model_keys = ["model_1", "model_2"]
417+
systems_1 = [
418+
str(Path(__file__).parent / "water/data/data_0"),
419+
str(Path(__file__).parent / "water/data/data_1"),
420+
]
421+
systems_2 = [
422+
str(Path(__file__).parent / "water/data/data_1"),
423+
str(Path(__file__).parent / "water/data/single"),
424+
]
425+
dataset_1 = pt_dataloader.DpLoaderSet(
426+
systems_1,
427+
self.batch_size,
428+
self.type_map,
429+
seed=10,
430+
shuffle=False,
431+
)
432+
dataset_2 = pt_dataloader.DpLoaderSet(
433+
systems_2,
434+
self.batch_size,
435+
self.type_map,
436+
seed=10,
437+
shuffle=False,
438+
)
439+
sampler_1 = pt_dataloader.get_sampler_from_params(
440+
dataset_1, {"sys_probs": [0.7, 0.3], "auto_prob": "prob_sys_size"}
441+
)
442+
sampler_2 = pt_dataloader.get_sampler_from_params(
443+
dataset_2, {"sys_probs": [0.4, 0.6], "auto_prob": "prob_sys_size"}
444+
)
445+
probs_1 = self._normalize_probs(np.asarray(sampler_1.weights))
446+
probs_2 = self._normalize_probs(np.asarray(sampler_2.weights))
447+
448+
# === Step 2. Compute per-task total_numb_batch ===
449+
per_task_total = np.array(
450+
[
451+
self._compute_total_numb_batch(
452+
np.asarray(dataset_1.index, dtype=np.float64), probs_1
453+
),
454+
self._compute_total_numb_batch(
455+
np.asarray(dataset_2.index, dtype=np.float64), probs_2
456+
),
457+
],
458+
dtype=np.float64,
459+
)
460+
461+
# === Step 3. Test num_epoch_dict calculation ===
462+
model_prob = np.asarray([0.4, 0.6], dtype=np.float64)
463+
model_prob = model_prob / np.sum(model_prob)
464+
num_epoch_dict = {model_keys[0]: 2.0, model_keys[1]: 5.0}
465+
466+
# Compute expected steps for each task
467+
# steps_i = epoch_i * per_task_total[i] / model_prob[i]
468+
per_task_steps = np.array(
469+
[
470+
num_epoch_dict[model_keys[0]] * per_task_total[0] / model_prob[0],
471+
num_epoch_dict[model_keys[1]] * per_task_total[1] / model_prob[1],
472+
],
473+
dtype=np.float64,
474+
)
475+
476+
# Total steps should be max of per-task steps
477+
expected_num_steps = int(np.ceil(np.max(per_task_steps)))
478+
479+
# Verify the calculation matches the expected formula
480+
self.assertIsInstance(expected_num_steps, int)
481+
self.assertGreater(expected_num_steps, 0)
482+
483+
# Verify that running expected_num_steps would give each task at least
484+
# its target epochs (may be more for tasks needing fewer steps)
485+
expected_model_0_counts = expected_num_steps * model_prob[0]
486+
expected_model_1_counts = expected_num_steps * model_prob[1]
487+
488+
# Each task should complete at least its target epochs
489+
expected_epochs_0 = expected_model_0_counts / per_task_total[0]
490+
expected_epochs_1 = expected_model_1_counts / per_task_total[1]
491+
492+
self.assertGreaterEqual(
493+
expected_epochs_0,
494+
num_epoch_dict[model_keys[0]],
495+
msg="Model 0 should complete at least 2 epochs",
496+
)
497+
self.assertGreaterEqual(
498+
expected_epochs_1,
499+
num_epoch_dict[model_keys[1]],
500+
msg="Model 1 should complete at least 5 epochs",
501+
)
502+
503+
# The task requiring the most steps should complete approximately its target
504+
max_task_idx = int(np.argmax(per_task_steps))
505+
if max_task_idx == 0:
506+
self.assertAlmostEqual(
507+
expected_epochs_0,
508+
num_epoch_dict[model_keys[0]],
509+
delta=0.1,
510+
msg="Model 0 (max steps) should complete approximately 2 epochs",
511+
)
512+
else:
513+
self.assertAlmostEqual(
514+
expected_epochs_1,
515+
num_epoch_dict[model_keys[1]],
516+
delta=0.1,
517+
msg="Model 1 (max steps) should complete approximately 5 epochs",
518+
)
519+
413520

414521
if __name__ == "__main__":
415522
unittest.main()

0 commit comments

Comments
 (0)