|
1 | 1 | # SPDX-License-Identifier: LGPL-3.0-or-later |
2 | 2 | """ |
3 | | -SeZM: The descriptor of smooth equivariant Zone-bridging Model. |
| 3 | +SeZM descriptor: Smooth Equivariant Zone-bridging Model. |
4 | 4 |
|
5 | 5 | PyTorch backend |
6 | 6 |
|
7 | | -This implementation is designed around two non-negotiables: |
| 7 | +This implementation is designed around two goals: |
8 | 8 |
|
9 | 9 | 1) Conservative forces: the descriptor is computed from differentiable energy. |
10 | | -2) Speed-first inference: edge geometry and Wigner-D rotation blocks are computed |
| 10 | +2) Efficient inference: edge geometry and Wigner-D rotation blocks are computed |
11 | 11 | exactly once per `forward()` and reused by all interaction blocks. |
12 | 12 |
|
13 | 13 | Shared descriptor building blocks are re-exported by `sezm_nn/__init__.py`. |
|
117 | 117 |
|
118 | 118 | @BaseDescriptor.register("SeZM") |
119 | 119 | @BaseDescriptor.register("sezm") |
| 120 | +@BaseDescriptor.register("DPA4") |
120 | 121 | @BaseDescriptor.register("dpa4") |
121 | 122 | class DescrptSeZM(BaseDescriptor, nn.Module): |
122 | 123 | """ |
123 | | - SeZM: The descriptor of smooth equivariant Zone-bridging Model for DeePMD-kit. |
| 124 | + SeZM descriptor. |
124 | 125 |
|
125 | 126 | Execution outline |
126 | 127 | ----------------- |
@@ -242,8 +243,8 @@ class DescrptSeZM(BaseDescriptor, nn.Module): |
242 | 243 | - DepthAttnRes: input-dependent query projection |
243 | 244 | - EnvironmentInitialEmbedding: |
244 | 245 | rbf_proj_layer1/2 and g_layer1/2 |
245 | | - Attention projections in SO2Convolution |
246 | | - (attn_radial_logit_proj, attn_output_gate_proj) are always bias-free. |
| 246 | + Attention logit and output-gate parameters in SO(2) convolution are |
| 247 | + always bias-free. |
247 | 248 | layer_scale |
248 | 249 | If True, apply learnable LayerScale (init 1e-3) on residual branches: |
249 | 250 | - SO(2) branch: per-focus-channel scales `(n_focus, focus_dim)` |
@@ -292,9 +293,11 @@ class DescrptSeZM(BaseDescriptor, nn.Module): |
292 | 293 | ``True`` only when ``s2_activation[1]=True``. The final ``l=0`` output |
293 | 294 | FFN always keeps this user-provided value. |
294 | 295 | use_amp |
295 | | - If True, use automatic mixed precision (AMP) with bfloat16 on CUDA. |
296 | | - This does not provide accelerations under fp32 precision but will decrease |
297 | | - the memory usage, while preserving model accuracy. |
| 296 | + If True, use automatic mixed precision (AMP) with bfloat16 on CUDA |
| 297 | + during training. This can improve speed and reduce memory usage. |
| 298 | + Enabling this option is recommended on GPUs with native bfloat16 support. |
| 299 | + Disable it on GPUs without native bfloat16 support to avoid runtime |
| 300 | + errors or additional conversion overhead. |
298 | 301 | exclude_types |
299 | 302 | List of excluded type pairs. |
300 | 303 | precision |
@@ -1554,8 +1557,11 @@ def _compute_mode_ctx(self, device: torch.device) -> Generator[None, None, None] |
1554 | 1557 | Notes |
1555 | 1558 | ----- |
1556 | 1559 | - When `use_amp=True` and the model is in training mode, enables |
1557 | | - torch.autocast with bfloat16 on CUDA. |
1558 | | - - Only affects autocast-eligible operations (matmul, conv, etc.). |
| 1560 | + torch.autocast with bfloat16 on CUDA. This can improve speed and |
| 1561 | + reduce memory usage on GPUs with native bfloat16 support. |
| 1562 | + Disable AMP on GPUs without native bfloat16 support to avoid runtime |
| 1563 | + errors or additional conversion overhead. |
| 1564 | + - Only affects autocast-eligible operations. |
1559 | 1565 | - Does nothing during inference (`self.training=False`), on non-CUDA |
1560 | 1566 | devices, or when `use_amp=False`. |
1561 | 1567 |
|
|
0 commit comments