Skip to content

Commit 94149a9

Browse files
committed
feat(pt): use num_epoch to set num_steps
1 parent 82a5f32 commit 94149a9

3 files changed

Lines changed: 385 additions & 17 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 100 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,11 @@ def __init__(
139139
else 1
140140
)
141141
self.num_model = len(self.model_keys)
142+
self.model_prob = None
142143

143144
# Iteration config
144-
self.num_steps = training_params["numb_steps"]
145+
self.num_steps = training_params.get("numb_steps")
146+
self.num_epoch = training_params.get("num_epoch")
145147
self.disp_file = training_params.get("disp_file", "lcurve.out")
146148
self.disp_freq = training_params.get("disp_freq", 1000)
147149
self.disp_avg = training_params.get("disp_avg", False)
@@ -247,6 +249,47 @@ def get_dataloader_and_iter(
247249
valid_numb_batch,
248250
)
249251

252+
def compute_total_numb_batch(
253+
numb_batches: Iterable[int],
254+
sampler_weights: np.ndarray,
255+
) -> int:
256+
weights = np.asarray(sampler_weights, dtype=np.float64)
257+
if weights.ndim != 1:
258+
raise ValueError("Sampler weights must be 1D.")
259+
if weights.size == 0:
260+
raise ValueError("Sampler weights are empty.")
261+
weight_sum = float(np.sum(weights))
262+
if weight_sum <= 0.0:
263+
raise ValueError("Sampler weights must sum to a positive value.")
264+
probs = weights / weight_sum
265+
nbatches = np.asarray(numb_batches, dtype=np.float64)
266+
if nbatches.shape[0] != probs.shape[0]:
267+
raise ValueError("Number of batches and sampler weights must match.")
268+
valid = probs > 0.0
269+
if not np.any(valid):
270+
raise ValueError(
271+
"Sampler probabilities must contain at least one positive entry."
272+
)
273+
return int(np.ceil(np.max(nbatches[valid] / probs[valid])))
274+
275+
def resolve_model_prob(
276+
model_keys: list[str],
277+
model_prob_config: dict[str, Any] | None,
278+
model_training_data: dict[str, DpLoaderSet],
279+
) -> np.ndarray:
280+
model_prob = np.zeros(len(model_keys), dtype=np.float64)
281+
if model_prob_config is not None:
282+
for ii, model_key in enumerate(model_keys):
283+
if model_key in model_prob_config:
284+
model_prob[ii] = float(model_prob_config[model_key])
285+
else:
286+
for ii, model_key in enumerate(model_keys):
287+
model_prob[ii] = float(len(model_training_data[model_key]))
288+
sum_prob = float(np.sum(model_prob))
289+
if sum_prob <= 0.0:
290+
raise ValueError("Sum of model prob must be larger than 0!")
291+
return model_prob / sum_prob
292+
250293
def single_model_stat(
251294
_model: Any,
252295
_data_stat_nbatch: int,
@@ -430,6 +473,56 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
430473
),
431474
)
432475

