|
95 | 95 | has_lora, |
96 | 96 | np_safe, |
97 | 97 | nvtx_range, |
98 | | - resolve_s2_grid_resolution, |
99 | 98 | safe_numpy_to_tensor, |
100 | 99 | ) |
101 | 100 |
|
@@ -257,12 +256,16 @@ class DescrptSeZM(BaseDescriptor, nn.Module): |
257 | 256 | ``activation_function="silu"``. |
258 | 257 | ``ffn_enabled=True`` makes the block-internal FFN path use |
259 | 258 | ``activation_function="silu"`` and ``glu_activation=True``. |
| 259 | + S2-grid resolutions are resolved automatically per block. The e3nn |
| 260 | + product grid uses ``[2 * mmax + 4, ceil_even(3 * lmax + 2)]`` in the |
| 261 | + SO(2) branch, and the FFN branch lifts it to a square |
| 262 | + ``[max(R_phi, R_theta), max(R_phi, R_theta)]`` grid. Lebedev branches |
| 263 | + use the smallest packaged rule with precision at least ``3 * lmax``. |
260 | 264 | The final ``l=0`` output FFN is unchanged. |
261 | | - s2_grid_resolution |
262 | | - Two-element list ``[R_phi, R_theta]`` used by the S2-grid activation. |
263 | | - If omitted, it is resolved from the first block ``(lmax, mmax)`` after |
264 | | - schedule parsing as |
265 | | - ``[2 * mmax + 4, ceil_even(3 * lmax + 2)]``. |
| 265 | + lebedev_quadrature |
| 266 | + Two booleans ``[so2_enabled, ffn_enabled]`` aligned with |
| 267 | + ``s2_activation``. If enabled for a branch, that branch uses Lebedev |
| 268 | + quadrature instead of the e3nn product grid in its S2 projector. |
266 | 269 | activation_function |
267 | 270 | Base activation function for helper MLPs, the SO(2) gated activation |
268 | 271 | path, and the final ``l=0`` output FFN. |
@@ -349,7 +352,7 @@ def __init__( |
349 | 352 | full_attn_res: str = "none", |
350 | 353 | block_attn_res: str = "none", |
351 | 354 | s2_activation: list[bool] | None = None, |
352 | | - s2_grid_resolution: list[int] | None = None, |
| 355 | + lebedev_quadrature: list[bool] | None = None, |
353 | 356 | activation_function: str = "silu", |
354 | 357 | glu_activation: bool = True, |
355 | 358 | use_amp: bool = True, |
@@ -419,12 +422,25 @@ def __init__( |
419 | 422 | "`s2_activation` must be a list[bool] of length 2: [so2_activation, ffn_activation]" |
420 | 423 | ) |
421 | 424 | self.s2_activation = list(s2_activation) |
| 425 | + if lebedev_quadrature is None: |
| 426 | + lebedev_quadrature = [False, False] |
| 427 | + if not isinstance(lebedev_quadrature, list) or len(lebedev_quadrature) != 2: |
| 428 | + raise ValueError( |
| 429 | + "`lebedev_quadrature` must be a list[bool] of length 2: [so2_quadrature, ffn_quadrature]" |
| 430 | + ) |
| 431 | + if any(not isinstance(flag, bool) for flag in lebedev_quadrature): |
| 432 | + raise ValueError( |
| 433 | + "`lebedev_quadrature` must be a list[bool] of length 2: [so2_quadrature, ffn_quadrature]" |
| 434 | + ) |
| 435 | + self.lebedev_quadrature = list(lebedev_quadrature) |
422 | 436 | self.activation_function = str(activation_function) |
423 | 437 | self.glu_activation = bool(glu_activation) |
424 | 438 |
|
425 | 439 | # === Split effective activation config by branch === |
426 | 440 | self.so2_s2_activation = self.s2_activation[0] |
427 | 441 | self.ffn_s2_activation = self.s2_activation[1] |
| 442 | + self.so2_lebedev_quadrature = self.lebedev_quadrature[0] |
| 443 | + self.ffn_lebedev_quadrature = self.lebedev_quadrature[1] |
428 | 444 | self.so2_activation_function = ( |
429 | 445 | "silu" if self.so2_s2_activation else self.activation_function |
430 | 446 | ) |
@@ -512,11 +528,6 @@ def __init__( |
512 | 528 |
|
513 | 529 | # === L/M schedules === |
514 | 530 | self._init_lm_schedules(lmax, n_blocks, l_schedule, mmax, m_schedule) |
515 | | - self.s2_grid_resolution = resolve_s2_grid_resolution( |
516 | | - self.lmax, |
517 | | - self.mmax, |
518 | | - s2_grid_resolution, |
519 | | - ) |
520 | 531 | self.ebed_dims = [get_so3_dim_of_lmax(l) for l in self.l_schedule] |
521 | 532 | self.rad_sizes_per_block = [l + 1 for l in self.l_schedule] |
522 | 533 |
|
@@ -716,7 +727,8 @@ def __init__( |
716 | 727 | block_attn_res=self.block_attn_res_mode, |
717 | 728 | so2_s2_activation=self.so2_s2_activation, |
718 | 729 | ffn_s2_activation=self.ffn_s2_activation, |
719 | | - s2_grid_resolution=self.s2_grid_resolution, |
| 730 | + so2_lebedev_quadrature=self.so2_lebedev_quadrature, |
| 731 | + ffn_lebedev_quadrature=self.ffn_lebedev_quadrature, |
720 | 732 | n_atten_head=self.n_atten_head, |
721 | 733 | mixed_attention=self.mixed_attention, |
722 | 734 | legacy_attention=self.legacy_attention, |
@@ -1770,7 +1782,7 @@ def serialize(self) -> dict[str, Any]: |
1770 | 1782 | "full_attn_res": self.full_attn_res_mode, |
1771 | 1783 | "block_attn_res": self.block_attn_res_mode, |
1772 | 1784 | "s2_activation": self.s2_activation, |
1773 | | - "s2_grid_resolution": self.s2_grid_resolution, |
| 1785 | + "lebedev_quadrature": self.lebedev_quadrature, |
1774 | 1786 | "activation_function": self.activation_function, |
1775 | 1787 | "glu_activation": self.glu_activation, |
1776 | 1788 | "precision": RESERVED_PRECISION_DICT[self.dtype], |
@@ -1803,6 +1815,7 @@ def deserialize(cls, data: dict[str, Any]) -> DescrptSeZM: |
1803 | 1815 | config = data.pop("config") |
1804 | 1816 | variables = data.pop("@variables") |
1805 | 1817 | data.pop("env_mat", None) |
| 1818 | + config.pop("s2_grid_resolution", None) |
1806 | 1819 | obj = cls(**config) |
1807 | 1820 | template = obj.state_dict() |
1808 | 1821 | state = { |
|
0 commit comments