Skip to content

Commit 6dc3a39

Browse files
committed
fix: remove unused type embedding for bias_q and simplify bias handling in LRFittingNet
1 parent 67837f4 commit 6dc3a39

1 file changed

Lines changed: 18 additions & 58 deletions

File tree

deepmd/pt/model/task/lr_fitting.py

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
)
1717
from deepmd.pt.model.network.network import (
1818
TypeEmbedNet,
19-
TypeEmbedNetConsistent,
2019
)
2120
from deepmd.pt.model.task.fitting import (
2221
Fitting,
@@ -130,7 +129,6 @@ def __init__(
130129
type_map: list[str] | None = None,
131130
use_aparam_as_mask: bool = False,
132131
default_fparam: list[float] | None = None,
133-
use_type_embed_for_bias_q: bool | None = None,
134132
**kwargs: Any,
135133
) -> None:
136134
super().__init__()
@@ -176,17 +174,9 @@ def __init__(
176174
assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!"
177175
self.register_buffer("bias_atom_e", bias_atom_e)
178176

179-
if bias_atom_q is not None:
180-
self.use_type_embed_for_bias_q = False
181-
elif use_type_embed_for_bias_q is None:
182-
self.use_type_embed_for_bias_q = True
183-
else:
184-
self.use_type_embed_for_bias_q = use_type_embed_for_bias_q
185-
self.bias_atom_q_type_embed: TypeEmbedNet | None = None
186-
self.bias_atom_q: torch.nn.Parameter | None = None
187-
if self.use_type_embed_for_bias_q:
188-
# Build per-type lr bias from a type-only embedding network.
189-
self.bias_atom_q_type_embed = TypeEmbedNet(
177+
if bias_atom_q is None:
178+
# No external bias provided: learn per-type bias via TypeEmbedNet.
179+
self.bias_atom_q: torch.nn.Parameter | TypeEmbedNet = TypeEmbedNet(
190180
type_nums=self.ntypes,
191181
embed_dim=self.lr_net_dim_out,
192182
precision=self.precision,
@@ -195,9 +185,6 @@ def __init__(
195185
trainable=self.trainable,
196186
)
197187
else:
198-
if bias_atom_q is None:
199-
# small random initialization to break saddle point
200-
bias_atom_q = np.random.randn(self.ntypes, self.lr_net_dim_out) * 0.01
201188
bias_atom_q = torch.tensor(
202189
bias_atom_q, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
203190
)
@@ -340,8 +327,7 @@ def change_type_map(
340327
)
341328
self.bias_atom_e = torch.cat([self.bias_atom_e, extend_bias_atom_e], dim=0)
342329

343-
if not self.use_type_embed_for_bias_q:
344-
assert self.bias_atom_q is not None
330+
if isinstance(self.bias_atom_q, torch.nn.Parameter):
345331
extend_shape_q = [len(type_map), *list(self.bias_atom_q.shape[1:])]
346332
extend_bias_atom_q = torch.zeros(
347333
extend_shape_q,
@@ -354,15 +340,13 @@ def change_type_map(
354340
)
355341

356342
self.bias_atom_e = self.bias_atom_e[remap_index]
357-
if self.use_type_embed_for_bias_q:
358-
assert self.bias_atom_q_type_embed is not None
359-
self.bias_atom_q_type_embed.change_type_map(type_map=type_map)
360-
else:
361-
assert self.bias_atom_q is not None
343+
if isinstance(self.bias_atom_q, torch.nn.Parameter):
362344
self.bias_atom_q = torch.nn.Parameter(
363345
self.bias_atom_q.data[remap_index],
364346
requires_grad=bool(self.trainable),
365347
)
348+
else:
349+
self.bias_atom_q.change_type_map(type_map=type_map)
366350

367351
def serialize(self) -> dict:
368352
"""Serialize the fitting to dict."""
@@ -388,12 +372,6 @@ def serialize(self) -> dict:
388372
"nets_lr": self.filter_layers_lr.serialize(),
389373
"rcond": self.rcond,
390374
"exclude_types": self.exclude_types,
391-
"use_type_embed_for_bias_q": self.use_type_embed_for_bias_q,
392-
"bias_atom_q_type_embed": (
393-
self.bias_atom_q_type_embed.embedding.serialize()
394-
if self.bias_atom_q_type_embed is not None
395-
else None
396-
),
397375
"@variables": {
398376
"bias_atom_e": to_numpy_array(self.bias_atom_e),
399377
"bias_atom_q": to_numpy_array(self._get_bias_atom_q_table()),
@@ -421,26 +399,15 @@ def serialize(self) -> dict:
421399
@classmethod
422400
def deserialize(cls, data: dict) -> "LRFittingNet":
423401
data = data.copy()
424-
use_type_embed_for_bias_q = data.get("use_type_embed_for_bias_q", False)
425-
data["use_type_embed_for_bias_q"] = use_type_embed_for_bias_q
426-
bias_atom_q_type_embed = data.pop("bias_atom_q_type_embed", None)
402+
# Compatibility with old checkpoints.
403+
data.pop("use_type_embed_for_bias_q", None)
404+
data.pop("bias_atom_q_type_embed", None)
427405
variables = data.pop("@variables")
428406
nets_sr = data.pop("nets_sr")
429407
nets_lr = data.pop("nets_lr")
430408
obj = cls(**data)
431-
if obj.use_type_embed_for_bias_q and bias_atom_q_type_embed is not None:
432-
assert obj.bias_atom_q_type_embed is not None
433-
obj.bias_atom_q_type_embed.embedding = TypeEmbedNetConsistent.deserialize(
434-
bias_atom_q_type_embed
435-
)
436409
for kk in variables.keys():
437410
if variables[kk] is not None:
438-
if (
439-
kk == "bias_atom_q"
440-
and obj.use_type_embed_for_bias_q
441-
and bias_atom_q_type_embed is not None
442-
):
443-
continue
444411
obj[kk] = to_torch_tensor(variables[kk])
445412
obj.filter_layers_sr = NetworkCollection.deserialize(nets_sr)
446413
obj.filter_layers_lr = NetworkCollection.deserialize(nets_lr)
@@ -500,15 +467,10 @@ def __setitem__(self, key: str, value: torch.Tensor) -> None:
500467
self.bias_atom_e = value
501468
elif key in ["bias_atom_q"]:
502469
value = value.view([self.ntypes, self._lr_net_out_dim()])
503-
if self.bias_atom_q is None:
504-
self.use_type_embed_for_bias_q = False
505-
self.bias_atom_q_type_embed = None
506-
self.bias_atom_q = torch.nn.Parameter(
507-
value,
508-
requires_grad=bool(self.trainable),
509-
)
510-
else:
511-
self.bias_atom_q.data.copy_(value)
470+
self.bias_atom_q = torch.nn.Parameter(
471+
value,
472+
requires_grad=bool(self.trainable),
473+
)
512474
elif key in ["fparam_avg"]:
513475
self.fparam_avg = value
514476
elif key in ["fparam_inv_std"]:
@@ -565,22 +527,20 @@ def _compress_bias_atom_q(self, bias: torch.Tensor) -> torch.Tensor:
565527
return self.bias_atom_q_bound * torch.tanh(bias / self.bias_atom_q_bound)
566528

567529
def _get_bias_atom_q_table(self) -> torch.Tensor:
568-
if self.bias_atom_q is not None:
530+
if isinstance(self.bias_atom_q, torch.nn.Parameter):
569531
return self._compress_bias_atom_q(self.bias_atom_q)
570-
assert self.bias_atom_q_type_embed is not None
571532
# `TypeEmbedNet` appends one zero-padding row; keep only real atom types.
572-
bias_table = self.bias_atom_q_type_embed.get_full_embedding(self.bias_atom_e.device)[
533+
bias_table = self.bias_atom_q.get_full_embedding(self.bias_atom_e.device)[
573534
: self.ntypes
574535
]
575536
return self._compress_bias_atom_q(bias_table)
576537

577538
def _get_lr_bias(self, atype: torch.Tensor) -> torch.Tensor:
578539
atype_long = atype.to(torch.long)
579-
if self.bias_atom_q is not None:
540+
if isinstance(self.bias_atom_q, torch.nn.Parameter):
580541
return self._compress_bias_atom_q(self.bias_atom_q[atype_long].to(self.prec))
581-
assert self.bias_atom_q_type_embed is not None
582542
return self._compress_bias_atom_q(
583-
self.bias_atom_q_type_embed(atype_long).to(self.prec)
543+
self.bias_atom_q(atype_long).to(self.prec)
584544
)
585545

586546
def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor:

0 commit comments

Comments
 (0)