476+
# Resolve training steps
477+
if not self.multi_task:
478+
sampler_weights = to_numpy_array(self.training_dataloader.sampler.weights)
479+
total_numb_batch = compute_total_numb_batch(
480+
training_data.index,
481+
sampler_weights,
482+
)
483+
else:
484+
per_task_total = []
485+
for model_key in self.model_keys:
486+
sampler_weights = to_numpy_array(
487+
self.training_dataloader[model_key].sampler.weights
488+
)
489+
per_task_total.append(
490+
compute_total_numb_batch(
491+
training_data[model_key].index,
492+
sampler_weights,
493+
)
494+
)
495+
self.model_prob = resolve_model_prob(
496+
self.model_keys,
497+
training_params.get("model_prob"),
498+
training_data,
499+
)
500+
total_numb_batch = int(
501+
np.ceil(np.sum(np.asarray(per_task_total) * self.model_prob))
502+
)
503+
if self.num_steps is None:
504+
if self.num_epoch is None:
505+
raise ValueError(
506+
"Either training.numb_steps or training.num_epoch must be set."
507+
)
508+
if self.num_epoch <= 0:
509+
raise ValueError("training.num_epoch must be positive.")
510+
if total_numb_batch <= 0:
511+
raise ValueError("Total number of training batches must be positive.")
512+
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
513+
log.info(
514+
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
515+
self.num_steps,
516+
self.num_epoch,
517+
total_numb_batch,
518+
)
519+
elif self.num_epoch is not None:
520+
log.warning(
521+
"Both training.numb_steps and training.num_epoch are set; "
522+
"using numb_steps=%d.",
523+
self.num_steps,
524+
)
525+
433526
# Learning rate
434527
self.warmup_steps = training_params.get("warmup_steps", 0)
435528
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
@@ -637,19 +730,12 @@ def single_model_finetune(
637730
)
638731

639732
# Get model prob for multi-task
640-
if self.multi_task:
641-
self.model_prob = np.array([0.0 for key in self.model_keys])
642-
if training_params.get("model_prob", None) is not None:
643-
model_prob = training_params["model_prob"]
644-
for ii, model_key in enumerate(self.model_keys):
645-
if model_key in model_prob:
646-
self.model_prob[ii] += float(model_prob[model_key])
647-
else:
648-
for ii, model_key in enumerate(self.model_keys):
649-
self.model_prob[ii] += float(len(self.training_data[model_key]))
650-
sum_prob = np.sum(self.model_prob)
651-
assert sum_prob > 0.0, "Sum of model prob must be larger than 0!"
652-
self.model_prob = self.model_prob / sum_prob
733+
if self.multi_task and self.model_prob is None:
734+
self.model_prob = resolve_model_prob(
735+
self.model_keys,
736+
training_params.get("model_prob"),
737+
training_data,
738+
)
653739

654740
# Multi-task share params
655741
if shared_links is not None:

deepmd/utils/argcheck.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3213,7 +3213,16 @@ def mixed_precision_args() -> list[Argument]: # ! added by Denghui.
32133213
def training_args(
32143214
multi_task: bool = False,
32153215
) -> list[Argument]: # ! modified by Ziyao: data configuration isolated.
3216-
doc_numb_steps = "Number of training batch. Each training uses one batch of data."
3216+
doc_numb_steps = "Number of training batches. Each training uses one batch of data. If set, this value takes precedence over num_epoch."
3217+
doc_num_epoch = (
3218+
"Number of training epochs. "
3219+
"When numb_steps is not set, the total steps are computed as "
3220+
"ceil(num_epoch * total_numb_batch). For each training dataset, "
3221+
"total_numb_batch is computed as ceil(max_i(n_bch_i / p_i)), where p_i "
3222+
"is the sampling probability of system i after sys_probs/auto_prob. "
3223+
"In multi-task mode, total_numb_batch is the model_prob-weighted sum "
3224+
"over tasks."
3225+
)
32173226
doc_seed = "The random seed for getting frames from the training data set."
32183227
doc_disp_file = "The file for printing learning curve."
32193228
doc_disp_freq = "The frequency of printing learning curve."
@@ -3286,7 +3295,13 @@ def training_args(
32863295
args += [
32873296
mixed_precision_data,
32883297
Argument(
3289-
"numb_steps", int, optional=False, doc=doc_numb_steps, alias=["stop_batch"]
3298+
"numb_steps", int, optional=True, doc=doc_numb_steps, alias=["stop_batch"]
3299+
),
3300+
Argument(
3301+
"num_epoch",
3302+
[int, float],
3303+
optional=True,
3304+
doc=doc_only_pt_supported + doc_num_epoch,
32903305
),
32913306
Argument("seed", [int, None], optional=True, doc=doc_seed),
32923307
Argument(

0 commit comments

Comments
 (0)