Skip to content

Commit f3e29fe

Browse files
committed
fix: buffer register
1 parent 4cee0bf commit f3e29fe

1 file changed

Lines changed: 13 additions & 7 deletions

File tree

deepmd/pt_expt/train/training.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,9 @@ def __init__(
439439
self.original_model = original_model
440440
self.compiled_forward_lower = compiled_forward_lower
441441
self._task_buf_order = task_buf_order
442-
if task_buf_order and task_buffers:
443-
for name in task_buf_order:
444-
if name in task_buffers:
445-
self.register_buffer(f"_task_{name}", task_buffers[name])
442+
# task_buffers is intentionally not stored: buffers are read from
443+
# original_model.get_fitting_net() at forward time so that weight
444+
# updates (load_state_dict, optimiser steps) are always reflected.
446445

447446
def forward(
448447
self,
@@ -489,9 +488,16 @@ def forward(
489488
ext_coord = ext_coord.reshape(nframes, -1, 3)
490489
ext_coord = ext_coord.detach().requires_grad_(True)
491490

492-
task_buf_vals = tuple(
493-
getattr(self, f"_task_{name}") for name in self._task_buf_order
494-
)
491+
if self._task_buf_order:
492+
try:
493+
_fitting = self.original_model.get_fitting_net()
494+
task_buf_vals: tuple = tuple(
495+
getattr(_fitting, name) for name in self._task_buf_order
496+
)
497+
except AttributeError:
498+
task_buf_vals = ()
499+
else:
500+
task_buf_vals = ()
495501
result = self.compiled_forward_lower(
496502
ext_coord,
497503
ext_atype,

0 commit comments

Comments
 (0)