Skip to content

Commit 8e9c2a5

Browse files
committed
fix: promote out_bias
1 parent f3e29fe commit 8e9c2a5

1 file changed

Lines changed: 63 additions & 19 deletions

File tree

deepmd/pt_expt/train/training.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,41 @@
7070

7171
log = logging.getLogger(__name__)
7272

73-
# Buffer names that differ per task after share_params; everything else in the
74-
# fitting net is literally the same Python object across shared tasks.
73+
# Buffer names in the fitting net that differ per task after share_params;
74+
# everything else in the fitting net is the same Python object across tasks.
7575
_TASK_SPECIFIC_BUFFER_NAMES: tuple[str, ...] = ("bias_atom_e", "case_embd")
7676

77+
# Buffer names in atomic_model that are per-task (energy/output statistics).
78+
# These live one level above the fitting net and are not reached by
79+
# fitting-net share_params, so they must also be promoted to FX placeholders.
80+
_ATOMIC_MODEL_TASK_BUFFER_NAMES: tuple[str, ...] = ("out_bias", "out_std")
81+
82+
# Prefix used in task_buf_order keys to distinguish atomic_model buffers
83+
# from fitting-net buffers.
84+
_AM_PREFIX = "am/"
85+
7786

7887
def _get_task_buffers(model: torch.nn.Module) -> dict[str, torch.Tensor]:
79-
"""Return per-task fitting-net buffers that vary across shared tasks."""
88+
"""Return per-task buffers (fitting net + atomic model) that vary across shared tasks."""
89+
result: dict[str, torch.Tensor] = {}
90+
# fitting-net task buffers
8091
try:
8192
fitting = model.get_fitting_net()
93+
for name in _TASK_SPECIFIC_BUFFER_NAMES:
94+
val = fitting._buffers.get(name)
95+
if val is not None and torch.is_tensor(val):
96+
result[name] = val.detach().clone()
8297
except AttributeError:
83-
return {}
84-
result: dict[str, torch.Tensor] = {}
85-
for name in _TASK_SPECIFIC_BUFFER_NAMES:
86-
val = getattr(fitting, name, None)
87-
if val is not None and torch.is_tensor(val):
88-
result[name] = val.detach().clone()
98+
pass
99+
# atomic_model task buffers (out_bias, out_std)
100+
try:
101+
am = model.atomic_model
102+
for name in _ATOMIC_MODEL_TASK_BUFFER_NAMES:
103+
val = am._buffers.get(name)
104+
if val is not None and torch.is_tensor(val):
105+
result[_AM_PREFIX + name] = val.detach().clone()
106+
except AttributeError:
107+
pass
89108
return result
90109

91110

@@ -292,13 +311,18 @@ def _trace_and_compile(
292311
tuple(task_buffers[k] for k in task_buf_order) if task_buffers else ()
293312
)
294313

295-
# Resolve fitting net once for buffer patching inside fn.
314+
# Resolve fitting net and atomic_model once for buffer patching inside fn.
296315
_fitting: torch.nn.Module | None = None
316+
_atomic_model: torch.nn.Module | None = None
297317
if task_buf_order:
298318
try:
299319
_fitting = model.get_fitting_net()
300320
except AttributeError:
301-
pass
321+
pass # no fitting net → no fitting-net buffers to patch
322+
try:
323+
_atomic_model = model.atomic_model
324+
except AttributeError:
325+
pass # no atomic_model → no atomic-model buffers to patch
302326

303327
def fn(
304328
extended_coord: torch.Tensor,
@@ -313,12 +337,20 @@ def fn(
313337
extended_coord = extended_coord.detach().requires_grad_(True)
314338
# Temporarily patch task-specific buffers with the proxy tensors so
315339
# make_fx records them as FX placeholders rather than baked-in constants.
316-
# This makes the compiled graph reusable for any buffer values.
340+
# Keys prefixed with _AM_PREFIX are atomic_model buffers; the rest are
341+
# fitting-net buffers.
317342
originals: dict[str, torch.Tensor | None] = {}
318-
if _fitting is not None and task_buf_order:
343+
if task_buf_order:
319344
for name, val in zip(task_buf_order, task_buf_vals):
320-
originals[name] = _fitting._buffers.get(name)
321-
_fitting._buffers[name] = val
345+
if name.startswith(_AM_PREFIX):
346+
actual = name[len(_AM_PREFIX):]
347+
if _atomic_model is not None:
348+
originals[name] = _atomic_model._buffers.get(actual)
349+
_atomic_model._buffers[actual] = val
350+
else:
351+
if _fitting is not None:
352+
originals[name] = _fitting._buffers.get(name)
353+
_fitting._buffers[name] = val
322354
try:
323355
return model.forward_lower(
324356
extended_coord,
@@ -331,7 +363,13 @@ def fn(
331363
)
332364
finally:
333365
for name, orig in originals.items():
334-
_fitting._buffers[name] = orig
366+
if name.startswith(_AM_PREFIX):
367+
actual = name[len(_AM_PREFIX):]
368+
if _atomic_model is not None:
369+
_atomic_model._buffers[actual] = orig
370+
else:
371+
if _fitting is not None:
372+
_fitting._buffers[name] = orig
335373

336374
# Pick a trace-time nframes that's unlikely to collide with any other
337375
# tensor dim in the graph. The symbolic tracer merges symbols that
@@ -491,9 +529,15 @@ def forward(
491529
if self._task_buf_order:
492530
try:
493531
_fitting = self.original_model.get_fitting_net()
494-
task_buf_vals: tuple = tuple(
495-
getattr(_fitting, name) for name in self._task_buf_order
496-
)
532+
_am = getattr(self.original_model, "atomic_model", None)
533+
_vals: list[torch.Tensor] = []
534+
for _name in self._task_buf_order:
535+
if _name.startswith(_AM_PREFIX):
536+
_actual = _name[len(_AM_PREFIX):]
537+
_vals.append(_am._buffers[_actual])
538+
else:
539+
_vals.append(getattr(_fitting, _name))
540+
task_buf_vals: tuple = tuple(_vals)
497541
except AttributeError:
498542
task_buf_vals = ()
499543
else:

0 commit comments

Comments
 (0)