@@ -113,25 +113,8 @@ def forward_common(
113113 ): # use force to calculate energy hessian
114114 force : torch .Tensor = ret ["energy_derv_r" ]\
115115 .squeeze (- 2 ) # nf x nloc x 3
116- nf , nloc , _ = force .shape
117- hess = (
118- torch .autograd .grad (
119- outputs = force ,
120- inputs = coord ,
121- grad_outputs = torch .eye (
122- nloc * 3 , device = force .device , dtype = force .dtype
123- )
124- .view (nloc * 3 , nloc , 3 )
125- .unsqueeze (1 ) # (nloc * 3, 1, nloc, 3)
126- .expand (- 1 , nf , - 1 , - 1 ), # (nloc * 3, nf, nloc, 3)
127- create_graph = self .training ,
128- retain_graph = True ,
129- is_grads_batched = True ,
130- )[0 ]
131- .swapaxes (0 , 1 ) # (nf, nloc * 3, nloc, 3)
132- .view (nf , 1 , nloc * 3 , nloc * 3 ) # (nf, 1, nloc * 3, nloc * 3)
133- )
134- hess = {get_hessian_name ("energy" ): - hess } # negative sign for force
116+ hess = self ._cal_e_hessian_block (force , coord )
117+ hess = {get_hessian_name (name = "energy" ): - hess } # negative sign for force
135118 else :
136119 hess = self ._cal_hessian_all (
137120 coord ,
@@ -143,6 +126,54 @@ def forward_common(
143126 ret .update (hess )
144127 return ret
145128
129+ def _cal_e_hessian_block (
130+ self ,
131+ force : torch .Tensor ,
132+ coord : torch .Tensor ,
133+ slice : slice = slice (None ),
134+ ) -> torch .Tensor :
135+ force = force [:, slice , :]
136+ nf , nslice , _ = force .shape
137+ _ , nloc , _ = coord .shape
138+ hess = (
139+ torch .autograd .grad (
140+ outputs = force ,
141+ inputs = coord ,
142+ grad_outputs = torch .eye (
143+ nslice * 3 , device = force .device , dtype = force .dtype
144+ )
145+ .view (nslice * 3 , nslice , 3 )
146+ .unsqueeze (1 ) # (nslice * 3, 1, nslice, 3)
147+ .expand (- 1 , nf , - 1 , - 1 ), # (nslice * 3, nf, nslice, 3)
148+ create_graph = self .training ,
149+ retain_graph = True ,
150+ is_grads_batched = True ,
151+ )[0 ] # (nslice * 3, nf, nloc, 3)
152+ .swapaxes (0 , 1 ) # (nf, nslice * 3, nloc, 3)
153+ .view (nf , 1 , nslice * 3 , nloc * 3 )
154+ )
155+ return hess
156+
157+ def _cal_e_hessian_loop (
158+ self ,
159+ force : torch .Tensor ,
160+ coord : torch .Tensor ,
161+ ) -> torch .Tensor :
162+ hess = torch .zeros (
163+ * force .shape , * coord .shape [- 2 :], device = force .device , dtype = force .dtype
164+ ) # nf, nloc, 3, nloc, 3
165+ for nloc in range (coord .shape [- 2 ]):
166+ for i in range (3 ):
167+ hess [:, nloc , i ] = torch .autograd .grad (
168+ outputs = force [:, nloc , i ],
169+ inputs = coord ,
170+ grad_outputs = torch .ones_like (force [:, nloc , i ]),
171+ create_graph = self .training ,
172+ retain_graph = True ,
173+ )[0 ]
174+ nloc = coord .shape [- 2 ]
175+ return hess .view (- 1 , 1 , nloc * 3 , nloc * 3 )
176+
146177 def _cal_hessian_all (
147178 self ,
148179 coord : torch .Tensor ,
0 commit comments