Skip to content

Commit 31c6008

Browse files
committed
add prob charge
1 parent 7f2a952 commit 31c6008

6 files changed

Lines changed: 70 additions & 9 deletions

File tree

deepmd/entrypoints/test.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def test(
7575
detail_file: str,
7676
atomic: bool,
7777
head: str | None = None,
78+
output_latent_charge: bool = False,
7879
**kwargs: Any,
7980
) -> None:
8081
"""Test model predictions.
@@ -185,6 +186,7 @@ def test(
185186
detail_file,
186187
atomic,
187188
append_detail=(cc != 0),
189+
output_latent_charge=output_latent_charge,
188190
)
189191
elif isinstance(dp, DeepDOS):
190192
err = test_dos(
@@ -305,6 +307,7 @@ def test_ener(
305307
detail_file: str | None,
306308
has_atom_ener: bool,
307309
append_detail: bool = False,
310+
output_latent_charge: bool = False,
308311
) -> tuple[list[np.ndarray], list[int]]:
309312
"""Test energy type model.
310313
@@ -402,32 +405,46 @@ def test_ener(
402405
efield=efield,
403406
mixed_type=mixed_type,
404407
spin=spin,
408+
output_latent_charge=output_latent_charge,
405409
)
406410
energy = ret[0]
407411
force = ret[1]
408412
virial = ret[2]
409413
energy = energy.reshape([numb_test, 1])
410414
force = force.reshape([numb_test, -1])
411415
virial = virial.reshape([numb_test, 9])
412-
if dp.has_hessian:
413-
hessian = ret[3]
414-
hessian = hessian.reshape([numb_test, -1])
416+
idx = 3
415417
if has_atom_ener:
416-
ae = ret[3]
417-
av = ret[4]
418+
ae = ret[idx]
419+
idx += 1
420+
av = ret[idx]
421+
idx += 1
418422
ae = ae.reshape([numb_test, -1])
419423
av = av.reshape([numb_test, -1])
420424
if dp.has_spin:
421-
force_m = ret[5]
425+
force_m = ret[idx]
426+
idx += 1
427+
mask_mag = ret[idx]
428+
idx += 1
422429
force_m = force_m.reshape([numb_test, -1])
423-
mask_mag = ret[6]
424430
mask_mag = mask_mag.reshape([numb_test, -1])
425431
else:
426432
if dp.has_spin:
427-
force_m = ret[3]
433+
force_m = ret[idx]
434+
idx += 1
435+
mask_mag = ret[idx]
436+
idx += 1
428437
force_m = force_m.reshape([numb_test, -1])
429-
mask_mag = ret[4]
430438
mask_mag = mask_mag.reshape([numb_test, -1])
439+
if dp.has_hessian:
440+
hessian = ret[idx]
441+
idx += 1
442+
hessian = hessian.reshape([numb_test, -1])
443+
latent_charge = None
444+
if output_latent_charge:
445+
latent_charge = ret[idx]
446+
idx += 1
447+
latent_charge = latent_charge.reshape([numb_test, -1])
431448
out_put_spin = dp.get_ntypes_spin() != 0 or dp.has_spin
432449
if out_put_spin:
433450
if dp.get_ntypes_spin() != 0: # old tf support for spin
@@ -659,6 +676,13 @@ def test_ener(
659676
header=f"{system}: data_h pred_h (3Na*3Na matrix in row-major order)",
660677
append=append_detail,
661678
)
679+
if output_latent_charge and latent_charge is not None:
680+
save_txt_file(
681+
detail_path.with_suffix(".q.out"),
682+
latent_charge,
683+
header=f"{system}: pred_q (latent charge per atom)",
684+
append=append_detail,
685+
)
662686

663687
return dict_to_return
664688

deepmd/infer/deep_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class DeepEvalBackend(ABC):
7878
"global_polar": "global_polar",
7979
"wfc": "wfc",
8080
"energy_derv_r_derv_r": "hessian",
81+
"latent_charge": "latent_charge",
8182
}
8283

8384
@abstractmethod

deepmd/infer/deep_pot.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,34 @@ def eval(
203203
nframes,
204204
natoms,
205205
) = self._standard_input(coords, cells, atom_types, fparam, aparam, mixed_type)
206+
output_latent_charge = kwargs.pop("output_latent_charge", False)
207+
extra_request_defs = []
208+
if output_latent_charge:
209+
# Try to get dim_out_lr from the model fitting net
210+
dim_out_lr = 1
211+
model = self.deep_eval.get_model()
212+
if hasattr(model, "atomic_model") and hasattr(
213+
model.atomic_model, "fitting_net"
214+
):
215+
dim_out_lr = getattr(model.atomic_model.fitting_net, "dim_out_lr", 1)
216+
extra_request_defs.append(
217+
OutputVariableDef(
218+
"latent_charge",
219+
shape=[dim_out_lr],
220+
reducible=False,
221+
r_differentiable=False,
222+
c_differentiable=False,
223+
atomic=True,
224+
)
225+
)
206226
results = self.deep_eval.eval(
207227
coords,
208228
cells,
209229
atom_types,
210230
atomic,
211231
fparam=fparam,
212232
aparam=aparam,
233+
extra_request_defs=extra_request_defs,
213234
**kwargs,
214235
)
215236
energy = results["energy_redu"].reshape(nframes, 1)
@@ -251,6 +272,9 @@ def eval(
251272
nframes, 3 * natoms, 3 * natoms
252273
)
253274
result = (*list(result), hessian)
275+
if output_latent_charge and "latent_charge" in results:
276+
latent_charge = results["latent_charge"].reshape(nframes, natoms, -1)
277+
result = (*result, latent_charge)
254278
return result
255279

256280

deepmd/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,12 @@ def main_parser() -> argparse.ArgumentParser:
454454
default=False,
455455
help="(Supported backend: PyTorch) Disable JIT compilation when loading the model.",
456456
)
457+
parser_tst.add_argument(
458+
"--output-latent-charge",
459+
action="store_true",
460+
default=False,
461+
help="Output latent charge predicted by SOG models to the detail file.",
462+
)
457463

458464
# * eval_desc script ***************************************************************
459465
parser_eval_desc = subparsers.add_parser(

deepmd/pt/infer/deep_eval.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ def eval(
384384
coords, atom_types, len(atom_types.shape) > 1
385385
)
386386
request_defs = self._get_request_defs(atomic)
387+
if "extra_request_defs" in kwargs:
388+
request_defs = request_defs + kwargs["extra_request_defs"]
387389
if "spin" not in kwargs or kwargs["spin"] is None:
388390
out = self._eval_func(self._eval_model, numb_test, natoms)(
389391
coords, cells, atom_types, fparam, aparam, request_defs

deepmd/pt/model/model/sog_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,8 @@ def forward(
485485
model_predict["mask"] = model_ret["mask"]
486486
if self._hessian_enabled:
487487
model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-2)
488+
if "latent_charge" in model_ret:
489+
model_predict["latent_charge"] = model_ret["latent_charge"]
488490
else:
489491
model_predict = model_ret
490492
model_predict["updated_coord"] += coord
@@ -528,6 +530,8 @@ def forward_lower(
528530
else:
529531
assert model_ret["dforce"] is not None
530532
model_predict["dforce"] = model_ret["dforce"]
533+
if "latent_charge" in model_ret:
534+
model_predict["latent_charge"] = model_ret["latent_charge"]
531535
else:
532536
model_predict = model_ret
533537
return model_predict

0 commit comments

Comments
 (0)