Skip to content

Commit d87256f

Browse files
committed
feat: add lebedev quadrature
1 parent 5908295 commit d87256f

13 files changed

Lines changed: 699 additions & 76 deletions

File tree

deepmd/pt/model/descriptor/sezm.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@
9595
has_lora,
9696
np_safe,
9797
nvtx_range,
98-
resolve_s2_grid_resolution,
9998
safe_numpy_to_tensor,
10099
)
101100

@@ -257,12 +256,16 @@ class DescrptSeZM(BaseDescriptor, nn.Module):
257256
``activation_function="silu"``.
258257
``ffn_enabled=True`` makes the block-internal FFN path use
259258
``activation_function="silu"`` and ``glu_activation=True``.
259+
S2-grid resolutions are resolved automatically per block. The e3nn
260+
product grid uses ``[2 * mmax + 4, ceil_even(3 * lmax + 2)]`` in the
261+
SO(2) branch, and the FFN branch lifts it to a square
262+
``[max(R_phi, R_theta), max(R_phi, R_theta)]`` grid. Lebedev branches
263+
use the smallest packaged rule with precision at least ``3 * lmax``.
260264
The final ``l=0`` output FFN is unchanged.
261-
s2_grid_resolution
262-
Two-element list ``[R_phi, R_theta]`` used by the S2-grid activation.
263-
If omitted, it is resolved from the first block ``(lmax, mmax)`` after
264-
schedule parsing as
265-
``[2 * mmax + 4, ceil_even(3 * lmax + 2)]``.
265+
lebedev_quadrature
266+
Two booleans ``[so2_enabled, ffn_enabled]`` aligned with
267+
``s2_activation``. If enabled for a branch, that branch uses Lebedev
268+
quadrature instead of the e3nn product grid in its S2 projector.
266269
activation_function
267270
Base activation function for helper MLPs, the SO(2) gated activation
268271
path, and the final ``l=0`` output FFN.
@@ -349,7 +352,7 @@ def __init__(
349352
full_attn_res: str = "none",
350353
block_attn_res: str = "none",
351354
s2_activation: list[bool] | None = None,
352-
s2_grid_resolution: list[int] | None = None,
355+
lebedev_quadrature: list[bool] | None = None,
353356
activation_function: str = "silu",
354357
glu_activation: bool = True,
355358
use_amp: bool = True,
@@ -419,12 +422,25 @@ def __init__(
419422
"`s2_activation` must be a list[bool] of length 2: [so2_activation, ffn_activation]"
420423
)
421424
self.s2_activation = list(s2_activation)
425+
if lebedev_quadrature is None:
426+
lebedev_quadrature = [False, False]
427+
if not isinstance(lebedev_quadrature, list) or len(lebedev_quadrature) != 2:
428+
raise ValueError(
429+
"`lebedev_quadrature` must be a list[bool] of length 2: [so2_quadrature, ffn_quadrature]"
430+
)
431+
if any(not isinstance(flag, bool) for flag in lebedev_quadrature):
432+
raise ValueError(
433+
"`lebedev_quadrature` must be a list[bool] of length 2: [so2_quadrature, ffn_quadrature]"
434+
)
435+
self.lebedev_quadrature = list(lebedev_quadrature)
422436
self.activation_function = str(activation_function)
423437
self.glu_activation = bool(glu_activation)
424438

425439
# === Split effective activation config by branch ===
426440
self.so2_s2_activation = self.s2_activation[0]
427441
self.ffn_s2_activation = self.s2_activation[1]
442+
self.so2_lebedev_quadrature = self.lebedev_quadrature[0]
443+
self.ffn_lebedev_quadrature = self.lebedev_quadrature[1]
428444
self.so2_activation_function = (
429445
"silu" if self.so2_s2_activation else self.activation_function
430446
)
@@ -512,11 +528,6 @@ def __init__(
512528

513529
# === L/M schedules ===
514530
self._init_lm_schedules(lmax, n_blocks, l_schedule, mmax, m_schedule)
515-
self.s2_grid_resolution = resolve_s2_grid_resolution(
516-
self.lmax,
517-
self.mmax,
518-
s2_grid_resolution,
519-
)
520531
self.ebed_dims = [get_so3_dim_of_lmax(l) for l in self.l_schedule]
521532
self.rad_sizes_per_block = [l + 1 for l in self.l_schedule]
522533

@@ -716,7 +727,8 @@ def __init__(
716727
block_attn_res=self.block_attn_res_mode,
717728
so2_s2_activation=self.so2_s2_activation,
718729
ffn_s2_activation=self.ffn_s2_activation,
719-
s2_grid_resolution=self.s2_grid_resolution,
730+
so2_lebedev_quadrature=self.so2_lebedev_quadrature,
731+
ffn_lebedev_quadrature=self.ffn_lebedev_quadrature,
720732
n_atten_head=self.n_atten_head,
721733
mixed_attention=self.mixed_attention,
722734
legacy_attention=self.legacy_attention,
@@ -1770,7 +1782,7 @@ def serialize(self) -> dict[str, Any]:
17701782
"full_attn_res": self.full_attn_res_mode,
17711783
"block_attn_res": self.block_attn_res_mode,
17721784
"s2_activation": self.s2_activation,
1773-
"s2_grid_resolution": self.s2_grid_resolution,
1785+
"lebedev_quadrature": self.lebedev_quadrature,
17741786
"activation_function": self.activation_function,
17751787
"glu_activation": self.glu_activation,
17761788
"precision": RESERVED_PRECISION_DICT[self.dtype],
@@ -1803,6 +1815,7 @@ def deserialize(cls, data: dict[str, Any]) -> DescrptSeZM:
18031815
config = data.pop("config")
18041816
variables = data.pop("@variables")
18051817
data.pop("env_mat", None)
1818+
config.pop("s2_grid_resolution", None)
18061819
obj = cls(**config)
18071820
template = obj.state_dict()
18081821
state = {

deepmd/pt/model/descriptor/sezm_nn/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@
5656
project_Dt_from_m,
5757
so3_packed_index,
5858
)
59+
from .lebedev import (
60+
LEBEDEV_PRECISION_TO_NPOINTS,
61+
load_lebedev_rule,
62+
)
5963
from .lora import (
6064
LoRASO2,
6165
LoRASO3,
@@ -109,6 +113,7 @@
109113

110114
__all__ = [
111115
"ATTN_RES_MODES",
116+
"LEBEDEV_PRECISION_TO_NPOINTS",
112117
"BridgingSwitch",
113118
"C3CutoffEnvelope",
114119
"ChannelLinear",
@@ -159,6 +164,7 @@
159164
"get_so3_dim_of_lmax",
160165
"has_lora",
161166
"init_trunc_normal_fan_in_out",
167+
"load_lebedev_rule",
162168
"map_degree_idx",
163169
"merge_lora_into_base",
164170
"np_safe",

0 commit comments

Comments
 (0)