Skip to content

Commit c796862

Browse files
feat(pd): support gradient accumulation (#4920)
support gradient accumulation for paddle backend. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Configurable gradient accumulation (acc_freq) that batches optimizer updates, optional gradient clipping, and multi‑GPU gradient sync to occur at the configured interval; acc_freq=1 preserves prior behavior. - **Documentation** - Added argument docs and a Paddle backend notice describing acc_freq. - **Tests** - Added tests exercising gradient accumulation and updated test cleanup. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 996d192 commit c796862

3 files changed

Lines changed: 49 additions & 17 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ def __init__(
133133

134134
# Iteration config
135135
self.num_steps = training_params["numb_steps"]
136+
self.acc_freq: int = training_params.get(
137+
"acc_freq", 1
138+
) # gradient accumulation steps
136139
self.disp_file = training_params.get("disp_file", "lcurve.out")
137140
self.disp_freq = training_params.get("disp_freq", 1000)
138141
self.save_ckpt = training_params.get("save_ckpt", "model.ckpt")
@@ -744,7 +747,6 @@ def step(_step_id, task_key="Default") -> None:
744747
_lr = self.lr_exp
745748
cur_lr = _lr.value(_step_id)
746749
pref_lr = cur_lr
747-
self.optimizer.clear_grad(set_to_zero=False)
748750

749751
with nvprof_context(enable_profiling, "Fetching data"):
750752
input_dict, label_dict, log_dict = self.get_data(
@@ -780,22 +782,27 @@ def step(_step_id, task_key="Default") -> None:
780782
with nvprof_context(enable_profiling, "Backward pass"):
781783
loss.backward()
782784

783-
# fuse + allreduce manually before optimization if use DDP + no_sync
784-
# details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
785-
if self.world_size > 1:
786-
hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None)
787-
788-
if self.gradient_max_norm > 0.0:
789-
with nvprof_context(enable_profiling, "Gradient clip"):
790-
paddle.nn.utils.clip_grad_norm_(
791-
self.wrapper.parameters(),
792-
self.gradient_max_norm,
793-
error_if_nonfinite=True,
785+
# gradient accumulation
786+
if (_step_id + 1) % self.acc_freq == 0:
787+
# fuse + allreduce manually before optimization if use DDP + no_sync
788+
# details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
789+
if self.world_size > 1:
790+
hpu.fused_allreduce_gradients(
791+
list(self.wrapper.parameters()), None
794792
)
795793

796-
with nvprof_context(enable_profiling, "Adam update"):
797-
self.optimizer.step()
798-
self.scheduler.step()
794+
if self.gradient_max_norm > 0.0:
795+
with nvprof_context(enable_profiling, "Gradient clip"):
796+
paddle.nn.utils.clip_grad_norm_(
797+
self.wrapper.parameters(),
798+
self.gradient_max_norm,
799+
error_if_nonfinite=True,
800+
)
801+
802+
with nvprof_context(enable_profiling, "Adam update"):
803+
self.optimizer.step()
804+
self.optimizer.clear_grad(set_to_zero=False)
805+
self.scheduler.step()
799806

800807
else:
801808
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

deepmd/utils/argcheck.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
doc_only_tf_supported = "(Supported Backend: TensorFlow) "
4242
doc_only_pt_supported = "(Supported Backend: PyTorch) "
43+
doc_only_pd_supported = "(Supported Backend: Paddle) "
4344
# descriptors
4445
doc_loc_frame = "Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame."
4546
doc_se_e2_a = "Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor."
@@ -3167,6 +3168,7 @@ def training_args(
31673168
doc_kf_blocksize = "The blocksize for the Kalman filter."
31683169
doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode."
31693170
doc_data_dict = "The multiple definition of the data, used in the multi-task mode."
3171+
doc_acc_freq = "Gradient accumulation steps (number of steps to accumulate gradients before performing an update)."
31703172

31713173
arg_training_data = training_data_args()
31723174
arg_validation_data = validation_data_args()
@@ -3269,6 +3271,13 @@ def training_args(
32693271
optional=True,
32703272
doc=doc_only_pt_supported + doc_gradient_max_norm,
32713273
),
3274+
Argument(
3275+
"acc_freq",
3276+
int,
3277+
optional=True,
3278+
default=1,
3279+
doc=doc_only_pd_supported + doc_acc_freq,
3280+
),
32723281
]
32733282
variants = [
32743283
Variant(

source/tests/pd/test_training.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,25 @@ def setUp(self) -> None:
150150
self.config["model"] = deepcopy(model_se_e2_a)
151151
self.config["training"]["numb_steps"] = 1
152152
self.config["training"]["save_freq"] = 1
153-
# import paddle
154153
enable_prim(True)
155-
# assert paddle.framework.core._is_eager_prim_enabled()
154+
155+
def tearDown(self) -> None:
156+
DPTrainTest.tearDown(self)
157+
158+
159+
class TestEnergyModelGradientAccumulation(unittest.TestCase, DPTrainTest):
160+
def setUp(self) -> None:
161+
input_json = str(Path(__file__).parent / "water/se_atten.json")
162+
with open(input_json) as f:
163+
self.config = json.load(f)
164+
data_file = [str(Path(__file__).parent / "water/data/data_0")]
165+
self.config["training"]["training_data"]["systems"] = data_file
166+
self.config["training"]["validation_data"]["systems"] = data_file
167+
self.config["model"] = deepcopy(model_se_e2_a)
168+
self.config["training"]["numb_steps"] = 1
169+
self.config["training"]["save_freq"] = 1
170+
self.config["training"]["acc_freq"] = 4
171+
enable_prim(True)
156172

157173
def tearDown(self) -> None:
158174
DPTrainTest.tearDown(self)

0 commit comments

Comments
 (0)