Skip to content

Commit 9c1ef51

Browse files
akwariiJaGeo
andauthored
Batched workflows for TorchSim (#1505)
* Support batching for the elastic workflow * Update model construction due to API changes * Remove unsupported models * Remove legacy models from TorchSimModelType * Add tests for torchsim model wrappers * Revert mattersim changes * skip test if torchsim is not installed * Minimalistic elastic workflow for torchsim * Add tests * WIP: adapt common elastic jobs * Fix Optimizer being casted to its string value when passed to a jobflow.job * squeeze the stress tensor * Support TorchSim symmetry constraints * Convert the stress to kbar * Unit conversion without ASE * Batched phonon workflow * Remove ase and torchsim dependencies from common phonon jobs * Use lighter model * Fix tests * Disable torchsim in test-non-ase jobs * Fix phonon maker type check and improve readability * Include cell forces to the convergence check * Add values check on different phonon derived properties * Improve socket keyword documentation * Add elastic workflow tutorial * Loosen test tolerance --------- Co-authored-by: J. George <JaGeo@users.noreply.github.com>
1 parent ae78248 commit 9c1ef51

23 files changed

Lines changed: 1003 additions & 322 deletions

File tree

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ jobs:
168168
# However this `splitting-algorithm` means that tests cannot depend sensitively on the order they're executed in.
169169
run: |
170170
micromamba activate a2
171-
pytest --durations=5 -n auto --splits 3 --group ${{ matrix.split }} --durations-path tests/.pytest-split-durations --splitting-algorithm least_duration --ignore=tests/ase --ignore=tests/openff_md --ignore=tests/openmm_md --ignore=tests/forcefields --cov=atomate2 --cov-report=xml
171+
pytest --durations=5 -n auto --splits 3 --group ${{ matrix.split }} --durations-path tests/.pytest-split-durations --splitting-algorithm least_duration --ignore=tests/ase --ignore=tests/openff_md --ignore=tests/openmm_md --ignore=tests/forcefields --ignore=tests/torchsim --cov=atomate2 --cov-report=xml
172172
173173
174174
- uses: codecov/codecov-action@v1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ ase-ext = ["tblite>=0.3.0; platform_system=='Linux'"]
6363
forcefields-demo = ["chgnet>=0.3.8","atomate2[ase]"]
6464

6565
torchsim = [
66-
"torch-sim-atomistic==0.6.0; python_version >= '3.12'"
66+
"torch-sim-atomistic[symmetry]==0.6.0; python_version >= '3.12'"
6767
]
6868
jdftx = ["pymatgen==2026.5.4"]
6969
approxneb = ["pymatgen-analysis-diffusion>=2024.7.15"]

src/atomate2/common/flows/elastic.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from atomate2.aims.jobs.base import BaseAimsMaker
2626
from atomate2.forcefields.jobs import ForceFieldRelaxMaker
27+
from atomate2.torchsim import TorchSimOptimizeMaker
2728
from atomate2.vasp.jobs.base import BaseVaspMaker
2829

2930

@@ -72,26 +73,38 @@ class BaseElasticMaker(Maker, ABC):
7273
Keyword arguments passed to :obj:`fit_elastic_tensor`.
7374
task_document_kwargs : dict
7475
Additional keyword args passed to :obj:`.ElasticDocument.from_stresses()`.
76+
socket : bool
77+
If True, uses the socket-io interface to run all deformations in a single
78+
job, reducing overhead. In the specific case of TorchSim, this enables batching
79+
of all structure relaxations.
80+
Note: socket=True is not supported for BaseVaspMaker.
7581
"""
7682

7783
name: str = "elastic"
7884
order: int = 2
7985
sym_reduce: bool = True
8086
symprec: float = SETTINGS.SYMPREC
81-
bulk_relax_maker: BaseAimsMaker | BaseVaspMaker | ForceFieldRelaxMaker | None = None
82-
elastic_relax_maker: BaseAimsMaker | BaseVaspMaker | ForceFieldRelaxMaker = (
83-
None # constant volume optimization
84-
)
87+
bulk_relax_maker: (
88+
BaseAimsMaker
89+
| BaseVaspMaker
90+
| ForceFieldRelaxMaker
91+
| TorchSimOptimizeMaker
92+
| None
93+
) = None
94+
elastic_relax_maker: (
95+
BaseAimsMaker | BaseVaspMaker | ForceFieldRelaxMaker | TorchSimOptimizeMaker
96+
) = None # constant volume optimization
8597
max_failed_deformations: int | float | None = None
8698
generate_elastic_deformations_kwargs: dict = field(default_factory=dict)
8799
fit_elastic_tensor_kwargs: dict = field(default_factory=dict)
88100
task_document_kwargs: dict = field(default_factory=dict)
101+
socket: bool = False
89102

90103
def make(
91104
self,
92105
structure: Structure,
93106
prev_dir: str | Path | None = None,
94-
equilibrium_stress: Matrix3D = None,
107+
equilibrium_stress: Matrix3D | None = None,
95108
conventional: bool = False,
96109
) -> Flow:
97110
"""
@@ -135,15 +148,16 @@ def make(
135148
**self.generate_elastic_deformations_kwargs,
136149
)
137150

138-
vasp_deformation_calcs = run_elastic_deformations(
151+
deformation_calcs = run_elastic_deformations(
139152
structure,
140153
deformations.output,
141154
elastic_relax_maker=self.elastic_relax_maker,
142155
prev_dir=prev_dir,
156+
socket=self.socket,
143157
)
144158
fit_tensor = fit_elastic_tensor(
145159
structure,
146-
vasp_deformation_calcs.output,
160+
deformation_calcs.output,
147161
equilibrium_stress=equilibrium_stress,
148162
order=self.order,
149163
symprec=self.symprec if self.sym_reduce else None,
@@ -156,7 +170,7 @@ def make(
156170
# allow some of the deformations to fail
157171
fit_tensor.config.on_missing_references = OnMissing.NONE
158172

159-
jobs += [deformations, vasp_deformation_calcs, fit_tensor]
173+
jobs += [deformations, deformation_calcs, fit_tensor]
160174

161175
return Flow(
162176
jobs=jobs,

src/atomate2/common/flows/phonons.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from atomate2.aims.jobs.base import BaseAimsMaker
2929
from atomate2.forcefields.jobs import ForceFieldRelaxMaker, ForceFieldStaticMaker
30+
from atomate2.torchsim.core import TorchSimOptimizeMaker, TorchSimStaticMaker
3031
from atomate2.vasp.jobs.base import BaseVaspMaker
3132

3233
SUPPORTED_CODES = frozenset(("vasp", "aims", "forcefields", "ase", "torchsim"))
@@ -99,13 +100,16 @@ class BasePhononMaker(Maker, ABC):
99100
A maker to perform a tight relaxation on the bulk.
100101
Set to ``None`` to skip the
101102
bulk relaxation
102-
static_energy_maker: .ForceFieldRelaxMaker, .BaseAimsMaker, .BaseVaspMaker, or None
103+
static_energy_maker: .ForceFieldRelaxMaker, .BaseAimsMaker, .BaseVaspMaker,
104+
.TorchSimStaticMaker, or None
103105
A maker to perform the computation of the DFT energy on the bulk.
104106
Set to ``None`` to skip the
105107
static energy computation
106-
born_maker: .ForceFieldStaticMaker, .BaseAsimsMaker, .BaseVaspMaker, or None
108+
born_maker: .ForceFieldStaticMaker, .BaseAsimsMaker, .BaseVaspMaker,
109+
.TorchSimStaticMaker, or None
107110
Maker to compute the BORN charges.
108-
phonon_displacement_maker: .ForceFieldStaticMaker, .BaseAimsMaker, .BaseVaspMaker
111+
phonon_displacement_maker: .ForceFieldStaticMaker, .BaseAimsMaker, .BaseVaspMaker,
112+
.TorchSimStaticMaker
109113
Maker used to compute the forces for a supercell.
110114
generate_frequencies_eigenvectors_kwargs : dict
111115
Keyword arguments passed to :obj:`generate_frequencies_eigenvectors`.
@@ -132,7 +136,10 @@ class BasePhononMaker(Maker, ABC):
132136
store_force_constants: bool
133137
if True, force constants will be stored
134138
socket: bool
135-
If True, use the socket/batch for the calculation
139+
If True, uses the socket-io interface to run all displacements in a single
140+
job, reducing overhead. In the specific case of TorchSim, this enables batching
141+
of all static structure evaluations.
142+
Note: socket=True is not supported for BaseVaspMaker.
136143
"""
137144

138145
name: str = "phonon"
@@ -145,14 +152,26 @@ class BasePhononMaker(Maker, ABC):
145152
allow_orthorhombic: bool = False
146153
get_supercell_size_kwargs: dict = field(default_factory=dict)
147154
use_symmetrized_structure: Literal["primitive", "conventional"] | None = None
148-
bulk_relax_maker: ForceFieldRelaxMaker | BaseVaspMaker | BaseAimsMaker | None = None
149-
static_energy_maker: ForceFieldRelaxMaker | BaseVaspMaker | BaseAimsMaker | None = (
150-
None
151-
)
152-
born_maker: ForceFieldStaticMaker | BaseVaspMaker | None = None
153-
phonon_displacement_maker: ForceFieldStaticMaker | BaseVaspMaker | BaseAimsMaker = (
155+
bulk_relax_maker: (
156+
ForceFieldRelaxMaker
157+
| BaseVaspMaker
158+
| BaseAimsMaker
159+
| TorchSimOptimizeMaker
160+
| None
161+
) = None
162+
static_energy_maker: (
163+
ForceFieldRelaxMaker
164+
| BaseVaspMaker
165+
| BaseAimsMaker
166+
| TorchSimStaticMaker
167+
| None
168+
) = None
169+
born_maker: ForceFieldStaticMaker | BaseVaspMaker | TorchSimStaticMaker | None = (
154170
None
155171
)
172+
phonon_displacement_maker: (
173+
ForceFieldStaticMaker | BaseVaspMaker | BaseAimsMaker | TorchSimStaticMaker
174+
) = None
156175
create_thermal_displacements: bool = True
157176
generate_frequencies_eigenvectors_kwargs: dict = field(
158177
default_factory=lambda: {

src/atomate2/common/jobs/elastic.py

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from atomate2 import SETTINGS
1919
from atomate2.common.analysis.elastic import get_default_strain_states
2020
from atomate2.common.schemas.elastic import ElasticDocument
21+
from atomate2.common.utils import check_class_name
22+
from atomate2.vasp.jobs.base import BaseVaspMaker
2123

2224
if TYPE_CHECKING:
2325
from pathlib import Path
@@ -26,7 +28,7 @@
2628
from pymatgen.core.structure import Structure
2729

2830
from atomate2.forcefields.jobs import ForceFieldRelaxMaker
29-
from atomate2.vasp.jobs.base import BaseVaspMaker
31+
from atomate2.torchsim import TorchSimOptimizeMaker
3032

3133

3234
logger = logging.getLogger(__name__)
@@ -104,14 +106,19 @@ def run_elastic_deformations(
104106
structure: Structure,
105107
deformations: list[Deformation],
106108
prev_dir: str | Path | None = None,
107-
prev_dir_argname: str = None,
108-
elastic_relax_maker: BaseVaspMaker | ForceFieldRelaxMaker = None,
109+
prev_dir_argname: str | None = None,
110+
elastic_relax_maker: BaseVaspMaker
111+
| ForceFieldRelaxMaker
112+
| TorchSimOptimizeMaker = None,
113+
socket: bool = False,
109114
) -> Response:
110115
"""
111116
Run elastic deformations.
112117
113-
Note, this job will replace itself with N relaxation calculations, where N is
114-
the number of deformations.
118+
Note, this job will replace itself with N relaxation calculations,
119+
or a single batched calculation using the socket interface to run all
120+
deformations simultaneously. This results in lower overhead as well as
121+
parallel relaxation of the deformations for TorchSim.
115122
116123
Parameters
117124
----------
@@ -121,46 +128,81 @@ def run_elastic_deformations(
121128
The deformations to apply.
122129
prev_dir : str or Path or None
123130
A previous directory to use for copying outputs.
124-
prev_dir_argname: str
125-
argument name for the prev_dir variable
126-
elastic_relax_maker : .BaseVaspMaker or .ForceFieldRelaxMaker
127-
A VaspMaker or a ForceFieldMaker to use to generate the elastic relaxation jobs.
131+
prev_dir_argname: str or None
132+
Argument name for the prev_dir variable.
133+
elastic_relax_maker : .BaseVaspMaker, .ForceFieldRelaxMaker, or
134+
.TorchSimOptimizeMaker
135+
A VaspMaker, ForceFieldMaker, or TorchSimMaker to use to generate the elastic
136+
relaxation jobs.
137+
socket : bool
138+
If True, uses the socket-io interface to run all deformations in a single
139+
job, reducing overhead. In the specific case of TorchSim, this enables batching
140+
of all structure relaxations.
141+
Note: socket=True is not supported for BaseVaspMaker.
128142
"""
129-
relaxations = []
143+
num_deformations = len(deformations)
144+
elastic_jobs = []
130145
outputs = []
131-
for idx, deformation in enumerate(deformations):
132-
# deform the structure
146+
147+
if socket and isinstance(elastic_relax_maker, BaseVaspMaker):
148+
raise ValueError("socket=True is not supported for BaseVaspMaker.")
149+
150+
deformed_structures = []
151+
for deformation in deformations:
133152
dst = DeformStructureTransformation(deformation=deformation)
134153
ts = TransformedStructure(structure, transformations=[dst])
135-
deformed_structure = ts.final_structure
154+
deformed_structures.append(ts.final_structure)
136155

137156
with contextlib.suppress(Exception):
138-
# write details of the transformation to the transformations.json file
139-
# this file will automatically get added to the task document and allow
140-
# the elastic builder to reconstruct the elastic document; note the ":" is
141-
# automatically converted to a "." in the filename.
157+
# Write details of the transformation to the transformations.json file
142158
elastic_relax_maker.write_additional_data["transformations:json"] = ts
143159

144-
elastic_job_kwargs = {}
145-
if prev_dir is not None and prev_dir_argname is not None:
146-
elastic_job_kwargs[prev_dir_argname] = prev_dir
147-
# create the job
148-
relax_job = elastic_relax_maker.make(deformed_structure, **elastic_job_kwargs)
149-
relax_job.append_name(f" {idx + 1}/{len(deformations)}")
150-
relaxations.append(relax_job)
160+
elastic_job_kwargs = {}
161+
if prev_dir is not None and prev_dir_argname is not None:
162+
elastic_job_kwargs[prev_dir_argname] = prev_dir
151163

152-
# extract the outputs we want
153-
output = {
154-
"stress": relax_job.output.output.stress,
155-
"deformation": deformation,
156-
"uuid": relax_job.output.uuid,
157-
"job_dir": relax_job.output.dir_name,
158-
}
164+
if socket:
165+
batched_job = elastic_relax_maker.make(
166+
deformed_structures, **elastic_job_kwargs
167+
)
168+
batched_job.append_name(" batched_socket")
169+
elastic_jobs.append(batched_job)
170+
171+
for idx, deformation in enumerate(deformations):
172+
if check_class_name(elastic_relax_maker, "TorchSimOptimizeMaker"):
173+
output = {
174+
"stress": batched_job.output.output.stress[idx],
175+
"deformation": deformation,
176+
"uuid": batched_job.output.uuid,
177+
"job_dir": batched_job.output.dir_name,
178+
}
179+
else:
180+
output = {
181+
"stress": batched_job.output[idx].output.stress,
182+
"deformation": deformation,
183+
"uuid": batched_job.output[idx].uuid,
184+
"job_dir": batched_job.output[idx].dir_name,
185+
}
186+
outputs.append(output)
187+
188+
else:
189+
for idx, deformed_structure in enumerate(deformed_structures):
190+
relax_job = elastic_relax_maker.make(
191+
deformed_structure, **elastic_job_kwargs
192+
)
193+
relax_job.append_name(f" {idx + 1}/{num_deformations}")
194+
elastic_jobs.append(relax_job)
159195

160-
outputs.append(output)
196+
output = {
197+
"stress": relax_job.output.output.stress,
198+
"deformation": deformations[idx],
199+
"uuid": relax_job.output.uuid,
200+
"job_dir": relax_job.output.dir_name,
201+
}
202+
outputs.append(output)
161203

162-
relax_flow = Flow(relaxations, outputs)
163-
return Response(replace=relax_flow)
204+
elastic_flow = Flow(elastic_jobs, outputs)
205+
return Response(replace=elastic_flow)
164206

165207

166208
@job(output_schema=ElasticDocument)
@@ -221,7 +263,7 @@ def fit_elastic_tensor(
221263
failed_uuids.append(data["uuid"])
222264
continue
223265

224-
stresses.append(Stress(stress_sign_factor * np.array(data["stress"])))
266+
stresses.append(Stress(stress_sign_factor * np.squeeze(data["stress"])))
225267
deformations.append(Deformation(data["deformation"]))
226268
uuids.append(data["uuid"])
227269
job_dirs.append(data["job_dir"])

0 commit comments

Comments
 (0)