Skip to content

Commit a935a0a

Browse files
committed
clear
1 parent 6ccac04 commit a935a0a

3 files changed

Lines changed: 15 additions & 2 deletions

File tree

deepmd/tf/entrypoints/change_bias.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,12 @@ def _change_bias_checkpoint_file(
199199
if stop_batch is None and num_epoch is not None:
200200
if num_epoch <= 0:
201201
raise ValueError("training.num_epoch must be positive.")
202+
# Apply sys_probs and auto_prob from original training config
203+
# to ensure stop_batch calculation matches the original training
204+
training_data_config = training_params.get("training_data", {})
205+
sys_probs = training_data_config.get("sys_probs", None)
206+
auto_prob = training_data_config.get("auto_prob", "prob_sys_size")
207+
data.set_sys_probs(sys_probs=sys_probs, auto_prob_style=auto_prob)
202208
total_numb_batch = compute_total_numb_batch(data.nbatches, data.sys_probs)
203209
if total_numb_batch <= 0:
204210
raise ValueError("Total number of training batches must be positive.")

deepmd/utils/argcheck.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3672,11 +3672,20 @@ def training_extra_check(data: dict | None) -> bool:
36723672
raise ValueError(
36733673
"training.num_epoch_dict is mutually exclusive with training.model_prob."
36743674
)
3675+
else:
3676+
if num_steps is None:
3677+
raise ValueError(
3678+
"Multi-task mode requires either training.numb_steps or training.num_epoch_dict."
3679+
)
36753680
else:
36763681
if num_steps is not None and num_epoch is not None:
36773682
raise ValueError(
36783683
"training.num_step and training.num_epoch are mutually exclusive."
36793684
)
3685+
if num_steps is None and num_epoch is None:
3686+
raise ValueError(
3687+
"Single-task mode requires either training.numb_steps or training.num_epoch."
3688+
)
36803689
return True
36813690

36823691
doc_training = "The training options."

source/tests/pt/test_sampler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,6 @@ def test_num_epoch_dict(self) -> None:
440440
sampler_2 = pt_dataloader.get_sampler_from_params(
441441
dataset_2, {"sys_probs": [0.4, 0.6], "auto_prob": "prob_sys_size"}
442442
)
443-
probs_1 = self._normalize_probs(np.asarray(sampler_1.weights))
444-
probs_2 = self._normalize_probs(np.asarray(sampler_2.weights))
445443

446444
# === Step 2. Compute per-task total_numb_batch ===
447445
per_task_total = np.array(

0 commit comments

Comments
 (0)