Skip to content

Commit cce397b

Browse files
committed
fix: use numb_epoch key and correct model_prob doc
- Change num_epoch to numb_epoch in change_bias.py and train.py - Fix model_prob default description in multi-task-training.md
1 parent 119b3bd commit cce397b

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

deepmd/tf/entrypoints/change_bias.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _change_bias_checkpoint_file(
195195
# Get stop_batch and origin_type_map like in train.py
196196
training_params = jdata.get("training", {})
197197
stop_batch = training_params.get("numb_steps")
198-
num_epoch = training_params.get("num_epoch")
198+
num_epoch = training_params.get("numb_epoch")
199199
if stop_batch is None and num_epoch is not None:
200200
if num_epoch <= 0:
201201
raise ValueError("training.num_epoch must be positive.")

deepmd/tf/entrypoints/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def _do_work(
259259
# get training info
260260
training_params = jdata["training"]
261261
stop_batch = training_params.get("numb_steps")
262-
num_epoch = training_params.get("num_epoch")
262+
num_epoch = training_params.get("numb_epoch")
263263
if stop_batch is None:
264264
if num_epoch is None:
265265
raise ValueError(

doc/train/multi-task-training.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Specifically, there are several parts that need to be modified:
7979

8080
- (Optional) {ref}`training/model_prob <training/model_prob>`: The sampling weight settings corresponding to each `model_key`, i.e., the probability weight in the training step.
8181
You can specify any positive real number weight for each task. The higher the weight, the higher the probability of being sampled in each training.
82-
This setting is optional, and if not set, tasks will be sampled with equal weights. It is only used when `num_epoch_dict` is not set.
82+
This setting is optional, and if not set, tasks will be sampled with weights proportional to the number of systems in each task. It is only used when `num_epoch_dict` is not set.
8383

8484
- (Optional) {ref}`training/num_epoch_dict <training/num_epoch_dict>`: The number of training epochs for each model branch, specified as a dictionary mapping `model_key` to epoch values (can be fractional).
8585
This allows different tasks to train for different numbers of epochs, which is particularly useful for multi-task fine-tuning scenarios

source/tests/pt/test_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class _SerialPool:
3535
def __init__(self, *args, **kwargs) -> None:
3636
pass
3737

38-
def __enter__(self) -> "_SerialPool":
38+
def __enter__(self) -> "_SerialPool": # noqa: PYI034
3939
return self
4040

4141
def __exit__(self, exc_type, exc, tb) -> bool:

0 commit comments

Comments
 (0)