Skip to content

Commit 049fb95

Browse files
committed
update doc
1 parent 7b84ec8 commit 049fb95

19 files changed

Lines changed: 341 additions & 875 deletions

deepmd/pt/entrypoints/freeze_pt2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
SeZM relies on a nested ``autograd.grad(create_graph=True)`` inside
55
``fit_output_to_model_output``; TorchScript cannot represent that
66
graph, so DPA4 / SeZM checkpoints are routed through AOTInductor instead.
7-
The output archive layout matches the ``pt_expt`` convention and is
8-
consumed directly by ``DeepPotPTExpt.cc`` without any C++ change.
7+
The output archive layout follows the ``pt_expt`` convention, including the
8+
metadata consumed by ``DeepPotPTExpt.cc`` and ``DeepSpinPTExpt.cc``.
99
1010
Tracing runs on CPU (``make_fx`` with ``_allow_non_fake_inputs=True``
1111
is brittle on CUDA because the proxy-tensor dispatcher does not set

deepmd/pt/model/descriptor/sezm.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
"""
3-
SeZM: The descriptor of smooth equivariant Zone-bridging Model.
3+
SeZM descriptor: Smooth Equivariant Zone-bridging Model.
44
55
PyTorch backend
66
7-
This implementation is designed around two non-negotiables:
7+
This implementation is designed around two goals:
88
99
1) Conservative forces: the descriptor is computed from differentiable energy.
10-
2) Speed-first inference: edge geometry and Wigner-D rotation blocks are computed
10+
2) Efficient inference: edge geometry and Wigner-D rotation blocks are computed
1111
exactly once per `forward()` and reused by all interaction blocks.
1212
1313
Shared descriptor building blocks are re-exported by `sezm_nn/__init__.py`.
@@ -117,10 +117,11 @@
117117

118118
@BaseDescriptor.register("SeZM")
119119
@BaseDescriptor.register("sezm")
120+
@BaseDescriptor.register("DPA4")
120121
@BaseDescriptor.register("dpa4")
121122
class DescrptSeZM(BaseDescriptor, nn.Module):
122123
"""
123-
SeZM: The descriptor of smooth equivariant Zone-bridging Model for DeePMD-kit.
124+
SeZM descriptor.
124125
125126
Execution outline
126127
-----------------
@@ -242,8 +243,8 @@ class DescrptSeZM(BaseDescriptor, nn.Module):
242243
- DepthAttnRes: input-dependent query projection
243244
- EnvironmentInitialEmbedding:
244245
rbf_proj_layer1/2 and g_layer1/2
245-
Attention projections in SO2Convolution
246-
(attn_radial_logit_proj, attn_output_gate_proj) are always bias-free.
246+
Attention logit and output-gate parameters in SO(2) convolution are
247+
always bias-free.
247248
layer_scale
248249
If True, apply learnable LayerScale (init 1e-3) on residual branches:
249250
- SO(2) branch: per-focus-channel scales `(n_focus, focus_dim)`
@@ -292,9 +293,11 @@ class DescrptSeZM(BaseDescriptor, nn.Module):
292293
``True`` only when ``s2_activation[1]=True``. The final ``l=0`` output
293294
FFN always keeps this user-provided value.
294295
use_amp
295-
If True, use automatic mixed precision (AMP) with bfloat16 on CUDA.
296-
This does not provide accelerations under fp32 precision but will decrease
297-
the memory usage, while preserving model accuracy.
296+
If True, use automatic mixed precision (AMP) with bfloat16 on CUDA
297+
during training. This can improve speed and reduce memory usage.
298+
Enabling this option is recommended on GPUs with native bfloat16 support.
299+
Disable it on GPUs without native bfloat16 support to avoid runtime
300+
errors or additional conversion overhead.
298301
exclude_types
299302
List of excluded type pairs.
300303
precision
@@ -1554,8 +1557,11 @@ def _compute_mode_ctx(self, device: torch.device) -> Generator[None, None, None]
15541557
Notes
15551558
-----
15561559
- When `use_amp=True` and the model is in training mode, enables
1557-
torch.autocast with bfloat16 on CUDA.
1558-
- Only affects autocast-eligible operations (matmul, conv, etc.).
1560+
torch.autocast with bfloat16 on CUDA. This can improve speed and
1561+
reduce memory usage on GPUs with native bfloat16 support.
1562+
Disable AMP on GPUs without native bfloat16 support to avoid runtime
1563+
errors or additional conversion overhead.
1564+
- Only affects autocast-eligible operations.
15591565
- Does nothing during inference (`self.training=False`), on non-CUDA
15601566
devices, or when `use_amp=False`.
15611567

deepmd/pt/model/model/sezm_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ def _rebuild_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
559559

560560
@BaseModel.register("SeZM")
561561
@BaseModel.register("sezm")
562+
@BaseModel.register("DPA4")
562563
@BaseModel.register("dpa4")
563564
class SeZMModel(DPModelCommon, SeZMModel_):
564565
"""
@@ -570,7 +571,9 @@ class SeZMModel(DPModelCommon, SeZMModel_):
570571
standard neighbor list and traces the local graph with ``make_fx`` for
571572
higher-order force training. Evaluation/inference compile usage is
572573
controlled by the `DP_COMPILE_INFER` environment variable read at model
573-
initialization time.
574+
initialization time. This path is experimental, requires ``torch==2.11``,
575+
may still expose PyTorch compiler bugs, and can improve training speed by
576+
roughly 2-3x on supported workloads.
574577
"""
575578

576579
model_type = "SeZM"

0 commit comments

Comments
 (0)