@@ -75,6 +75,7 @@ def test(
7575 detail_file : str ,
7676 atomic : bool ,
7777 head : str | None = None ,
78+ output_latent_charge : bool = False ,
7879 ** kwargs : Any ,
7980) -> None :
8081 """Test model predictions.
@@ -185,6 +186,7 @@ def test(
185186 detail_file ,
186187 atomic ,
187188 append_detail = (cc != 0 ),
189+ output_latent_charge = output_latent_charge ,
188190 )
189191 elif isinstance (dp , DeepDOS ):
190192 err = test_dos (
@@ -305,6 +307,7 @@ def test_ener(
305307 detail_file : str | None ,
306308 has_atom_ener : bool ,
307309 append_detail : bool = False ,
310+ output_latent_charge : bool = False ,
308311) -> tuple [list [np .ndarray ], list [int ]]:
309312 """Test energy type model.
310313
@@ -402,32 +405,46 @@ def test_ener(
402405 efield = efield ,
403406 mixed_type = mixed_type ,
404407 spin = spin ,
408+ output_latent_charge = output_latent_charge ,
405409 )
406410 energy = ret [0 ]
407411 force = ret [1 ]
408412 virial = ret [2 ]
409413 energy = energy .reshape ([numb_test , 1 ])
410414 force = force .reshape ([numb_test , - 1 ])
411415 virial = virial .reshape ([numb_test , 9 ])
412- if dp .has_hessian :
413- hessian = ret [3 ]
414- hessian = hessian .reshape ([numb_test , - 1 ])
416+ idx = 3
415417 if has_atom_ener :
416- ae = ret [3 ]
417- av = ret [4 ]
418+ ae = ret [idx ]
419+ idx += 1
420+ av = ret [idx ]
421+ idx += 1
418422 ae = ae .reshape ([numb_test , - 1 ])
419423 av = av .reshape ([numb_test , - 1 ])
420424 if dp .has_spin :
421- force_m = ret [5 ]
425+ force_m = ret [idx ]
426+ idx += 1
427+ mask_mag = ret [idx ]
428+ idx += 1
422429 force_m = force_m .reshape ([numb_test , - 1 ])
423- mask_mag = ret [6 ]
424430 mask_mag = mask_mag .reshape ([numb_test , - 1 ])
425431 else :
426432 if dp .has_spin :
427- force_m = ret [3 ]
433+ force_m = ret [idx ]
434+ idx += 1
435+ mask_mag = ret [idx ]
436+ idx += 1
428437 force_m = force_m .reshape ([numb_test , - 1 ])
429- mask_mag = ret [4 ]
430438 mask_mag = mask_mag .reshape ([numb_test , - 1 ])
439+ if dp .has_hessian :
440+ hessian = ret [idx ]
441+ idx += 1
442+ hessian = hessian .reshape ([numb_test , - 1 ])
443+ latent_charge = None
444+ if output_latent_charge :
445+ latent_charge = ret [idx ]
446+ idx += 1
447+ latent_charge = latent_charge .reshape ([numb_test , - 1 ])
431448 out_put_spin = dp .get_ntypes_spin () != 0 or dp .has_spin
432449 if out_put_spin :
433450 if dp .get_ntypes_spin () != 0 : # old tf support for spin
@@ -659,6 +676,13 @@ def test_ener(
659676 header = f"{ system } : data_h pred_h (3Na*3Na matrix in row-major order)" ,
660677 append = append_detail ,
661678 )
679+ if output_latent_charge and latent_charge is not None :
680+ save_txt_file (
681+ detail_path .with_suffix (".q.out" ),
682+ latent_charge ,
683+ header = f"{ system } : pred_q (latent charge per atom)" ,
684+ append = append_detail ,
685+ )
662686
663687 return dict_to_return
664688
0 commit comments