Skip to content

Commit c9101d1

Browse files
committed
fix ut
1 parent 34ad36d commit c9101d1

15 files changed

Lines changed: 64 additions & 0 deletions

deepmd/dpmodel/model/dipole_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def call(
4444
fparam: Array | None = None,
4545
aparam: Array | None = None,
4646
do_atomic_virial: bool = False,
47+
charge_spin: Array | None = None,
4748
) -> dict[str, Array]:
4849
model_ret = self.call_common(
4950
coord,
@@ -52,6 +53,7 @@ def call(
5253
fparam=fparam,
5354
aparam=aparam,
5455
do_atomic_virial=do_atomic_virial,
56+
charge_spin=charge_spin,
5557
)
5658
model_predict = {}
5759
model_predict["dipole"] = model_ret["dipole"]
@@ -75,6 +77,7 @@ def call_lower(
7577
fparam: Array | None = None,
7678
aparam: Array | None = None,
7779
do_atomic_virial: bool = False,
80+
charge_spin: Array | None = None,
7881
) -> dict[str, Array]:
7982
model_ret = self.call_common_lower(
8083
extended_coord,
@@ -84,6 +87,7 @@ def call_lower(
8487
fparam=fparam,
8588
aparam=aparam,
8689
do_atomic_virial=do_atomic_virial,
90+
charge_spin=charge_spin,
8791
)
8892
model_predict = {}
8993
model_predict["dipole"] = model_ret["dipole"]

deepmd/dpmodel/model/dos_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def call(
4444
fparam: Array | None = None,
4545
aparam: Array | None = None,
4646
do_atomic_virial: bool = False,
47+
charge_spin: Array | None = None,
4748
) -> dict[str, Array]:
4849
model_ret = self.call_common(
4950
coord,
@@ -52,6 +53,7 @@ def call(
5253
fparam=fparam,
5354
aparam=aparam,
5455
do_atomic_virial=do_atomic_virial,
56+
charge_spin=charge_spin,
5557
)
5658
model_predict = {}
5759
model_predict["atom_dos"] = model_ret["dos"]
@@ -69,6 +71,7 @@ def call_lower(
6971
fparam: Array | None = None,
7072
aparam: Array | None = None,
7173
do_atomic_virial: bool = False,
74+
charge_spin: Array | None = None,
7275
) -> dict[str, Array]:
7376
model_ret = self.call_common_lower(
7477
extended_coord,
@@ -78,6 +81,7 @@ def call_lower(
7881
fparam=fparam,
7982
aparam=aparam,
8083
do_atomic_virial=do_atomic_virial,
84+
charge_spin=charge_spin,
8185
)
8286
model_predict = {}
8387
model_predict["atom_dos"] = model_ret["dos"]

deepmd/dpmodel/model/dp_zbl_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def call(
4646
fparam: Array | None = None,
4747
aparam: Array | None = None,
4848
do_atomic_virial: bool = False,
49+
charge_spin: Array | None = None,
4950
) -> dict[str, Array]:
5051
model_ret = self.call_common(
5152
coord,
@@ -54,6 +55,7 @@ def call(
5455
fparam=fparam,
5556
aparam=aparam,
5657
do_atomic_virial=do_atomic_virial,
58+
charge_spin=charge_spin,
5759
)
5860
model_predict = {}
5961
model_predict["atom_energy"] = model_ret["energy"]
@@ -77,6 +79,7 @@ def call_lower(
7779
fparam: Array | None = None,
7880
aparam: Array | None = None,
7981
do_atomic_virial: bool = False,
82+
charge_spin: Array | None = None,
8083
) -> dict[str, Array]:
8184
model_ret = self.call_common_lower(
8285
extended_coord,
@@ -86,6 +89,7 @@ def call_lower(
8689
fparam=fparam,
8790
aparam=aparam,
8891
do_atomic_virial=do_atomic_virial,
92+
charge_spin=charge_spin,
8993
)
9094
model_predict = {}
9195
model_predict["atom_energy"] = model_ret["energy"]

deepmd/dpmodel/model/polar_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def call(
4444
fparam: Array | None = None,
4545
aparam: Array | None = None,
4646
do_atomic_virial: bool = False,
47+
charge_spin: Array | None = None,
4748
) -> dict[str, Array]:
4849
model_ret = self.call_common(
4950
coord,
@@ -52,6 +53,7 @@ def call(
5253
fparam=fparam,
5354
aparam=aparam,
5455
do_atomic_virial=do_atomic_virial,
56+
charge_spin=charge_spin,
5557
)
5658
model_predict = {}
5759
model_predict["polar"] = model_ret["polarizability"]
@@ -69,6 +71,7 @@ def call_lower(
6971
fparam: Array | None = None,
7072
aparam: Array | None = None,
7173
do_atomic_virial: bool = False,
74+
charge_spin: Array | None = None,
7275
) -> dict[str, Array]:
7376
model_ret = self.call_common_lower(
7477
extended_coord,
@@ -78,6 +81,7 @@ def call_lower(
7881
fparam=fparam,
7982
aparam=aparam,
8083
do_atomic_virial=do_atomic_virial,
84+
charge_spin=charge_spin,
8185
)
8286
model_predict = {}
8387
model_predict["polar"] = model_ret["polarizability"]

deepmd/dpmodel/model/property_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def call(
5151
fparam: Array | None = None,
5252
aparam: Array | None = None,
5353
do_atomic_virial: bool = False,
54+
charge_spin: Array | None = None,
5455
) -> dict[str, Array]:
5556
model_ret = self.call_common(
5657
coord,
@@ -59,6 +60,7 @@ def call(
5960
fparam=fparam,
6061
aparam=aparam,
6162
do_atomic_virial=do_atomic_virial,
63+
charge_spin=charge_spin,
6264
)
6365
var_name = self.get_var_name()
6466
model_predict = {}
@@ -77,6 +79,7 @@ def call_lower(
7779
fparam: Array | None = None,
7880
aparam: Array | None = None,
7981
do_atomic_virial: bool = False,
82+
charge_spin: Array | None = None,
8083
) -> dict[str, Array]:
8184
model_ret = self.call_common_lower(
8285
extended_coord,
@@ -86,6 +89,7 @@ def call_lower(
8689
fparam=fparam,
8790
aparam=aparam,
8891
do_atomic_virial=do_atomic_virial,
92+
charge_spin=charge_spin,
8993
)
9094
var_name = self.get_var_name()
9195
model_predict = {}

deepmd/dpmodel/model/spin_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def call_common(
579579
box: Array | None = None,
580580
fparam: Array | None = None,
581581
aparam: Array | None = None,
582+
charge_spin: Array | None = None,
582583
do_atomic_virial: bool = False,
583584
) -> dict[str, Array]:
584585
"""Return model prediction with raw internal keys.
@@ -624,6 +625,7 @@ def call_common(
624625
box,
625626
fparam=fparam,
626627
aparam=aparam,
628+
charge_spin=charge_spin,
627629
do_atomic_virial=do_atomic_virial,
628630
coord_corr_for_virial=coord_corr_for_virial,
629631
)
@@ -673,6 +675,7 @@ def call(
673675
box: Array | None = None,
674676
fparam: Array | None = None,
675677
aparam: Array | None = None,
678+
charge_spin: Array | None = None,
676679
do_atomic_virial: bool = False,
677680
) -> dict[str, Array]:
678681
"""Return model prediction with translated user-facing keys.
@@ -710,6 +713,7 @@ def call(
710713
box,
711714
fparam=fparam,
712715
aparam=aparam,
716+
charge_spin=charge_spin,
713717
do_atomic_virial=do_atomic_virial,
714718
)
715719
model_output_type = self.backbone_model.model_output_type()
@@ -747,6 +751,7 @@ def call_common_lower(
747751
mapping: Array | None = None,
748752
fparam: Array | None = None,
749753
aparam: Array | None = None,
754+
charge_spin: Array | None = None,
750755
do_atomic_virial: bool = False,
751756
) -> dict[str, Array]:
752757
"""Return model prediction with raw internal keys. Lower interface that takes
@@ -798,6 +803,7 @@ def call_common_lower(
798803
mapping=mapping_updated,
799804
fparam=fparam,
800805
aparam=aparam,
806+
charge_spin=charge_spin,
801807
do_atomic_virial=do_atomic_virial,
802808
extended_coord_corr=extended_coord_corr,
803809
)
@@ -851,6 +857,7 @@ def call_lower(
851857
mapping: Array | None = None,
852858
fparam: Array | None = None,
853859
aparam: Array | None = None,
860+
charge_spin: Array | None = None,
854861
do_atomic_virial: bool = False,
855862
) -> dict[str, Array]:
856863
"""Return model prediction with translated user-facing keys. Lower interface.
@@ -889,6 +896,7 @@ def call_lower(
889896
mapping=mapping,
890897
fparam=fparam,
891898
aparam=aparam,
899+
charge_spin=charge_spin,
892900
do_atomic_virial=do_atomic_virial,
893901
)
894902
model_output_type = self.backbone_model.model_output_type()

deepmd/pt/model/model/dipole_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def forward(
6060
fparam: torch.Tensor | None = None,
6161
aparam: torch.Tensor | None = None,
6262
do_atomic_virial: bool = False,
63+
charge_spin: torch.Tensor | None = None,
6364
) -> dict[str, torch.Tensor]:
6465
model_ret = self.forward_common(
6566
coord,
@@ -68,6 +69,7 @@ def forward(
6869
fparam=fparam,
6970
aparam=aparam,
7071
do_atomic_virial=do_atomic_virial,
72+
charge_spin=charge_spin,
7173
)
7274
if self.get_fitting_net() is not None:
7375
model_predict = {}
@@ -97,6 +99,7 @@ def forward_lower(
9799
aparam: torch.Tensor | None = None,
98100
do_atomic_virial: bool = False,
99101
comm_dict: dict[str, torch.Tensor] | None = None,
102+
charge_spin: torch.Tensor | None = None,
100103
) -> dict[str, torch.Tensor]:
101104
model_ret = self.forward_common_lower(
102105
extended_coord,
@@ -108,6 +111,7 @@ def forward_lower(
108111
do_atomic_virial=do_atomic_virial,
109112
comm_dict=comm_dict,
110113
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
114+
charge_spin=charge_spin,
111115
)
112116
if self.get_fitting_net() is not None:
113117
model_predict = {}

deepmd/pt/model/model/dos_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def forward(
5252
fparam: torch.Tensor | None = None,
5353
aparam: torch.Tensor | None = None,
5454
do_atomic_virial: bool = False,
55+
charge_spin: torch.Tensor | None = None,
5556
) -> dict[str, torch.Tensor]:
5657
model_ret = self.forward_common(
5758
coord,
@@ -60,6 +61,7 @@ def forward(
6061
fparam=fparam,
6162
aparam=aparam,
6263
do_atomic_virial=do_atomic_virial,
64+
charge_spin=charge_spin,
6365
)
6466
if self.get_fitting_net() is not None:
6567
model_predict = {}
@@ -89,6 +91,7 @@ def forward_lower(
8991
aparam: torch.Tensor | None = None,
9092
do_atomic_virial: bool = False,
9193
comm_dict: dict[str, torch.Tensor] | None = None,
94+
charge_spin: torch.Tensor | None = None,
9295
) -> dict[str, torch.Tensor]:
9396
model_ret = self.forward_common_lower(
9497
extended_coord,
@@ -100,6 +103,7 @@ def forward_lower(
100103
do_atomic_virial=do_atomic_virial,
101104
comm_dict=comm_dict,
102105
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
106+
charge_spin=charge_spin,
103107
)
104108
if self.get_fitting_net() is not None:
105109
model_predict = {}

deepmd/pt/model/model/dp_linear_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def forward(
6565
fparam: torch.Tensor | None = None,
6666
aparam: torch.Tensor | None = None,
6767
do_atomic_virial: bool = False,
68+
charge_spin: torch.Tensor | None = None,
6869
) -> dict[str, torch.Tensor]:
6970
model_ret = self.forward_common(
7071
coord,
@@ -73,6 +74,7 @@ def forward(
7374
fparam=fparam,
7475
aparam=aparam,
7576
do_atomic_virial=do_atomic_virial,
77+
charge_spin=charge_spin,
7678
)
7779

7880
model_predict = {}
@@ -101,6 +103,7 @@ def forward_lower(
101103
aparam: torch.Tensor | None = None,
102104
do_atomic_virial: bool = False,
103105
comm_dict: dict[str, torch.Tensor] | None = None,
106+
charge_spin: torch.Tensor | None = None,
104107
) -> dict[str, torch.Tensor]:
105108
model_ret = self.forward_common_lower(
106109
extended_coord,
@@ -112,6 +115,7 @@ def forward_lower(
112115
do_atomic_virial=do_atomic_virial,
113116
comm_dict=comm_dict,
114117
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
118+
charge_spin=charge_spin,
115119
)
116120

117121
model_predict = {}

deepmd/pt/model/model/dp_zbl_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def forward(
6262
fparam: torch.Tensor | None = None,
6363
aparam: torch.Tensor | None = None,
6464
do_atomic_virial: bool = False,
65+
charge_spin: torch.Tensor | None = None,
6566
) -> dict[str, torch.Tensor]:
6667
model_ret = self.forward_common(
6768
coord,
@@ -70,6 +71,7 @@ def forward(
7071
fparam=fparam,
7172
aparam=aparam,
7273
do_atomic_virial=do_atomic_virial,
74+
charge_spin=charge_spin,
7375
)
7476

7577
model_predict = {}
@@ -98,6 +100,7 @@ def forward_lower(
98100
aparam: torch.Tensor | None = None,
99101
do_atomic_virial: bool = False,
100102
comm_dict: dict[str, torch.Tensor] | None = None,
103+
charge_spin: torch.Tensor | None = None,
101104
) -> dict[str, torch.Tensor]:
102105
model_ret = self.forward_common_lower(
103106
extended_coord,
@@ -109,6 +112,7 @@ def forward_lower(
109112
do_atomic_virial=do_atomic_virial,
110113
comm_dict=comm_dict,
111114
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
115+
charge_spin=charge_spin,
112116
)
113117

114118
model_predict = {}

0 commit comments

Comments
 (0)