Skip to content

Commit 56af3e5

Browse files
committed
refactor
1 parent 273171a commit 56af3e5

6 files changed

Lines changed: 278 additions & 206 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 87 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,55 @@ def resolve_model_prob(
267267
raise ValueError("Sum of model prob must be larger than 0!")
268268
return model_prob / sum_prob
269269

270+
def resolve_model_prob_from_epochs(
271+
model_keys,
272+
num_epoch_dict_config,
273+
per_task_total,
274+
) -> tuple[np.ndarray, int, dict[str, float]]:
275+
if not num_epoch_dict_config:
276+
raise ValueError(
277+
"training.num_epoch_dict must be set for multi-task epochs."
278+
)
279+
missing = [k for k in model_keys if k not in num_epoch_dict_config]
280+
if missing:
281+
raise ValueError(
282+
"training.num_epoch_dict must specify all tasks; "
283+
f"missing: {missing}"
284+
)
285+
epoch_targets = np.zeros(len(model_keys), dtype=np.float64)
286+
for ii, model_key in enumerate(model_keys):
287+
epoch_value = num_epoch_dict_config[model_key]
288+
if epoch_value is None:
289+
raise ValueError(
290+
f"training.num_epoch_dict['{model_key}'] must be positive."
291+
)
292+
epoch_value = float(epoch_value)
293+
if not np.isfinite(epoch_value) or epoch_value <= 0.0:
294+
raise ValueError(
295+
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
296+
)
297+
epoch_targets[ii] = epoch_value
298+
per_task_total = np.asarray(per_task_total, dtype=np.float64)
299+
if per_task_total.ndim != 1:
300+
raise ValueError("Per-task total batches must be 1D.")
301+
if per_task_total.shape[0] != epoch_targets.shape[0]:
302+
raise ValueError("Per-task totals and epoch targets must match.")
303+
if not np.all(np.isfinite(per_task_total)):
304+
raise ValueError("Per-task total batches must be finite.")
305+
if np.any(per_task_total <= 0.0):
306+
raise ValueError("Per-task total batches must be positive.")
307+
per_task_steps = per_task_total * epoch_targets
308+
total_target_steps = float(np.sum(per_task_steps))
309+
if total_target_steps <= 0.0:
310+
raise ValueError("Sum of target steps must be positive.")
311+
model_prob = per_task_steps / total_target_steps
312+
num_steps = int(np.ceil(total_target_steps))
313+
per_task_steps_map = {
314+
model_key: float(per_task_steps[ii])
315+
for ii, model_key in enumerate(model_keys)
316+
}
317+
return model_prob, num_steps, per_task_steps_map
318+
270319
def single_model_stat(
271320
_model,
272321
_data_stat_nbatch,
@@ -456,6 +505,24 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
456505
training_data.index,
457506
sampler_weights,
458507
)
508+
if self.num_steps is None:
509+
if self.num_epoch is None:
510+
raise ValueError(
511+
"Either training.numb_steps or training.num_epoch must be set."
512+
)
513+
if self.num_epoch <= 0:
514+
raise ValueError("training.num_epoch must be positive.")
515+
if total_numb_batch <= 0:
516+
raise ValueError(
517+
"Total number of training batches must be positive."
518+
)
519+
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
520+
log.info(
521+
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
522+
self.num_steps,
523+
self.num_epoch,
524+
total_numb_batch,
525+
)
459526
else:
460527
for model_key in self.model_keys:
461528
sampler_weights = to_numpy_array(
@@ -467,80 +534,35 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
467534
sampler_weights,
468535
)
469536
)
470-
self.model_prob = resolve_model_prob(
471-
self.model_keys,
472-
training_params.get("model_prob"),
473-
training_data,
474-
)
475-
total_numb_batch = int(
476-
np.ceil(np.sum(np.asarray(per_task_total) * self.model_prob))
477-
)
478-
if self.num_steps is None:
479-
# === Step 1. Check num_epoch_dict first (multi-task only) ===
480-
if self.multi_task and self.num_epoch_dict:
481-
missing = [k for k in self.model_keys if k not in self.num_epoch_dict]
482-
if missing:
483-
raise ValueError(
484-
f"training.num_epoch_dict must specify all tasks; missing: {missing}"
485-
)
486-
# Validate epoch values
487-
for model_key in self.model_keys:
488-
epoch_value = self.num_epoch_dict[model_key]
489-
if epoch_value is not None and epoch_value <= 0:
490-
raise ValueError(
491-
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
492-
)
493-
# Compute steps needed for each task to complete its epochs
494-
per_task_steps: dict[str, float] = {}
495-
for ii, model_key in enumerate(self.model_keys):
496-
epoch_value = self.num_epoch_dict[model_key]
497-
if epoch_value is not None:
498-
if self.model_prob[ii] <= 0.0:
499-
raise ValueError(
500-
f"training.model_prob['{model_key}'] must be positive when num_epoch_dict targets it."
501-
)
502-
# steps_i = epoch_i * per_task_total[i] / model_prob[i]
503-
steps_i = epoch_value * per_task_total[ii] / self.model_prob[ii]
504-
per_task_steps[model_key] = float(steps_i)
505-
if not per_task_steps:
506-
raise ValueError(
507-
"training.num_epoch_dict must have at least one non-null epoch target."
508-
)
509-
self.num_steps = int(np.ceil(np.max(list(per_task_steps.values()))))
537+
if self.num_epoch_dict:
538+
(
539+
self.model_prob,
540+
self.num_steps,
541+
per_task_steps,
542+
) = resolve_model_prob_from_epochs(
543+
self.model_keys,
544+
self.num_epoch_dict,
545+
np.asarray(per_task_total, dtype=np.float64),
546+
)
510547
log.info(
511-
"Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s.",
548+
"Computed model_prob=%s and num_steps=%d from num_epoch_dict=%s "
549+
"with per-task target steps: %s.",
550+
self.model_prob,
512551
self.num_steps,
513552
self.num_epoch_dict,
514553
{k: int(np.ceil(v)) for k, v in per_task_steps.items()},
515554
)
516-
# === Step 2. Fall back to num_epoch ===
517-
elif self.num_epoch is None:
518-
raise ValueError(
519-
"Either training.numb_steps, training.num_epoch, or "
520-
"training.num_epoch_dict (multi-task only) must be set."
521-
)
522555
else:
523-
if self.num_epoch <= 0:
524-
raise ValueError("training.num_epoch must be positive.")
525-
if total_numb_batch <= 0:
556+
if self.num_steps is None:
526557
raise ValueError(
527-
"Total number of training batches must be positive."
558+
"Either training.numb_steps (multi-task only) or "
559+
"training.num_epoch_dict must be set."
528560
)
529-
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
530-
log.info(
531-
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
532-
self.num_steps,
533-
self.num_epoch,
534-
total_numb_batch,
561+
self.model_prob = resolve_model_prob(
562+
self.model_keys,
563+
training_params.get("model_prob"),
564+
training_data,
535565
)
536-
elif self.num_epoch is not None or (
537-
self.multi_task and self.num_epoch_dict is not None
538-
):
539-
log.warning(
540-
"Both training.numb_steps and training.num_epoch (or num_epoch_dict) are set; "
541-
"using numb_steps=%d.",
542-
self.num_steps,
543-
)
544566

545567
# Learning rate
546568
self.warmup_steps = training_params.get("warmup_steps", 0)

deepmd/pt/train/training.py

Lines changed: 87 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,55 @@ def resolve_model_prob(
317317
raise ValueError("Sum of model prob must be larger than 0!")
318318
return model_prob / sum_prob
319319

320+
def resolve_model_prob_from_epochs(
321+
model_keys: list[str],
322+
num_epoch_dict_config: dict[str, Any],
323+
per_task_total: np.ndarray,
324+
) -> tuple[np.ndarray, int, dict[str, float]]:
325+
if not num_epoch_dict_config:
326+
raise ValueError(
327+
"training.num_epoch_dict must be set for multi-task epochs."
328+
)
329+
missing = [k for k in model_keys if k not in num_epoch_dict_config]
330+
if missing:
331+
raise ValueError(
332+
"training.num_epoch_dict must specify all tasks; "
333+
f"missing: {missing}"
334+
)
335+
epoch_targets = np.zeros(len(model_keys), dtype=np.float64)
336+
for ii, model_key in enumerate(model_keys):
337+
epoch_value = num_epoch_dict_config[model_key]
338+
if epoch_value is None:
339+
raise ValueError(
340+
f"training.num_epoch_dict['{model_key}'] must be positive."
341+
)
342+
epoch_value = float(epoch_value)
343+
if not np.isfinite(epoch_value) or epoch_value <= 0.0:
344+
raise ValueError(
345+
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
346+
)
347+
epoch_targets[ii] = epoch_value
348+
per_task_total = np.asarray(per_task_total, dtype=np.float64)
349+
if per_task_total.ndim != 1:
350+
raise ValueError("Per-task total batches must be 1D.")
351+
if per_task_total.shape[0] != epoch_targets.shape[0]:
352+
raise ValueError("Per-task totals and epoch targets must match.")
353+
if not np.all(np.isfinite(per_task_total)):
354+
raise ValueError("Per-task total batches must be finite.")
355+
if np.any(per_task_total <= 0.0):
356+
raise ValueError("Per-task total batches must be positive.")
357+
per_task_steps = per_task_total * epoch_targets
358+
total_target_steps = float(np.sum(per_task_steps))
359+
if total_target_steps <= 0.0:
360+
raise ValueError("Sum of target steps must be positive.")
361+
model_prob = per_task_steps / total_target_steps
362+
num_steps = int(np.ceil(total_target_steps))
363+
per_task_steps_map = {
364+
model_key: float(per_task_steps[ii])
365+
for ii, model_key in enumerate(model_keys)
366+
}
367+
return model_prob, num_steps, per_task_steps_map
368+
320369
def single_model_stat(
321370
_model: Any,
322371
_data_stat_nbatch: int,
@@ -508,6 +557,24 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
508557
training_data.index,
509558
sampler_weights,
510559
)
560+
if self.num_steps is None:
561+
if self.num_epoch is None:
562+
raise ValueError(
563+
"Either training.numb_steps or training.num_epoch must be set."
564+
)
565+
if self.num_epoch <= 0:
566+
raise ValueError("training.num_epoch must be positive.")
567+
if total_numb_batch <= 0:
568+
raise ValueError(
569+
"Total number of training batches must be positive."
570+
)
571+
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
572+
log.info(
573+
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
574+
self.num_steps,
575+
self.num_epoch,
576+
total_numb_batch,
577+
)
511578
else:
512579
for model_key in self.model_keys:
513580
sampler_weights = to_numpy_array(
@@ -519,80 +586,35 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
519586
sampler_weights,
520587
)
521588
)
522-
self.model_prob = resolve_model_prob(
523-
self.model_keys,
524-
training_params.get("model_prob"),
525-
training_data,
526-
)
527-
total_numb_batch = int(
528-
np.ceil(np.sum(np.asarray(per_task_total) * self.model_prob))
529-
)
530-
if self.num_steps is None:
531-
# === Step 1. Check num_epoch_dict first (multi-task only) ===
532-
if self.multi_task and self.num_epoch_dict:
533-
missing = [k for k in self.model_keys if k not in self.num_epoch_dict]
534-
if missing:
535-
raise ValueError(
536-
f"training.num_epoch_dict must specify all tasks; missing: {missing}"
537-
)
538-
# Validate epoch values
539-
for model_key in self.model_keys:
540-
epoch_value = self.num_epoch_dict[model_key]
541-
if epoch_value is not None and epoch_value <= 0:
542-
raise ValueError(
543-
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
544-
)
545-
# Compute steps needed for each task to complete its epochs
546-
per_task_steps: dict[str, float] = {}
547-
for ii, model_key in enumerate(self.model_keys):
548-
epoch_value = self.num_epoch_dict[model_key]
549-
if epoch_value is not None:
550-
if self.model_prob[ii] <= 0.0:
551-
raise ValueError(
552-
f"training.model_prob['{model_key}'] must be positive when num_epoch_dict targets it."
553-
)
554-
# steps_i = epoch_i * per_task_total[i] / model_prob[i]
555-
steps_i = epoch_value * per_task_total[ii] / self.model_prob[ii]
556-
per_task_steps[model_key] = float(steps_i)
557-
if not per_task_steps:
558-
raise ValueError(
559-
"training.num_epoch_dict must have at least one non-null epoch target."
560-
)
561-
self.num_steps = int(np.ceil(np.max(list(per_task_steps.values()))))
589+
if self.num_epoch_dict:
590+
(
591+
self.model_prob,
592+
self.num_steps,
593+
per_task_steps,
594+
) = resolve_model_prob_from_epochs(
595+
self.model_keys,
596+
self.num_epoch_dict,
597+
np.asarray(per_task_total, dtype=np.float64),
598+
)
562599
log.info(
563-
"Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s.",
600+
"Computed model_prob=%s and num_steps=%d from num_epoch_dict=%s "
601+
"with per-task target steps: %s.",
602+
self.model_prob,
564603
self.num_steps,
565604
self.num_epoch_dict,
566605
{k: int(np.ceil(v)) for k, v in per_task_steps.items()},
567606
)
568-
# === Step 2. Fall back to num_epoch ===
569-
elif self.num_epoch is None:
570-
raise ValueError(
571-
"Either training.numb_steps, training.num_epoch, or "
572-
"training.num_epoch_dict (multi-task only) must be set."
573-
)
574607
else:
575-
if self.num_epoch <= 0:
576-
raise ValueError("training.num_epoch must be positive.")
577-
if total_numb_batch <= 0:
608+
if self.num_steps is None:
578609
raise ValueError(
579-
"Total number of training batches must be positive."
610+
"Either training.numb_steps (multi-task only) or "
611+
"training.num_epoch_dict must be set."
580612
)
581-
self.num_steps = int(np.ceil(self.num_epoch * total_numb_batch))
582-
log.info(
583-
"Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d.",
584-
self.num_steps,
585-
self.num_epoch,
586-
total_numb_batch,
613+
self.model_prob = resolve_model_prob(
614+
self.model_keys,
615+
training_params.get("model_prob"),
616+
training_data,
587617
)
588-
elif self.num_epoch is not None or (
589-
self.multi_task and self.num_epoch_dict is not None
590-
):
591-
log.warning(
592-
"Both training.numb_steps and training.num_epoch (or num_epoch_dict) are set; "
593-
"using numb_steps=%d.",
594-
self.num_steps,
595-
)
596618

597619
# Learning rate
598620
warmup_steps = training_params.get("warmup_steps", None)

deepmd/tf/entrypoints/train.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,6 @@ def compute_total_numb_batch(nbatches, sys_probs) -> int:
304304
num_epoch,
305305
total_numb_batch,
306306
)
307-
elif num_epoch is not None:
308-
log.warning(
309-
"Both training.numb_steps and training.num_epoch are set; "
310-
"using numb_steps=%d.",
311-
stop_batch,
312-
)
313307
origin_type_map = jdata["model"].get("origin_type_map", None)
314308
if (
315309
origin_type_map is not None and not origin_type_map

0 commit comments

Comments
 (0)