Conversation
1. implementation of SOG-Net and LES, based on the local charge setting 2. direct calculation of force and virial, instead of autograd of energy todo: 1. improve of computational efficiency of kernel function(mainly at NUFFT) 2. more experiments to verify the correctness and effectiveness of the implementation 3. implementation on the LAMMPS end
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
Caution Review failedThe pull request is closed. ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (26)
📝 WalkthroughWalkthroughAdds three energy model families (LES, LR, SOG): new fitting networks, atomic-model wrappers, top-level PyTorch models with NUFFT-based frame-correction pipelines, argument schema, examples/benchmarks, tests, and packaging/entrypoint updates. Exports and registry entries are wired across model, task, and atomic_model packages. Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Caller (coord, atype, box)
participant Extender as Extender / NeighborList
participant Descriptor as Descriptor
participant Fitting as FittingNet (SR / LR)
participant FrameCorr as FrameCorrection (NUFFT)
participant ModelOut as Model Output
Client->>Extender: extend_input_and_build_neighbor_list(coord, atype, box)
Extender->>Descriptor: compute per-atom descriptor features (g2,h2,rot_mat,...)
Descriptor->>Fitting: per-atom descriptor features (+fparam/aparam)
Fitting->>Fitting: SR forward -> energy_redu\nLR forward -> latent_charge
Fitting->>FrameCorr: per-frame coords + latent_charge + box
Note over FrameCorr: fractional coords -> k-grid -> Gaussian damping\nNUFFT type-1(/type-2) -> corr_redu (+ optional force/virial)
FrameCorr->>ModelOut: frame correction bundle (corr_redu, force_local, virial_local)
Fitting->>ModelOut: base energy and derivative tensors
ModelOut->>Client: combined energy, force, virial (with frame correction applied)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR adds PyTorch implementations of SOG-Net and LES long-range corrections driven by learned per-atom “latent charge”, including direct NUFFT-based force/virial computation and accompanying tests/examples.
Changes:
- Introduces SOG/LES fitting nets, atomic models, and model wrappers that add a NUFFT-based per-frame long-range correction (energy/force/virial).
- Adds working-layer unit tests for SOG/LES and a water SOG training + profiling example bundle.
- Updates CLI arg schema and packaging/dependency configuration to expose the new fitting types and finufft requirements.
Reviewed changes
Copilot reviewed 24 out of 26 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/model/test_sog_working_layer.py | Adds SOG working-layer tests (frame correction + forward/forward_lower consistency). |
| source/tests/pt/model/test_les_working_layer.py | Adds LES working-layer tests (frame correction + consistency + grad path). |
| pyproject.toml | Renames distribution + CLI script; adds finufft deps to a pinned PyTorch GPU group. |
| examples/water/sog/sog.hdf5 | Adds SOG example statistics file. |
| examples/water/sog/README.md | Documents how to run the SOG water example. |
| examples/water/sog/profile_sog_whatif.py | Profiles timing under alternative model hyperparameters. |
| examples/water/sog/profile_sog_timing.py | Provides detailed timing breakdown (incl. NUFFT) via monkeypatching. |
| examples/water/sog/input_torch.json | Adds a full SOG training configuration for the water example. |
| examples/water/sog/compare_sog_dpa3_timing.py | Compares runtime vs DPA3 and isolates frame-correction overhead. |
| examples/water/sog/check_sog_consistency_with_cace.py | Script intended to compare SOG correction against an external reference. |
| examples/water/sog/ab_retain_graph.py | Benchmarks an alternative autograd-based correction application path. |
| examples/water/dpa3/input_torch_copy.json | Adds a copied DPA3 config used for timing comparisons. |
| examples/water/dpa3/dpa3.hdf5 | Adds DPA3 example statistics file. |
| deepmd/utils/argcheck.py | Registers new sog_energy / les_energy fitting args. |
| deepmd/pt/model/task/sog_energy_fitting.py | Implements SOGEnergyFittingNet (SR energy + LR latent charge + kernel params). |
| deepmd/pt/model/task/lr_fitting.py | Adds a generic SR+LR fitting base used by SOG/LES. |
| deepmd/pt/model/task/les_energy_fitting.py | Implements LESEnergyFittingNet (SR energy + LR latent charge + sigma param). |
| deepmd/pt/model/task/init.py | Exposes LR/SOG/LES fitting nets in the task package API. |
| deepmd/pt/model/model/sog_model.py | Adds SOGEnergyModel with NUFFT-based frame correction (energy/force/virial). |
| deepmd/pt/model/model/les_model.py | Adds LESEnergyModel with NUFFT-based frame correction (energy/force/virial). |
| deepmd/pt/model/model/init.py | Wires new fitting types into get_standard_model and imports new models. |
| deepmd/pt/model/atomic_model/sog_atomic_model.py | Adds SOGEnergyAtomicModel to produce SR energy + latent charge. |
| deepmd/pt/model/atomic_model/lr_energy_atomic_model.py | Adds/updates an auxiliary SR+property correction atomic model. |
| deepmd/pt/model/atomic_model/les_atomic_model.py | Adds LESEnergyAtomicModel to produce SR energy + latent charge. |
| deepmd/pt/model/atomic_model/init.py | Exports new atomic models. |
| deepmd/pt_expt/train/training.py | Propagates disp_freq to modules supporting set_debug_print_freq. |
Comments suppressed due to low confidence (1)
pyproject.toml:179
- In the
pin_pytorch_gpudependency group, newly addedpytorch-finufft>=.../cufinufft>=...are lower-bounded rather than pinned, which is inconsistent with the intent implied by the group name (and with otherpin_*groups that use exact pins). Consider pinning exact versions here (or moving these deps to a non-pinned extra) to keep environments reproducible.
"cufinufft>=2.5.0; platform_system=='Linux' and platform_machine=='x86_64'",
]
pin_jax = [
"jax==0.5.0;python_version>='3.10'",
]
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| doc=doc_only_pt_supported + doc_trainable, | ||
| ), | ||
| Argument( | ||
| "rcond", | ||
| [float, type(None)], |
There was a problem hiding this comment.
les_energy fitting args define shift/amplitude, but LESEnergyFittingNet.__init__ expects sigma (and does not accept shift/amplitude). This will make valid LES configs fail validation or mislead users. Replace these arguments with a sigma argument (plus doc/default) to match the actual LES fitting net API.
| Any, | ||
| ) | ||
|
|
||
| import pytorch_finufft |
There was a problem hiding this comment.
Importing pytorch_finufft at module import time makes deepmd.pt.model.model.sog_model (and anything that imports it) unusable unless the optional NUFFT dependency is installed. This also breaks the unit tests' skip logic, because the model import will raise before the skip decorator can run. Consider wrapping the import in a try/except ImportError and raising a clear runtime error only when the NUFFT correction path is actually executed (or lazily importing inside _compute_sog_frame_correction_bundle).
| import pytorch_finufft | |
| try: | |
| import pytorch_finufft | |
| except ImportError: # Optional dependency; only required for NUFFT-based corrections. | |
| class _MissingPytorchFinufftProxy: | |
| def __getattr__(self, name: str) -> Any: | |
| raise RuntimeError( | |
| "pytorch_finufft is not installed but is required for NUFFT-based " | |
| "corrections in SOGEnergyModel. Please install pytorch_finufft to " | |
| "use this functionality." | |
| ) | |
| pytorch_finufft = _MissingPytorchFinufftProxy() # type: ignore[assignment] |
| Any, | ||
| ) | ||
|
|
||
| import pytorch_finufft |
There was a problem hiding this comment.
Importing pytorch_finufft at module import time makes deepmd.pt.model.model.les_model unusable unless the optional NUFFT dependency is installed, and will also break any tests that try to skip when finufft is absent (the import fails before skip logic runs). Wrap this in try/except ImportError and defer the hard failure until the frame-correction code path is invoked (or lazily import inside _compute_les_frame_correction_bundle).
| import pytorch_finufft | |
| try: | |
| import pytorch_finufft | |
| except ImportError: | |
| class _MissingPytorchFinufft: | |
| """Placeholder for optional pytorch_finufft dependency. | |
| Any attempt to access attributes on this object will raise an | |
| ImportError at runtime, deferring the hard failure until the | |
| NUFFT-based frame-correction code path is actually used. | |
| """ | |
| def __getattr__(self, name: str) -> Any: # type: ignore[override] | |
| raise ImportError( | |
| "pytorch_finufft is required for LES frame correction but is " | |
| "not installed. Please install pytorch_finufft to use this " | |
| "functionality." | |
| ) | |
| pytorch_finufft = _MissingPytorchFinufft() # type: ignore[assignment] |
| [project] | ||
| name = "deepmd-kit" | ||
| name = "deepmd-kit-dev" | ||
| dynamic = ["version", "optional-dependencies", "scripts", "readme"] | ||
| description = "A deep learning package for many-body potential energy representation and molecular dynamics" |
There was a problem hiding this comment.
Changing the distributable name to deepmd-kit-dev is a breaking packaging change and doesn't seem related to the PR description (SOG/LES implementation). Unless this is intentional for a separate dev distribution, it will break downstream installs and tooling expecting deepmd-kit. Consider reverting the project name change or splitting it into a dedicated packaging PR.
| @@ -142,7 +142,7 @@ jax = [ | |||
| ] | |||
|
|
|||
| [tool.deepmd_build_backend.scripts] | |||
There was a problem hiding this comment.
Renaming the CLI entry point from dp to dp_dev is a breaking user-facing change (and examples/docs still use dp). If this is meant to be the same tool, keep the dp script name (or provide both dp and dp_dev aliases) to avoid breaking existing workflows.
| [tool.deepmd_build_backend.scripts] | |
| [tool.deepmd_build_backend.scripts] | |
| dp = "deepmd.main:main" |
| model_predict_lower = self.forward_common_lower( | ||
| extended_coord, | ||
| extended_atype, | ||
| nlist, | ||
| mapping, |
There was a problem hiding this comment.
Same issue as in SOGEnergyModel.forward: this else: is paired with if self.do_grad_c('energy'), so forces requested without virials will hit the else branch and overwrite model_predict['force'] with model_ret['dforce']. Adjust conditionals so dforce is only used when do_grad_r is false.
| SOGEnergyModel, | ||
| ) | ||
| from .spin_model import ( | ||
| SpinEnergyModel, | ||
| SpinModel, |
There was a problem hiding this comment.
Unconditionally importing the SOG/LES model modules here will also unconditionally import pytorch_finufft (currently imported at module import time inside sog_model.py/les_model.py). That makes import deepmd.pt.model.model fail for users who don't have finufft installed, even if they never use SOG/LES. Consider guarding these imports or lazily importing the models only when needed.
| from deepmd.pt.model.descriptor.se_a import ( | ||
| DescrptSeA, | ||
| ) | ||
| from deepmd.pt.model.model.sog_model import ( | ||
| SOGEnergyModel, |
There was a problem hiding this comment.
Even though the test class is skipped when finufft isn't installed, SOGEnergyModel is imported at module import time. Since deepmd.pt.model.model.sog_model currently imports pytorch_finufft unconditionally, the test module will still fail to import before the skip decorator can run. Either make sog_model tolerate missing finufft (preferred) or delay importing SOGEnergyModel until after the skip condition is evaluated.
| from deepmd.pt.model.descriptor.se_a import ( | ||
| DescrptSeA, | ||
| ) | ||
| from deepmd.pt.model.model.les_model import ( | ||
| LESEnergyModel, |
There was a problem hiding this comment.
Same import-time issue as the SOG test: LESEnergyModel is imported at module import time, but deepmd.pt.model.model.les_model currently imports pytorch_finufft unconditionally. If finufft isn't installed, the test module will fail to import before the skip decorator can take effect. Prefer fixing les_model to handle missing finufft or delay importing LESEnergyModel until after the skip check.
| def main() -> None: | ||
| cfg = json.loads(pathlib.Path("examples/water/sog/input_torch.json").read_text())[ | ||
| "model" | ||
| ] | ||
| dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
There was a problem hiding this comment.
This script hard-codes a local absolute path (/data/zyjin/.../cace/modules/sog.py), which will fail for other users and in CI. Consider taking the path from a CLI argument / environment variable, or removing this script from the repo if it's only intended for local debugging.
todo:
Summary by CodeRabbit
New Features
Tests
Chores