|
9 | 9 | import numpy as np |
10 | 10 | import torch |
11 | 11 | import torch.nn as nn |
| 12 | +import torch.nn.functional as F |
12 | 13 |
|
13 | 14 | from deepmd.pt.utils import ( |
14 | 15 | env, |
@@ -524,6 +525,115 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: |
524 | 525 | return P * self.norm.type(x.dtype) |
525 | 526 |
|
526 | 527 |
|
| 528 | +class AnglePriorEncoder(nn.Module): |
| 529 | + """ |
| 530 | + Smooth delta encoder for bond angles a ∈ [0, π] (radians). |
| 531 | + - Fixed 10 prior centers (in radians) taken from common molecular/material geometries. |
| 532 | + - Kernel: Gaussian RBF on linear difference (no periodic wrapping). |
| 533 | + - Output: softmax-normalized similarity vector of length 10 (smooth one-hot). |
| 534 | + - Optional: learnable global width (sigma). |
| 535 | +
|
| 536 | + Centers (degrees → radians): |
| 537 | + 1) 180.0 : linear (sp), also octahedral/trans positions |
| 538 | + 2) 120.0 : trigonal planar (sp2), graphene etc. |
| 539 | + 3) 109.47 : ideal tetrahedral (sp3) |
| 540 | + 4) 104.5 : water H-O-H |
| 541 | + 5) 106.7 : ammonia H-N-H (trigonal pyramidal) |
| 542 | + 6) 90.0 : square planar / octahedral adjacent |
| 543 | + 7) 180.0 : (duplicate center kept intentionally for prior emphasis) |
| 544 | + 8) 120.0 : trigonal bipyramidal equatorial-equatorial |
| 545 | + 9) 90.0 : trigonal bipyramidal axial-equatorial |
| 546 | + 10) 60.0 : cyclopropane strained angle |
| 547 | + """ |
| 548 | + |
| 549 | + def __init__( |
| 550 | + self, |
| 551 | + sigma_deg: float = 6.0, # initial Gaussian width in degrees |
| 552 | + learn_sigma: bool = True, # make sigma trainable if desired |
| 553 | + normalize: Optional[str] = "softmax", |
| 554 | + eps: float = 1e-9, |
| 555 | + ): |
| 556 | + super().__init__() |
| 557 | + assert normalize in ("softmax", "l1", None) |
| 558 | + self.normalize = normalize |
| 559 | + self.eps = eps |
| 560 | + |
| 561 | + # --- Fixed prior centers (degrees) --- |
| 562 | + centers_deg = torch.tensor( |
| 563 | + [180.0, 120.0, 109.47, 104.5, 106.7, 90.0, 180.0, 120.0, 90.0, 60.0], |
| 564 | + dtype=env.GLOBAL_PT_FLOAT_PRECISION, |
| 565 | + device=device, |
| 566 | + ) |
| 567 | + |
| 568 | + # Convert to radians and store as buffer: shape (K,) |
| 569 | + centers_rad = centers_deg * (torch.pi / 180.0) |
| 570 | + self.register_buffer("centers", centers_rad) # (10,) |
| 571 | + |
| 572 | + # --- Width parameter (global sigma, radians) --- |
| 573 | + sigma_rad = float(sigma_deg) * math.pi / 180.0 |
| 574 | + |
| 575 | + # Softplus parameterization to keep sigma > 0 |
| 576 | + def inv_softplus(x: float) -> float: |
| 577 | + x = max(x, 1e-12) |
| 578 | + return float(math.log(math.exp(x) - 1.0)) |
| 579 | + |
| 580 | + raw = torch.tensor( |
| 581 | + inv_softplus(sigma_rad), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device |
| 582 | + ) |
| 583 | + if learn_sigma: |
| 584 | + self.raw_sigma = nn.Parameter(data=raw) |
| 585 | + else: |
| 586 | + self.register_buffer("raw_sigma", raw) |
| 587 | + |
| 588 | + @property |
| 589 | + def sigma(self) -> torch.Tensor: |
| 590 | + """Current positive width (radians).""" |
| 591 | + return F.softplus(self.raw_sigma) + 1e-12 |
| 592 | + |
| 593 | + # @torch.no_grad() |
| 594 | + # def auto_sigma_from_centers(self, factor: float = 0.6, min_sigma_deg: float = 1.0): |
| 595 | + # """ |
| 596 | + # Set a reasonable global sigma from center spacing on [0, π]. |
| 597 | + # Uses median nearest-neighbor distance x factor, with a lower bound. |
| 598 | + # """ |
| 599 | + # c = self.centers # (K,) |
| 600 | + # # Pairwise |c_i - c_j| |
| 601 | + # dmat = torch.abs(c[:, None] - c[None, :]) |
| 602 | + # # Ignore self-distance |
| 603 | + # dmat = dmat + torch.eye(c.numel(), dtype=c.dtype, device=c.device) * 1e6 |
| 604 | + # dmin = dmat.min(dim=1).values # nearest neighbor distance per center |
| 605 | + # # Use median spacing to get a single global sigma |
| 606 | + # sigma = torch.clamp(torch.median(dmin) * factor, |
| 607 | + # min=min_sigma_deg * math.pi / 180.0) |
| 608 | + # # Write into raw_sigma (inverse softplus) |
| 609 | + # with torch.no_grad(): |
| 610 | + # self.raw_sigma.copy_(torch.log(torch.exp(sigma) - 1.0)) |
| 611 | + |
| 612 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 613 | + """ |
| 614 | + a: (...,) tensor of angles in radians, expected in [0, π]. |
| 615 | + returns: (..., 10) similarity/weight vector. |
| 616 | + """ |
| 617 | + theta = torch.acos(x) |
| 618 | + centers = self.centers.type(x.dtype) |
| 619 | + s = self.sigma.type(x.dtype) |
| 620 | + # Linear difference (no periodicity) |
| 621 | + diff = theta - centers # (..., K) |
| 622 | + # Gaussian kernel |
| 623 | + sims = torch.exp(-0.5 * (diff / s).pow(2)) # (..., K) |
| 624 | + |
| 625 | + # Normalization |
| 626 | + if self.normalize is None: |
| 627 | + codes = sims |
| 628 | + elif self.normalize == "softmax": |
| 629 | + codes = F.softmax(torch.log(sims + self.eps), dim=-1) |
| 630 | + elif self.normalize == "l1": |
| 631 | + codes = sims / (sims.sum(dim=-1, keepdim=True) + self.eps) |
| 632 | + else: |
| 633 | + raise ValueError(f"Unknown normalization: {self.normalize}") |
| 634 | + return torch.cat([x, codes], dim=-1) |
| 635 | + |
| 636 | + |
527 | 637 | def find_normalization(name: str, dim: int | None = None) -> nn.Module | None: |
528 | 638 | """Return an normalization function using name.""" |
529 | 639 | if name is None: |
|
0 commit comments