@@ -80,7 +80,9 @@ def eval_model(
8080 assert isinstance (spins , torch .Tensor ), err_msg
8181 assert isinstance (atom_types , torch .Tensor ) or isinstance (atom_types , list )
8282 if isinstance (atom_types , torch .Tensor ):
83- atom_types = atom_types .clone ().detach ().to (dtype = torch .int32 , device = DEVICE )
83+ atom_types = (
84+ atom_types .clone ().detach ().to (dtype = torch .int32 , device = DEVICE )
85+ )
8486 else :
8587 atom_types = torch .tensor (atom_types , dtype = torch .int32 , device = DEVICE )
8688 elif isinstance (coords , np .ndarray ):
@@ -105,15 +107,27 @@ def eval_model(
105107 natoms = len (atom_types [0 ])
106108
107109 if isinstance (coords , torch .Tensor ):
108- coord_input = coords .reshape ([- 1 , natoms , 3 ]).clone ().detach ().to (dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE )
110+ coord_input = (
111+ coords .reshape ([- 1 , natoms , 3 ])
112+ .clone ()
113+ .detach ()
114+ .to (dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE )
115+ )
109116 else :
110117 coord_input = torch .tensor (
111- coords .reshape ([- 1 , natoms , 3 ]), dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE
118+ coords .reshape ([- 1 , natoms , 3 ]),
119+ dtype = GLOBAL_PT_FLOAT_PRECISION ,
120+ device = DEVICE ,
112121 )
113122 spin_input = None
114123 if spins is not None :
115124 if isinstance (spins , torch .Tensor ):
116- spin_input = spins .reshape ([- 1 , natoms , 3 ]).clone ().detach ().to (dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE )
125+ spin_input = (
126+ spins .reshape ([- 1 , natoms , 3 ])
127+ .clone ()
128+ .detach ()
129+ .to (dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE )
130+ )
117131 else :
118132 spin_input = torch .tensor (
119133 spins .reshape ([- 1 , natoms , 3 ]),
@@ -133,10 +147,17 @@ def eval_model(
133147 else :
134148 pbc = True
135149 if isinstance (cells , torch .Tensor ):
136- box_input = cells .reshape ([- 1 , 3 , 3 ]).clone ().detach ().to (dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE )
150+ box_input = (
151+ cells .reshape ([- 1 , 3 , 3 ])
152+ .clone ()
153+ .detach ()
154+ .to (dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE )
155+ )
137156 else :
138157 box_input = torch .tensor (
139- cells .reshape ([- 1 , 3 , 3 ]), dtype = GLOBAL_PT_FLOAT_PRECISION , device = DEVICE
158+ cells .reshape ([- 1 , 3 , 3 ]),
159+ dtype = GLOBAL_PT_FLOAT_PRECISION ,
160+ device = DEVICE ,
140161 )
141162 num_iter = int ((nframes + infer_batch_size - 1 ) / infer_batch_size )
142163
0 commit comments