Skip to content

Commit e98e1c3

Browse files
authored
fix(pt): typo in epoch training (#5410)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed configuration lookup for the epoch count so training now reads and respects the configured number of epochs. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 5d9cbdf commit e98e1c3

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141

142142
# Iteration config
143143
self.num_steps = training_params.get("numb_steps")
144-
self.num_epoch = training_params.get("num_epoch")
144+
self.num_epoch = training_params.get("numb_epoch")
145145
self.num_epoch_dict = training_params.get("num_epoch_dict")
146146
self.acc_freq: int = training_params.get(
147147
"acc_freq", 1

deepmd/pt/train/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def __init__(
164164

165165
# Iteration config
166166
self.num_steps = training_params.get("numb_steps")
167-
self.num_epoch = training_params.get("num_epoch")
167+
self.num_epoch = training_params.get("numb_epoch")
168168
self.num_epoch_dict = training_params.get("num_epoch_dict")
169169
self.disp_file = training_params.get("disp_file", "lcurve.out")
170170
self.disp_freq = training_params.get("disp_freq", 1000)

0 commit comments

Comments
 (0)