Skip to content

Commit e534b96

Browse files
committed
feat(pt): add fixed gaussian angle
1 parent bf01e3c commit e534b96

5 files changed

Lines changed: 139 additions & 3 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
dropout_rate: float = 0.1,
8787
angle_use_sh_init: bool = False,
8888
angle_sh_init_lmax: int = 3,
89+
angle_use_fixed_gaussian: bool = False,
8990
) -> None:
9091
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
9192
@@ -220,6 +221,7 @@ def __init__(
220221
self.dropout_rate = dropout_rate
221222
self.angle_use_sh_init = angle_use_sh_init
222223
self.angle_sh_init_lmax = angle_sh_init_lmax
224+
self.angle_use_fixed_gaussian = angle_use_fixed_gaussian
223225
assert (
224226
fix_stat_std == 0.3
225227
), "fix_stat_std is not implemented in this version, please use skip_stat instead."

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def init_subclass_params(sub_data, sub_class):
225225
dropout_rate=self.repflow_args.dropout_rate,
226226
angle_use_sh_init=self.repflow_args.angle_use_sh_init,
227227
angle_sh_init_lmax=self.repflow_args.angle_sh_init_lmax,
228+
angle_use_fixed_gaussian=self.repflow_args.angle_use_fixed_gaussian,
228229
exclude_types=exclude_types,
229230
env_protection=env_protection,
230231
precision=precision,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
prod_env_mat,
1818
)
1919
from deepmd.pt.model.network.mlp import (
20+
AnglePriorEncoder,
2021
AngleSH,
2122
MLPLayer,
2223
)
@@ -168,6 +169,7 @@ def __init__(
168169
dropout_rate: float = 0.1,
169170
angle_use_sh_init: bool = False,
170171
angle_sh_init_lmax: int = 3,
172+
angle_use_fixed_gaussian: bool = False,
171173
seed: Optional[Union[int, list[int]]] = None,
172174
) -> None:
173175
r"""
@@ -302,6 +304,13 @@ def __init__(
302304
else:
303305
self.angle_sh = None
304306

307+
self.angle_use_fixed_gaussian = angle_use_fixed_gaussian
308+
if self.angle_use_fixed_gaussian:
309+
self.angle_gaussian_encoder = AnglePriorEncoder(
310+
sigma_deg=6.0, learn_sigma=False, normalize=None
311+
)
312+
else:
313+
self.angle_gaussian_encoder = None
305314
self.use_env_envelope = use_env_envelope
306315
self.use_new_sw = use_new_sw
307316
self.use_force_embedding = use_force_embedding
@@ -468,14 +477,17 @@ def __init__(
468477
self.e_dim,
469478
]
470479
self.edge_embd = RadialMLP(edge_channels_list)
471-
if not self.angle_use_sh_init:
480+
481+
if self.angle_use_sh_init:
482+
angle_input_dim = self.angle_sh_init_lmax + 1
483+
elif self.angle_use_fixed_gaussian:
484+
angle_input_dim = 10 + 1
485+
else:
472486
angle_input_dim = (
473487
len(self.angle_multi_freq_list_float) + 1
474488
if not self.angle_init_use_sin
475489
else 2 * (len(self.angle_multi_freq_list_float) + 1)
476490
)
477-
else:
478-
angle_input_dim = self.angle_sh_init_lmax + 1
479491

480492
self.angle_embd = MLPLayer(
481493
angle_input_dim,
@@ -999,6 +1011,11 @@ def forward(
9991011
assert self.angle_sh is not None
10001012
# nf x nloc x a_nnei x a_nnei x sh_sim [OR] n_angle x sh_sim
10011013
angle_input = self.angle_sh(angle_input * (torch.pi**0.5))
1014+
elif self.angle_use_fixed_gaussian:
1015+
assert not self.angle_init_use_sin and not self.angle_use_multi_freq
1016+
assert self.angle_gaussian_encoder is not None
1017+
# nf x nloc x a_nnei x a_nnei x 11 [OR] n_angle x 11
1018+
angle_input = self.angle_gaussian_encoder(angle_input)
10021019

10031020
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
10041021
angle_ebd = self.angle_embd(angle_input)

deepmd/pt/model/network/mlp.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import torch
1111
import torch.nn as nn
12+
import torch.nn.functional as F
1213

1314
from deepmd.pt.utils import (
1415
env,
@@ -524,6 +525,115 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
524525
return P * self.norm.type(x.dtype)
525526

526527

528+
class AnglePriorEncoder(nn.Module):
529+
"""
530+
Smooth delta encoder for bond angles a ∈ [0, π] (radians).
531+
- Fixed 10 prior centers (in radians) taken from common molecular/material geometries.
532+
- Kernel: Gaussian RBF on linear difference (no periodic wrapping).
533+
- Output: softmax-normalized similarity vector of length 10 (smooth one-hot).
534+
- Optional: learnable global width (sigma).
535+
536+
Centers (degrees → radians):
537+
1) 180.0 : linear (sp), also octahedral/trans positions
538+
2) 120.0 : trigonal planar (sp2), graphene etc.
539+
3) 109.47 : ideal tetrahedral (sp3)
540+
4) 104.5 : water H-O-H
541+
5) 106.7 : ammonia H-N-H (trigonal pyramidal)
542+
6) 90.0 : square planar / octahedral adjacent
543+
7) 180.0 : (duplicate center kept intentionally for prior emphasis)
544+
8) 120.0 : trigonal bipyramidal equatorial-equatorial
545+
9) 90.0 : trigonal bipyramidal axial-equatorial
546+
10) 60.0 : cyclopropane strained angle
547+
"""
548+
549+
def __init__(
550+
self,
551+
sigma_deg: float = 6.0, # initial Gaussian width in degrees
552+
learn_sigma: bool = True, # make sigma trainable if desired
553+
normalize: Optional[str] = "softmax",
554+
eps: float = 1e-9,
555+
):
556+
super().__init__()
557+
assert normalize in ("softmax", "l1", None)
558+
self.normalize = normalize
559+
self.eps = eps
560+
561+
# --- Fixed prior centers (degrees) ---
562+
centers_deg = torch.tensor(
563+
[180.0, 120.0, 109.47, 104.5, 106.7, 90.0, 180.0, 120.0, 90.0, 60.0],
564+
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
565+
device=device,
566+
)
567+
568+
# Convert to radians and store as buffer: shape (K,)
569+
centers_rad = centers_deg * (torch.pi / 180.0)
570+
self.register_buffer("centers", centers_rad) # (10,)
571+
572+
# --- Width parameter (global sigma, radians) ---
573+
sigma_rad = float(sigma_deg) * math.pi / 180.0
574+
575+
# Softplus parameterization to keep sigma > 0
576+
def inv_softplus(x: float) -> float:
577+
x = max(x, 1e-12)
578+
return float(math.log(math.exp(x) - 1.0))
579+
580+
raw = torch.tensor(
581+
inv_softplus(sigma_rad), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device
582+
)
583+
if learn_sigma:
584+
self.raw_sigma = nn.Parameter(data=raw)
585+
else:
586+
self.register_buffer("raw_sigma", raw)
587+
588+
@property
589+
def sigma(self) -> torch.Tensor:
590+
"""Current positive width (radians)."""
591+
return F.softplus(self.raw_sigma) + 1e-12
592+
593+
# @torch.no_grad()
594+
# def auto_sigma_from_centers(self, factor: float = 0.6, min_sigma_deg: float = 1.0):
595+
# """
596+
# Set a reasonable global sigma from center spacing on [0, π].
597+
# Uses median nearest-neighbor distance x factor, with a lower bound.
598+
# """
599+
# c = self.centers # (K,)
600+
# # Pairwise |c_i - c_j|
601+
# dmat = torch.abs(c[:, None] - c[None, :])
602+
# # Ignore self-distance
603+
# dmat = dmat + torch.eye(c.numel(), dtype=c.dtype, device=c.device) * 1e6
604+
# dmin = dmat.min(dim=1).values # nearest neighbor distance per center
605+
# # Use median spacing to get a single global sigma
606+
# sigma = torch.clamp(torch.median(dmin) * factor,
607+
# min=min_sigma_deg * math.pi / 180.0)
608+
# # Write into raw_sigma (inverse softplus)
609+
# with torch.no_grad():
610+
# self.raw_sigma.copy_(torch.log(torch.exp(sigma) - 1.0))
611+
612+
def forward(self, x: torch.Tensor) -> torch.Tensor:
613+
"""
614+
a: (...,) tensor of angles in radians, expected in [0, π].
615+
returns: (..., 10) similarity/weight vector.
616+
"""
617+
theta = torch.acos(x)
618+
centers = self.centers.type(x.dtype)
619+
s = self.sigma.type(x.dtype)
620+
# Linear difference (no periodicity)
621+
diff = theta - centers # (..., K)
622+
# Gaussian kernel
623+
sims = torch.exp(-0.5 * (diff / s).pow(2)) # (..., K)
624+
625+
# Normalization
626+
if self.normalize is None:
627+
codes = sims
628+
elif self.normalize == "softmax":
629+
codes = F.softmax(torch.log(sims + self.eps), dim=-1)
630+
elif self.normalize == "l1":
631+
codes = sims / (sims.sum(dim=-1, keepdim=True) + self.eps)
632+
else:
633+
raise ValueError(f"Unknown normalization: {self.normalize}")
634+
return torch.cat([x, codes], dim=-1)
635+
636+
527637
def find_normalization(name: str, dim: int | None = None) -> nn.Module | None:
528638
"""Return an normalization function using name."""
529639
if name is None:

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,12 @@ def dpa3_repflow_args():
16231623
optional=True,
16241624
default=3,
16251625
),
1626+
Argument(
1627+
"angle_use_fixed_gaussian",
1628+
bool,
1629+
optional=True,
1630+
default=False,
1631+
),
16261632
Argument(
16271633
"use_dynamic_sel",
16281634
bool,

0 commit comments

Comments
 (0)