Skip to content

Commit 352a2b5

Browse files
feat: use num_epoch to set num_steps (#5148)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Per-model epoch targets (num_epoch, num_epoch_dict) with automatic stop-step derivation and automatic per-task probability/step allocation for multi-task training. * **Bug Fixes** * Stronger validation and clearer errors/logging for epoch/step resolution and sampler totals; safer defaults when configs are missing. * **Documentation** * Updated training argument docs and aliases; clarified mutual-exclusion rules for num_steps/num_epoch/num_epoch_dict/model_prob. * **Tests** * Added deterministic single- and multi-task sampling tests validating derived vs explicit stepping. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5eb5400 commit 352a2b5

File tree

12 files changed

+1182
-52
lines changed

12 files changed

+1182
-52
lines changed

deepmd/dpmodel/utils/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636
save_dp_model,
3737
traverse_model_dict,
3838
)
39+
from .training_utils import (
40+
compute_total_numb_batch,
41+
resolve_model_prob,
42+
resolve_model_prob_from_epochs,
43+
)
3944

4045
__all__ = [
4146
"AtomExcludeMask",
@@ -49,6 +54,7 @@
4954
"aggregate",
5055
"build_multiple_neighbor_list",
5156
"build_neighbor_list",
57+
"compute_total_numb_batch",
5258
"extend_coord_with_ghosts",
5359
"get_graph_index",
5460
"get_multiple_nlist_key",
@@ -60,6 +66,8 @@
6066
"nlist_distinguish_types",
6167
"normalize_coord",
6268
"phys2inter",
69+
"resolve_model_prob",
70+
"resolve_model_prob_from_epochs",
6371
"save_dp_model",
6472
"to_face_distance",
6573
"traverse_model_dict",
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import logging
3+
from collections.abc import (
4+
Iterable,
5+
)
6+
7+
import numpy as np
8+
9+
log = logging.getLogger(__name__)
10+
11+
12+
def compute_total_numb_batch(
13+
numb_batches: Iterable[int],
14+
sampler_weights: np.ndarray,
15+
) -> int:
16+
"""Compute total number of batches considering sampler weights.
17+
18+
Parameters
19+
----------
20+
numb_batches : Iterable[int]
21+
Number of batches for each data system.
22+
sampler_weights : np.ndarray
23+
Sampling weights for each data system.
24+
25+
Returns
26+
-------
27+
int
28+
Total number of batches.
29+
30+
Raises
31+
------
32+
ValueError
33+
If input validation fails.
34+
"""
35+
weights = np.asarray(sampler_weights, dtype=np.float64)
36+
if weights.ndim != 1:
37+
raise ValueError("Sampler weights must be 1D.")
38+
if weights.size == 0:
39+
raise ValueError("Sampler weights are empty.")
40+
if not np.all(np.isfinite(weights)):
41+
raise ValueError("Sampler weights must be finite.")
42+
if np.any(weights < 0.0):
43+
raise ValueError("Sampler weights must be non-negative.")
44+
weight_sum = float(np.sum(weights))
45+
if weight_sum <= 0.0:
46+
raise ValueError("Sampler weights must sum to a positive value.")
47+
probs = weights / weight_sum
48+
nbatches = np.asarray(numb_batches, dtype=np.float64)
49+
if nbatches.ndim != 1:
50+
raise ValueError("Number of batches must be 1D.")
51+
if nbatches.size == 0:
52+
raise ValueError("Number of batches is empty.")
53+
if not np.all(np.isfinite(nbatches)):
54+
raise ValueError("Number of batches must be finite.")
55+
if np.any(nbatches < 0.0):
56+
raise ValueError("Number of batches must be non-negative.")
57+
if nbatches.shape[0] != probs.shape[0]:
58+
raise ValueError("Number of batches and sampler weights must match.")
59+
valid = probs > 0.0
60+
if not np.any(valid):
61+
raise ValueError(
62+
"Sampler probabilities must contain at least one positive entry."
63+
)
64+
return int(np.ceil(np.max(nbatches[valid] / probs[valid])))
65+
66+
67+
def resolve_model_prob(
68+
model_keys: list[str],
69+
model_prob_config: dict[str, float] | None,
70+
model_training_data: dict[str, object],
71+
rank: int = 0,
72+
) -> np.ndarray:
73+
"""Resolve model training probability for multi-task training.
74+
75+
Parameters
76+
----------
77+
model_keys : list[str]
78+
List of model keys.
79+
model_prob_config : dict[str, float] | None
80+
User-specified model probabilities. If None, use data size.
81+
model_training_data : dict[str, object]
82+
Training data for each model.
83+
rank : int, optional
84+
Process rank for distributed training, by default 0.
85+
86+
Returns
87+
-------
88+
np.ndarray
89+
Normalized model probabilities.
90+
91+
Raises
92+
------
93+
ValueError
94+
If input validation fails.
95+
"""
96+
model_prob = np.zeros(len(model_keys), dtype=np.float64)
97+
if model_prob_config:
98+
missing = [k for k in model_keys if k not in model_prob_config]
99+
if missing:
100+
raise ValueError(
101+
f"training.model_prob must specify all tasks; missing: {missing}"
102+
)
103+
for ii, model_key in enumerate(model_keys):
104+
if model_key in model_prob_config:
105+
model_prob[ii] = float(model_prob_config[model_key])
106+
else:
107+
if rank == 0:
108+
log.info(
109+
"training.model_prob is not set or empty; defaulting to the "
110+
"number of systems per task."
111+
)
112+
for ii, model_key in enumerate(model_keys):
113+
model_prob[ii] = float(len(model_training_data[model_key]))
114+
if not np.all(np.isfinite(model_prob)):
115+
raise ValueError("Model prob must be finite.")
116+
if np.any(model_prob < 0.0):
117+
raise ValueError("Model prob must be non-negative.")
118+
sum_prob = float(np.sum(model_prob))
119+
if sum_prob <= 0.0:
120+
raise ValueError("Sum of model prob must be larger than 0!")
121+
return model_prob / sum_prob
122+
123+
124+
def resolve_model_prob_from_epochs(
125+
model_keys: list[str],
126+
num_epoch_dict_config: dict[str, float],
127+
per_task_total: np.ndarray,
128+
) -> tuple[np.ndarray, int, dict[str, float]]:
129+
"""Resolve model probability and training steps from epoch configuration.
130+
131+
Parameters
132+
----------
133+
model_keys : list[str]
134+
List of model keys.
135+
num_epoch_dict_config : dict[str, float]
136+
Target epochs for each task.
137+
per_task_total : np.ndarray
138+
Total batches per task.
139+
140+
Returns
141+
-------
142+
tuple[np.ndarray, int, dict[str, float]]
143+
Model probabilities, total training steps, and per-task steps.
144+
145+
Raises
146+
------
147+
ValueError
148+
If input validation fails.
149+
"""
150+
if not num_epoch_dict_config:
151+
raise ValueError("training.num_epoch_dict must be set for multi-task epochs.")
152+
missing = [k for k in model_keys if k not in num_epoch_dict_config]
153+
if missing:
154+
raise ValueError(
155+
f"training.num_epoch_dict must specify all tasks; missing: {missing}"
156+
)
157+
epoch_targets = np.zeros(len(model_keys), dtype=np.float64)
158+
for ii, model_key in enumerate(model_keys):
159+
epoch_value = num_epoch_dict_config[model_key]
160+
if epoch_value is None:
161+
raise ValueError(
162+
f"training.num_epoch_dict['{model_key}'] must be positive."
163+
)
164+
epoch_value = float(epoch_value)
165+
if not np.isfinite(epoch_value) or epoch_value <= 0.0:
166+
raise ValueError(
167+
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
168+
)
169+
epoch_targets[ii] = epoch_value
170+
per_task_total = np.asarray(per_task_total, dtype=np.float64)
171+
if per_task_total.ndim != 1:
172+
raise ValueError("Per-task total batches must be 1D.")
173+
if per_task_total.shape[0] != epoch_targets.shape[0]:
174+
raise ValueError("Per-task totals and epoch targets must match.")
175+
if not np.all(np.isfinite(per_task_total)):
176+
raise ValueError("Per-task total batches must be finite.")
177+
if np.any(per_task_total <= 0.0):
178+
raise ValueError("Per-task total batches must be positive.")
179+
per_task_steps = per_task_total * epoch_targets
180+
total_target_steps = float(np.sum(per_task_steps))
181+
if total_target_steps <= 0.0:
182+
raise ValueError("Sum of target steps must be positive.")
183+
model_prob = per_task_steps / total_target_steps
184+
num_steps = int(np.ceil(total_target_steps))
185+
per_task_steps_map = {
186+
model_key: float(per_task_steps[ii]) for ii, model_key in enumerate(model_keys)
187+
}
188+
return model_prob, num_steps, per_task_steps_map

deepmd/pd/train/training.py

Lines changed: 85 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030
from deepmd.common import (
3131
symlink_prefix_files,
3232
)
33+
from deepmd.dpmodel.utils import (
34+
compute_total_numb_batch,
35+
resolve_model_prob,
36+
resolve_model_prob_from_epochs,
37+
)
3338
from deepmd.dpmodel.utils.learning_rate import (
3439
BaseLR,
3540
)
@@ -130,9 +135,12 @@ def __init__(
130135
else 1
131136
)
132137
self.num_model = len(self.model_keys)
138+
self.model_prob = None
133139

134140
# Iteration config
135-
self.num_steps = training_params["numb_steps"]
141+
self.num_steps = training_params.get("numb_steps")
142+
self.num_epoch = training_params.get("num_epoch")
143+
self.num_epoch_dict = training_params.get("num_epoch_dict")
136144
self.acc_freq: int = training_params.get(
137145
"acc_freq", 1
138146
) # gradient accumulation steps
@@ -390,6 +398,82 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
390398
),
391399
)
392400

