Skip to content

Commit 750a6d0

Browse files
committed
fix
1 parent 8cef984 commit 750a6d0

2 files changed

Lines changed: 47 additions & 16 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,14 @@ def compute_total_numb_batch(numb_batches, sampler_weights) -> int:
223223
raise ValueError("Sampler weights must sum to a positive value.")
224224
probs = weights / weight_sum
225225
nbatches = np.asarray(numb_batches, dtype=np.float64)
226+
if nbatches.ndim != 1:
227+
raise ValueError("Number of batches must be 1D.")
228+
if nbatches.size == 0:
229+
raise ValueError("Number of batches is empty.")
230+
if not np.all(np.isfinite(nbatches)):
231+
raise ValueError("Number of batches must be finite.")
232+
if np.any(nbatches < 0.0):
233+
raise ValueError("Number of batches must be non-negative.")
226234
if nbatches.shape[0] != probs.shape[0]:
227235
raise ValueError("Number of batches and sampler weights must match.")
228236
valid = probs > 0.0
@@ -239,6 +247,11 @@ def resolve_model_prob(
239247
) -> np.ndarray:
240248
model_prob = np.zeros(len(model_keys), dtype=np.float64)
241249
if model_prob_config:
250+
missing = [k for k in model_keys if k not in model_prob_config]
251+
if missing:
252+
raise ValueError(
253+
f"training.model_prob must specify all tasks; missing: {missing}"
254+
)
242255
for ii, model_key in enumerate(model_keys):
243256
if model_key in model_prob_config:
244257
model_prob[ii] = float(model_prob_config[model_key])
@@ -437,6 +450,7 @@ def get_lr(lr_params):
437450
),
438451
)
439452

453+
per_task_total = []
440454
if not self.multi_task:
441455
sampler_weights = to_numpy_array(
442456
self.training_dataloader.batch_sampler.sampler.weights
@@ -446,7 +460,6 @@ def get_lr(lr_params):
446460
sampler_weights,
447461
)
448462
else:
449-
per_task_total = []
450463
for model_key in self.model_keys:
451464
sampler_weights = to_numpy_array(
452465
self.training_dataloader[model_key].batch_sampler.sampler.weights
@@ -481,22 +494,27 @@ def get_lr(lr_params):
481494
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
482495
)
483496
# Compute steps needed for each task to complete its epochs
484-
per_task_steps = []
497+
per_task_steps: dict[str, float] = {}
485498
for ii, model_key in enumerate(self.model_keys):
486499
epoch_value = self.num_epoch_dict[model_key]
487500
if epoch_value is not None:
501+
if self.model_prob[ii] <= 0.0:
502+
raise ValueError(
503+
f"training.model_prob['{model_key}'] must be positive when num_epoch_dict targets it."
504+
)
488505
# steps_i = epoch_i * per_task_total[i] / model_prob[i]
489506
steps_i = epoch_value * per_task_total[ii] / self.model_prob[ii]
490-
per_task_steps.append(steps_i)
491-
self.num_steps = int(np.ceil(np.max(per_task_steps)))
507+
per_task_steps[model_key] = float(steps_i)
508+
if not per_task_steps:
509+
raise ValueError(
510+
"training.num_epoch_dict must have at least one non-null epoch target."
511+
)
512+
self.num_steps = int(np.ceil(np.max(list(per_task_steps.values()))))
492513
log.info(
493514
"Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s.",
494515
self.num_steps,
495516
self.num_epoch_dict,
496-
{
497-
k: int(np.ceil(v))
498-
for k, v in zip(self.model_keys, per_task_steps)
499-
},
517+
{k: int(np.ceil(v)) for k, v in per_task_steps.items()},
500518
)
501519
# === Step 2. Fall back to num_epoch ===
502520
elif self.num_epoch is None:

deepmd/pt/train/training.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,14 @@ def compute_total_numb_batch(
268268
raise ValueError("Sampler weights must sum to a positive value.")
269269
probs = weights / weight_sum
270270
nbatches = np.asarray(numb_batches, dtype=np.float64)
271+
if nbatches.ndim != 1:
272+
raise ValueError("Number of batches must be 1D.")
273+
if nbatches.size == 0:
274+
raise ValueError("Number of batches is empty.")
275+
if not np.all(np.isfinite(nbatches)):
276+
raise ValueError("Number of batches must be finite.")
277+
if np.any(nbatches < 0.0):
278+
raise ValueError("Number of batches must be non-negative.")
271279
if nbatches.shape[0] != probs.shape[0]:
272280
raise ValueError("Number of batches and sampler weights must match.")
273281
valid = probs > 0.0
@@ -493,14 +501,14 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
493501
)
494502

495503
# Resolve training steps
504+
per_task_total = []
496505
if not self.multi_task:
497506
sampler_weights = to_numpy_array(self.training_dataloader.sampler.weights)
498507
total_numb_batch = compute_total_numb_batch(
499508
training_data.index,
500509
sampler_weights,
501510
)
502511
else:
503-
per_task_total = []
504512
for model_key in self.model_keys:
505513
sampler_weights = to_numpy_array(
506514
self.training_dataloader[model_key].sampler.weights
@@ -535,22 +543,27 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
535543
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
536544
)
537545
# Compute steps needed for each task to complete its epochs
538-
per_task_steps = []
546+
per_task_steps: dict[str, float] = {}
539547
for ii, model_key in enumerate(self.model_keys):
540548
epoch_value = self.num_epoch_dict[model_key]
541549
if epoch_value is not None:
550+
if self.model_prob[ii] <= 0.0:
551+
raise ValueError(
552+
f"training.model_prob['{model_key}'] must be positive when num_epoch_dict targets it."
553+
)
542554
# steps_i = epoch_i * per_task_total[i] / model_prob[i]
543555
steps_i = epoch_value * per_task_total[ii] / self.model_prob[ii]
544-
per_task_steps.append(steps_i)
545-
self.num_steps = int(np.ceil(np.max(per_task_steps)))
556+
per_task_steps[model_key] = float(steps_i)
557+
if not per_task_steps:
558+
raise ValueError(
559+
"training.num_epoch_dict must have at least one non-null epoch target."
560+
)
561+
self.num_steps = int(np.ceil(np.max(list(per_task_steps.values()))))
546562
log.info(
547563
"Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s.",
548564
self.num_steps,
549565
self.num_epoch_dict,
550-
{
551-
k: int(np.ceil(v))
552-
for k, v in zip(self.model_keys, per_task_steps)
553-
},
566+
{k: int(np.ceil(v)) for k, v in per_task_steps.items()},
554567
)
555568
# === Step 2. Fall back to num_epoch ===
556569
elif self.num_epoch is None:

0 commit comments

Comments
 (0)