1818from atomate2 import SETTINGS
1919from atomate2 .common .analysis .elastic import get_default_strain_states
2020from atomate2 .common .schemas .elastic import ElasticDocument
21+ from atomate2 .common .utils import check_class_name
22+ from atomate2 .vasp .jobs .base import BaseVaspMaker
2123
2224if TYPE_CHECKING :
2325 from pathlib import Path
2628 from pymatgen .core .structure import Structure
2729
2830 from atomate2 .forcefields .jobs import ForceFieldRelaxMaker
29- from atomate2 .vasp . jobs . base import BaseVaspMaker
31+ from atomate2 .torchsim import TorchSimOptimizeMaker
3032
3133
3234logger = logging .getLogger (__name__ )
@@ -104,14 +106,19 @@ def run_elastic_deformations(
104106 structure : Structure ,
105107 deformations : list [Deformation ],
106108 prev_dir : str | Path | None = None ,
107- prev_dir_argname : str = None ,
108- elastic_relax_maker : BaseVaspMaker | ForceFieldRelaxMaker = None ,
109+ prev_dir_argname : str | None = None ,
110+ elastic_relax_maker : BaseVaspMaker
111+ | ForceFieldRelaxMaker
112+ | TorchSimOptimizeMaker = None ,
113+ socket : bool = False ,
109114) -> Response :
110115 """
111116 Run elastic deformations.
112117
113- Note, this job will replace itself with N relaxation calculations, where N is
114- the number of deformations.
118+ Note, this job will replace itself with N relaxation calculations,
119+ or a single batched calculation using the socket interface to run all
120+ deformations simultaneously. This results in lower overhead as well as
121+ parallel relaxation of the deformations for TorchSim.
115122
116123 Parameters
117124 ----------
@@ -121,46 +128,81 @@ def run_elastic_deformations(
121128 The deformations to apply.
122129 prev_dir : str or Path or None
123130 A previous directory to use for copying outputs.
124- prev_dir_argname: str
125- argument name for the prev_dir variable
126- elastic_relax_maker : .BaseVaspMaker or .ForceFieldRelaxMaker
127- A VaspMaker or a ForceFieldMaker to use to generate the elastic relaxation jobs.
131+ prev_dir_argname: str or None
132+ Argument name for the prev_dir variable.
133+ elastic_relax_maker : .BaseVaspMaker, .ForceFieldRelaxMaker, or
134+ .TorchSimOptimizeMaker
135+ A VaspMaker, ForceFieldMaker, or TorchSimMaker to use to generate the elastic
136+ relaxation jobs.
137+ socket : bool
138+ If True, uses the socket-io interface to run all deformations in a single
139+ job, reducing overhead. In the specific case of TorchSim, this enables batching
140+ of all structure relaxations.
141+ Note: socket=True is not supported for BaseVaspMaker.
128142 """
129- relaxations = []
143+ num_deformations = len (deformations )
144+ elastic_jobs = []
130145 outputs = []
131- for idx , deformation in enumerate (deformations ):
132- # deform the structure
146+
147+ if socket and isinstance (elastic_relax_maker , BaseVaspMaker ):
148+ raise ValueError ("socket=True is not supported for BaseVaspMaker." )
149+
150+ deformed_structures = []
151+ for deformation in deformations :
133152 dst = DeformStructureTransformation (deformation = deformation )
134153 ts = TransformedStructure (structure , transformations = [dst ])
135- deformed_structure = ts .final_structure
154+ deformed_structures . append ( ts .final_structure )
136155
137156 with contextlib .suppress (Exception ):
138- # write details of the transformation to the transformations.json file
139- # this file will automatically get added to the task document and allow
140- # the elastic builder to reconstruct the elastic document; note the ":" is
141- # automatically converted to a "." in the filename.
157+ # Write details of the transformation to the transformations.json file
142158 elastic_relax_maker .write_additional_data ["transformations:json" ] = ts
143159
144- elastic_job_kwargs = {}
145- if prev_dir is not None and prev_dir_argname is not None :
146- elastic_job_kwargs [prev_dir_argname ] = prev_dir
147- # create the job
148- relax_job = elastic_relax_maker .make (deformed_structure , ** elastic_job_kwargs )
149- relax_job .append_name (f" { idx + 1 } /{ len (deformations )} " )
150- relaxations .append (relax_job )
160+ elastic_job_kwargs = {}
161+ if prev_dir is not None and prev_dir_argname is not None :
162+ elastic_job_kwargs [prev_dir_argname ] = prev_dir
151163
152- # extract the outputs we want
153- output = {
154- "stress" : relax_job .output .output .stress ,
155- "deformation" : deformation ,
156- "uuid" : relax_job .output .uuid ,
157- "job_dir" : relax_job .output .dir_name ,
158- }
164+ if socket :
165+ batched_job = elastic_relax_maker .make (
166+ deformed_structures , ** elastic_job_kwargs
167+ )
168+ batched_job .append_name (" batched_socket" )
169+ elastic_jobs .append (batched_job )
170+
171+ for idx , deformation in enumerate (deformations ):
172+ if check_class_name (elastic_relax_maker , "TorchSimOptimizeMaker" ):
173+ output = {
174+ "stress" : batched_job .output .output .stress [idx ],
175+ "deformation" : deformation ,
176+ "uuid" : batched_job .output .uuid ,
177+ "job_dir" : batched_job .output .dir_name ,
178+ }
179+ else :
180+ output = {
181+ "stress" : batched_job .output [idx ].output .stress ,
182+ "deformation" : deformation ,
183+ "uuid" : batched_job .output [idx ].uuid ,
184+ "job_dir" : batched_job .output [idx ].dir_name ,
185+ }
186+ outputs .append (output )
187+
188+ else :
189+ for idx , deformed_structure in enumerate (deformed_structures ):
190+ relax_job = elastic_relax_maker .make (
191+ deformed_structure , ** elastic_job_kwargs
192+ )
193+ relax_job .append_name (f" { idx + 1 } /{ num_deformations } " )
194+ elastic_jobs .append (relax_job )
159195
160- outputs .append (output )
196+ output = {
197+ "stress" : relax_job .output .output .stress ,
198+ "deformation" : deformations [idx ],
199+ "uuid" : relax_job .output .uuid ,
200+ "job_dir" : relax_job .output .dir_name ,
201+ }
202+ outputs .append (output )
161203
162- relax_flow = Flow (relaxations , outputs )
163- return Response (replace = relax_flow )
204+ elastic_flow = Flow (elastic_jobs , outputs )
205+ return Response (replace = elastic_flow )
164206
165207
166208@job (output_schema = ElasticDocument )
@@ -221,7 +263,7 @@ def fit_elastic_tensor(
221263 failed_uuids .append (data ["uuid" ])
222264 continue
223265
224- stresses .append (Stress (stress_sign_factor * np .array (data ["stress" ])))
266+ stresses .append (Stress (stress_sign_factor * np .squeeze (data ["stress" ])))
225267 deformations .append (Deformation (data ["deformation" ]))
226268 uuids .append (data ["uuid" ])
227269 job_dirs .append (data ["job_dir" ])
0 commit comments