401+
per_task_total = []
402+
if not self.multi_task:
403+
if self.num_steps is None:
404+
if self.num_epoch is None:
405+
raise ValueError(
406+
"Either training.numb_steps or training.num_epoch must be set."
407+
)
408+
if self.num_epoch <= 0:
409+
raise ValueError("training.num_epoch must be positive.")
410+
sampler_weights = to_numpy_array(
411+
self.training_dataloader.batch_sampler.sampler.weights
412+
)
413+
total_numb_batch = compute_total_numb_batch(
414+
training_data.index,
415+
sampler_weights,
416+
)
417+
if total_numb_batch <= 0:
418+
raise ValueError(
419+
"Total number of training batches must be positive."
420+
)
421+
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
422+
log.info(
423+
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
424+
self.num_steps,
425+
self.num_epoch,
426+
total_numb_batch,
427+
)
428+
else:
429+
if self.num_epoch_dict:
430+
if self.num_steps is not None:
431+
raise ValueError(
432+
"training.numb_steps and training.num_epoch_dict "
433+
"are mutually exclusive."
434+
)
435+
for model_key in self.model_keys:
436+
sampler_weights = to_numpy_array(
437+
self.training_dataloader[
438+
model_key
439+
].batch_sampler.sampler.weights
440+
)
441+
per_task_total.append(
442+
compute_total_numb_batch(
443+
training_data[model_key].index,
444+
sampler_weights,
445+
)
446+
)
447+
(
448+
self.model_prob,
449+
self.num_steps,
450+
per_task_steps,
451+
) = resolve_model_prob_from_epochs(
452+
self.model_keys,
453+
self.num_epoch_dict,
454+
np.asarray(per_task_total, dtype=np.float64),
455+
)
456+
log.info(
457+
"Computed model_prob=%s and num_steps=%d from num_epoch_dict=%s "
458+
"with per-task target steps: %s.",
459+
self.model_prob,
460+
self.num_steps,
461+
self.num_epoch_dict,
462+
{k: int(np.ceil(v)) for k, v in per_task_steps.items()},
463+
)
464+
else:
465+
if self.num_steps is None:
466+
raise ValueError(
467+
"Either training.numb_steps (multi-task only) or "
468+
"training.num_epoch_dict must be set."
469+
)
470+
self.model_prob = resolve_model_prob(
471+
self.model_keys,
472+
training_params.get("model_prob"),
473+
training_data,
474+
rank=self.rank,
475+
)
476+
393477
# Learning rate
394478
self.warmup_steps = training_params.get("warmup_steps", 0)
395479
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
@@ -682,21 +766,6 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
682766
)
683767
self.optimizer = fleet.distributed_optimizer(self.optimizer)
684768

685-
# Get model prob for multi-task
686-
if self.multi_task:
687-
self.model_prob = np.array([0.0 for key in self.model_keys])
688-
if training_params.get("model_prob", None) is not None:
689-
model_prob = training_params["model_prob"]
690-
for ii, model_key in enumerate(self.model_keys):
691-
if model_key in model_prob:
692-
self.model_prob[ii] += float(model_prob[model_key])
693-
else:
694-
for ii, model_key in enumerate(self.model_keys):
695-
self.model_prob[ii] += float(len(self.training_data[model_key]))
696-
sum_prob = np.sum(self.model_prob)
697-
assert sum_prob > 0.0, "Sum of model prob must be larger than 0!"
698-
self.model_prob = self.model_prob / sum_prob
699-
700769
# Tensorboard
701770
self.enable_tensorboard = training_params.get("tensorboard", False)
702771
self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log")

0 commit comments

Comments
 (0)