Skip to content

Commit c7edd6f

Browse files
committed
feat: impl part hessian calc
1 parent 2770e37 commit c7edd6f

1 file changed

Lines changed: 50 additions & 19 deletions

File tree

deepmd/pt/model/model/make_hessian_model.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -113,25 +113,8 @@ def forward_common(
113113
): # use force to calculate energy hessian
114114
force: torch.Tensor = ret["energy_derv_r"]\
115115
.squeeze(-2) # nf x nloc x 3
116-
nf, nloc, _ = force.shape
117-
hess = (
118-
torch.autograd.grad(
119-
outputs=force,
120-
inputs=coord,
121-
grad_outputs=torch.eye(
122-
nloc * 3, device=force.device, dtype=force.dtype
123-
)
124-
.view(nloc * 3, nloc, 3)
125-
.unsqueeze(1) # (nloc * 3, 1, nloc, 3)
126-
.expand(-1, nf, -1, -1), # (nloc * 3, nf, nloc, 3)
127-
create_graph=self.training,
128-
retain_graph=True,
129-
is_grads_batched=True,
130-
)[0]
131-
.swapaxes(0, 1) # (nf, nloc * 3, nloc, 3)
132-
.view(nf, 1, nloc * 3, nloc * 3) # (nf, 1, nloc * 3, nloc * 3)
133-
)
134-
hess = {get_hessian_name("energy"): -hess} # negative sign for force
116+
hess = self._cal_e_hessian_block(force, coord)
117+
hess = {get_hessian_name(name="energy"): -hess} # negative sign for force
135118
else:
136119
hess = self._cal_hessian_all(
137120
coord,
@@ -143,6 +126,54 @@ def forward_common(
143126
ret.update(hess)
144127
return ret
145128

129+
def _cal_e_hessian_block(
130+
self,
131+
force: torch.Tensor,
132+
coord: torch.Tensor,
133+
slice: slice = slice(None),
134+
) -> torch.Tensor:
135+
force = force[:, slice, :]
136+
nf, nslice, _ = force.shape
137+
_, nloc, _ = coord.shape
138+
hess = (
139+
torch.autograd.grad(
140+
outputs=force,
141+
inputs=coord,
142+
grad_outputs=torch.eye(
143+
nslice * 3, device=force.device, dtype=force.dtype
144+
)
145+
.view(nslice * 3, nslice, 3)
146+
.unsqueeze(1) # (nslice * 3, 1, nslice, 3)
147+
.expand(-1, nf, -1, -1), # (nslice * 3, nf, nslice, 3)
148+
create_graph=self.training,
149+
retain_graph=True,
150+
is_grads_batched=True,
151+
)[0] # (nslice * 3, nf, nloc, 3)
152+
.swapaxes(0, 1) # (nf, nslice * 3, nloc, 3)
153+
.view(nf, 1, nslice * 3, nloc * 3)
154+
)
155+
return hess
156+
157+
def _cal_e_hessian_loop(
158+
self,
159+
force: torch.Tensor,
160+
coord: torch.Tensor,
161+
) -> torch.Tensor:
162+
hess = torch.zeros(
163+
*force.shape, *coord.shape[-2:], device=force.device, dtype=force.dtype
164+
) # nf, nloc, 3, nloc, 3
165+
for nloc in range(coord.shape[-2]):
166+
for i in range(3):
167+
hess[:, nloc, i] = torch.autograd.grad(
168+
outputs=force[:, nloc, i],
169+
inputs=coord,
170+
grad_outputs=torch.ones_like(force[:, nloc, i]),
171+
create_graph=self.training,
172+
retain_graph=True,
173+
)[0]
174+
nloc = coord.shape[-2]
175+
return hess.view(-1, 1, nloc * 3, nloc * 3)
176+
146177
def _cal_hessian_all(
147178
self,
148179
coord: torch.Tensor,

0 commit comments

Comments
 (0)