Skip to content

Commit 273171a

Browse files
committed
fix
1 parent 14f6827 commit 273171a

2 files changed

Lines changed: 47 additions & 16 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,14 @@ def compute_total_numb_batch(numb_batches, sampler_weights) -> int:
223223
raise ValueError("Sampler weights must sum to a positive value.")
224224
probs = weights / weight_sum
225225
nbatches = np.asarray(numb_batches, dtype=np.float64)
226+
if nbatches.ndim != 1:
227+
raise ValueError("Number of batches must be 1D.")
228+
if nbatches.size == 0:
229+
raise ValueError("Number of batches is empty.")
230+
if not np.all(np.isfinite(nbatches)):
231+
raise ValueError("Number of batches must be finite.")
232+
if np.any(nbatches < 0.0):
233+
raise ValueError("Number of batches must be non-negative.")
226234
if nbatches.shape[0] != probs.shape[0]:
227235
raise ValueError("Number of batches and sampler weights must match.")
228236
valid = probs > 0.0
@@ -239,6 +247,11 @@ def resolve_model_prob(
239247
) -> np.ndarray:
240248
model_prob = np.zeros(len(model_keys), dtype=np.float64)
241249
if model_prob_config:
250+
missing = [k for k in model_keys if k not in model_prob_config]
251+
if missing:
252+
raise ValueError(
253+
f"training.model_prob must specify all tasks; missing: {missing}"
254+
)
242255
for ii, model_key in enumerate(model_keys):
243256
if model_key in model_prob_config:
244257
model_prob[ii] = float(model_prob_config[model_key])
@@ -434,6 +447,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
434447
),
435448
)
436449

450+
per_task_total = []
437451
if not self.multi_task:
438452
sampler_weights = to_numpy_array(
439453
self.training_dataloader.batch_sampler.sampler.weights
@@ -443,7 +457,6 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
443457
sampler_weights,
444458
)
445459
else:
446-
per_task_total = []
447460
for model_key in self.model_keys:
448461
sampler_weights = to_numpy_array(
449462
self.training_dataloader[model_key].batch_sampler.sampler.weights
@@ -478,22 +491,27 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
478491
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
479492
)
480493
# Compute steps needed for each task to complete its epochs
481-
per_task_steps = []
494+
per_task_steps: dict[str, float] = {}
482495
for ii, model_key in enumerate(self.model_keys):
483496
epoch_value = self.num_epoch_dict[model_key]
484497
if epoch_value is not None:
498+
if self.model_prob[ii] <= 0.0:
499+
raise ValueError(
500+
f"training.model_prob['{model_key}'] must be positive when num_epoch_dict targets it."
501+
)
485502
# steps_i = epoch_i * per_task_total[i] / model_prob[i]
486503
steps_i = epoch_value * per_task_total[ii] / self.model_prob[ii]
487-
per_task_steps.append(steps_i)
488-
self.num_steps = int(np.ceil(np.max(per_task_steps)))
504+
per_task_steps[model_key] = float(steps_i)
505+
if not per_task_steps:
506+
raise ValueError(
507+
"training.num_epoch_dict must have at least one non-null epoch target."
508+
)
509+
self.num_steps = int(np.ceil(np.max(list(per_task_steps.values()))))
489510
log.info(
490511
"Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s.",
491512
self.num_steps,
492513
self.num_epoch_dict,
493-
{
494-
k: int(np.ceil(v))
495-
for k, v in zip(self.model_keys, per_task_steps)
496-
},
514+
{k: int(np.ceil(v)) for k, v in per_task_steps.items()},
497515
)
498516
# === Step 2. Fall back to num_epoch ===
499517
elif self.num_epoch is None:

deepmd/pt/train/training.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,14 @@ def compute_total_numb_batch(
268268
raise ValueError("Sampler weights must sum to a positive value.")
269269
probs = weights / weight_sum
270270
nbatches = np.asarray(numb_batches, dtype=np.float64)
271+
if nbatches.ndim != 1:
272+
raise ValueError("Number of batches must be 1D.")
273+
if nbatches.size == 0:
274+
raise ValueError("Number of batches is empty.")
275+
if not np.all(np.isfinite(nbatches)):
276+
raise ValueError("Number of batches must be finite.")
277+
if np.any(nbatches < 0.0):
278+
raise ValueError("Number of batches must be non-negative.")
271279
if nbatches.shape[0] != probs.shape[0]:
272280
raise ValueError("Number of batches and sampler weights must match.")
273281
valid = probs > 0.0
@@ -493,14 +501,14 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
493501
)
494502

495503
# Resolve training steps
504+
per_task_total = []
496505
if not self.multi_task:
497506
sampler_weights = to_numpy_array(self.training_dataloader.sampler.weights)
498507
total_numb_batch = compute_total_numb_batch(
499508
training_data.index,
500509
sampler_weights,
501510
)
502511
else:
503-
per_task_total = []
504512
for model_key in self.model_keys:
505513
sampler_weights = to_numpy_array(
506514
self.training_dataloader[model_key].sampler.weights
@@ -535,22 +543,27 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
535543
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
536544
)
537545
# Compute steps needed for each task to complete its epochs
538-
per_task_steps = []
546+
per_task_steps: dict[str, float] = {}
539547
for ii, model_key in enumerate(self.model_keys):
540548
epoch_value = self.num_epoch_dict[model_key]
541549
if epoch_value is not None:
550+
if self.model_prob[ii] <= 0.0:
551+
raise ValueError(
552+
f"training.model_prob['{model_key}'] must be positive when num_epoch_dict targets it."
553+
)
542554
# steps_i = epoch_i * per_task_total[i] / model_prob[i]
543555
steps_i = epoch_value * per_task_total[ii] / self.model_prob[ii]
544-
per_task_steps.append(steps_i)
545-
self.num_steps = int(np.ceil(np.max(per_task_steps)))
556+
per_task_steps[model_key] = float(steps_i)
557+
if not per_task_steps:
558+
raise ValueError(
559+
"training.num_epoch_dict must have at least one non-null epoch target."
560+
)
561+
self.num_steps = int(np.ceil(np.max(list(per_task_steps.values()))))
546562
log.info(
547563
"Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s.",
548564
self.num_steps,
549565
self.num_epoch_dict,
550-
{
551-
k: int(np.ceil(v))
552-
for k, v in zip(self.model_keys, per_task_steps)
553-
},
566+
{k: int(np.ceil(v)) for k, v in per_task_steps.items()},
554567
)
555568
# === Step 2. Fall back to num_epoch ===
556569
elif self.num_epoch is None:

0 commit comments

Comments
 (0)