Skip to content

Commit 17b4c8a

Browse files
committed
make charge_spin last-input
1 parent e98b6a8 commit 17b4c8a

16 files changed

Lines changed: 45 additions & 45 deletions

deepmd/dpmodel/model/ener_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def call(
8686
box: Array | None = None,
8787
fparam: Array | None = None,
8888
aparam: Array | None = None,
89-
charge_spin: Array | None = None,
9089
do_atomic_virial: bool = False,
90+
charge_spin: Array | None = None,
9191
) -> dict[str, Array]:
9292
model_ret = self.call_common(
9393
coord,
@@ -121,8 +121,8 @@ def call_lower(
121121
mapping: Array | None = None,
122122
fparam: Array | None = None,
123123
aparam: Array | None = None,
124-
charge_spin: Array | None = None,
125124
do_atomic_virial: bool = False,
125+
charge_spin: Array | None = None,
126126
) -> dict[str, Array]:
127127
model_ret = self.call_common_lower(
128128
extended_coord,

deepmd/dpmodel/model/spin_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,8 +579,8 @@ def call_common(
579579
box: Array | None = None,
580580
fparam: Array | None = None,
581581
aparam: Array | None = None,
582-
charge_spin: Array | None = None,
583582
do_atomic_virial: bool = False,
583+
charge_spin: Array | None = None,
584584
) -> dict[str, Array]:
585585
"""Return model prediction with raw internal keys.
586586
@@ -675,8 +675,8 @@ def call(
675675
box: Array | None = None,
676676
fparam: Array | None = None,
677677
aparam: Array | None = None,
678-
charge_spin: Array | None = None,
679678
do_atomic_virial: bool = False,
679+
charge_spin: Array | None = None,
680680
) -> dict[str, Array]:
681681
"""Return model prediction with translated user-facing keys.
682682
@@ -751,8 +751,8 @@ def call_common_lower(
751751
mapping: Array | None = None,
752752
fparam: Array | None = None,
753753
aparam: Array | None = None,
754-
charge_spin: Array | None = None,
755754
do_atomic_virial: bool = False,
755+
charge_spin: Array | None = None,
756756
) -> dict[str, Array]:
757757
"""Return model prediction with raw internal keys. Lower interface that takes
758758
extended atomic coordinates, types and spins, nlist, and mapping
@@ -857,8 +857,8 @@ def call_lower(
857857
mapping: Array | None = None,
858858
fparam: Array | None = None,
859859
aparam: Array | None = None,
860-
charge_spin: Array | None = None,
861860
do_atomic_virial: bool = False,
861+
charge_spin: Array | None = None,
862862
) -> dict[str, Array]:
863863
"""Return model prediction with translated user-facing keys. Lower interface.
864864

deepmd/pt/infer/deep_eval.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def eval(
408408
request_defs = self._get_request_defs(atomic)
409409
if "spin" not in kwargs or kwargs["spin"] is None:
410410
out = self._eval_func(self._eval_model, numb_test, natoms)(
411-
coords, cells, atom_types, fparam, aparam, charge_spin, request_defs
411+
coords, cells, atom_types, fparam, aparam, request_defs, charge_spin
412412
)
413413
else:
414414
out = self._eval_func(self._eval_model_spin, numb_test, natoms)(
@@ -418,8 +418,8 @@ def eval(
418418
np.array(kwargs["spin"]),
419419
fparam,
420420
aparam,
421-
charge_spin,
422421
request_defs,
422+
charge_spin,
423423
)
424424
return dict(
425425
zip(
@@ -520,8 +520,8 @@ def _eval_model(
520520
atom_types: np.ndarray,
521521
fparam: np.ndarray | None,
522522
aparam: np.ndarray | None,
523-
charge_spin: np.ndarray | None,
524523
request_defs: list[OutputVariableDef],
524+
charge_spin: np.ndarray | None,
525525
) -> tuple[np.ndarray, ...]:
526526
model = self.dp.to(DEVICE)
527527
prec = NP_PRECISION_DICT[RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION]]
@@ -604,8 +604,8 @@ def _eval_model_spin(
604604
spins: np.ndarray,
605605
fparam: np.ndarray | None,
606606
aparam: np.ndarray | None,
607-
charge_spin: np.ndarray | None,
608607
request_defs: list[OutputVariableDef],
608+
charge_spin: np.ndarray | None,
609609
) -> tuple[np.ndarray, ...]:
610610
model = self.dp.to(DEVICE)
611611

deepmd/pt/model/model/spin_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,8 @@ def forward_common(
526526
box: torch.Tensor | None = None,
527527
fparam: torch.Tensor | None = None,
528528
aparam: torch.Tensor | None = None,
529-
charge_spin: torch.Tensor | None = None,
530529
do_atomic_virial: bool = False,
530+
charge_spin: torch.Tensor | None = None,
531531
) -> dict[str, torch.Tensor]:
532532
nframes, nloc = atype.shape
533533
coord_updated, atype_updated, coord_corr_for_virial = self.process_spin_input(
@@ -580,10 +580,10 @@ def forward_common_lower(
580580
mapping: torch.Tensor | None = None,
581581
fparam: torch.Tensor | None = None,
582582
aparam: torch.Tensor | None = None,
583-
charge_spin: torch.Tensor | None = None,
584583
do_atomic_virial: bool = False,
585584
comm_dict: dict[str, torch.Tensor] | None = None,
586585
extra_nlist_sort: bool = False,
586+
charge_spin: torch.Tensor | None = None,
587587
) -> dict[str, torch.Tensor]:
588588
nframes, nloc = nlist.shape[:2]
589589
(
@@ -699,8 +699,8 @@ def forward(
699699
box: torch.Tensor | None = None,
700700
fparam: torch.Tensor | None = None,
701701
aparam: torch.Tensor | None = None,
702-
charge_spin: torch.Tensor | None = None,
703702
do_atomic_virial: bool = False,
703+
charge_spin: torch.Tensor | None = None,
704704
) -> dict[str, torch.Tensor]:
705705
model_ret = self.forward_common(
706706
coord,
@@ -735,9 +735,9 @@ def forward_lower(
735735
mapping: torch.Tensor | None = None,
736736
fparam: torch.Tensor | None = None,
737737
aparam: torch.Tensor | None = None,
738-
charge_spin: torch.Tensor | None = None,
739738
do_atomic_virial: bool = False,
740739
comm_dict: dict[str, torch.Tensor] | None = None,
740+
charge_spin: torch.Tensor | None = None,
741741
) -> dict[str, torch.Tensor]:
742742
model_ret = self.forward_common_lower(
743743
extended_coord,

deepmd/pt_expt/model/dipole_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def forward(
4444
box: torch.Tensor | None = None,
4545
fparam: torch.Tensor | None = None,
4646
aparam: torch.Tensor | None = None,
47-
charge_spin: torch.Tensor | None = None,
4847
do_atomic_virial: bool = False,
48+
charge_spin: torch.Tensor | None = None,
4949
) -> dict[str, torch.Tensor]:
5050
model_ret = self.call_common(
5151
coord,
@@ -77,8 +77,8 @@ def forward_lower(
7777
mapping: torch.Tensor | None = None,
7878
fparam: torch.Tensor | None = None,
7979
aparam: torch.Tensor | None = None,
80-
charge_spin: torch.Tensor | None = None,
8180
do_atomic_virial: bool = False,
81+
charge_spin: torch.Tensor | None = None,
8282
) -> dict[str, torch.Tensor]:
8383
model_ret = self.call_common_lower(
8484
extended_coord,
@@ -129,8 +129,8 @@ def forward_lower_exportable(
129129
mapping: torch.Tensor | None = None,
130130
fparam: torch.Tensor | None = None,
131131
aparam: torch.Tensor | None = None,
132-
charge_spin: torch.Tensor | None = None,
133132
do_atomic_virial: bool = False,
133+
charge_spin: torch.Tensor | None = None,
134134
**make_fx_kwargs: Any,
135135
) -> torch.nn.Module:
136136
model = self

deepmd/pt_expt/model/dos_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def forward(
4444
box: torch.Tensor | None = None,
4545
fparam: torch.Tensor | None = None,
4646
aparam: torch.Tensor | None = None,
47-
charge_spin: torch.Tensor | None = None,
4847
do_atomic_virial: bool = False,
48+
charge_spin: torch.Tensor | None = None,
4949
) -> dict[str, torch.Tensor]:
5050
model_ret = self.call_common(
5151
coord,
@@ -71,8 +71,8 @@ def forward_lower(
7171
mapping: torch.Tensor | None = None,
7272
fparam: torch.Tensor | None = None,
7373
aparam: torch.Tensor | None = None,
74-
charge_spin: torch.Tensor | None = None,
7574
do_atomic_virial: bool = False,
75+
charge_spin: torch.Tensor | None = None,
7676
) -> dict[str, torch.Tensor]:
7777
model_ret = self.call_common_lower(
7878
extended_coord,
@@ -109,8 +109,8 @@ def forward_lower_exportable(
109109
mapping: torch.Tensor | None = None,
110110
fparam: torch.Tensor | None = None,
111111
aparam: torch.Tensor | None = None,
112-
charge_spin: torch.Tensor | None = None,
113112
do_atomic_virial: bool = False,
113+
charge_spin: torch.Tensor | None = None,
114114
**make_fx_kwargs: Any,
115115
) -> torch.nn.Module:
116116
model = self

deepmd/pt_expt/model/dp_linear_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def forward(
4747
box: torch.Tensor | None = None,
4848
fparam: torch.Tensor | None = None,
4949
aparam: torch.Tensor | None = None,
50-
charge_spin: torch.Tensor | None = None,
5150
do_atomic_virial: bool = False,
51+
charge_spin: torch.Tensor | None = None,
5252
) -> dict[str, torch.Tensor]:
5353
model_ret = self.call_common(
5454
coord,
@@ -80,8 +80,8 @@ def forward_lower(
8080
mapping: torch.Tensor | None = None,
8181
fparam: torch.Tensor | None = None,
8282
aparam: torch.Tensor | None = None,
83-
charge_spin: torch.Tensor | None = None,
8483
do_atomic_virial: bool = False,
84+
charge_spin: torch.Tensor | None = None,
8585
) -> dict[str, torch.Tensor]:
8686
model_ret = self.call_common_lower(
8787
extended_coord,
@@ -134,8 +134,8 @@ def forward_lower_exportable(
134134
mapping: torch.Tensor | None = None,
135135
fparam: torch.Tensor | None = None,
136136
aparam: torch.Tensor | None = None,
137-
charge_spin: torch.Tensor | None = None,
138137
do_atomic_virial: bool = False,
138+
charge_spin: torch.Tensor | None = None,
139139
**make_fx_kwargs: Any,
140140
) -> torch.nn.Module:
141141
model = self

deepmd/pt_expt/model/dp_zbl_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def forward(
4444
box: torch.Tensor | None = None,
4545
fparam: torch.Tensor | None = None,
4646
aparam: torch.Tensor | None = None,
47-
charge_spin: torch.Tensor | None = None,
4847
do_atomic_virial: bool = False,
48+
charge_spin: torch.Tensor | None = None,
4949
) -> dict[str, torch.Tensor]:
5050
model_ret = self.call_common(
5151
coord,
@@ -77,8 +77,8 @@ def forward_lower(
7777
mapping: torch.Tensor | None = None,
7878
fparam: torch.Tensor | None = None,
7979
aparam: torch.Tensor | None = None,
80-
charge_spin: torch.Tensor | None = None,
8180
do_atomic_virial: bool = False,
81+
charge_spin: torch.Tensor | None = None,
8282
) -> dict[str, torch.Tensor]:
8383
model_ret = self.call_common_lower(
8484
extended_coord,
@@ -131,8 +131,8 @@ def forward_lower_exportable(
131131
mapping: torch.Tensor | None = None,
132132
fparam: torch.Tensor | None = None,
133133
aparam: torch.Tensor | None = None,
134-
charge_spin: torch.Tensor | None = None,
135134
do_atomic_virial: bool = False,
135+
charge_spin: torch.Tensor | None = None,
136136
**make_fx_kwargs: Any,
137137
) -> torch.nn.Module:
138138
model = self

deepmd/pt_expt/model/ener_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def forward(
5757
box: torch.Tensor | None = None,
5858
fparam: torch.Tensor | None = None,
5959
aparam: torch.Tensor | None = None,
60-
charge_spin: torch.Tensor | None = None,
6160
do_atomic_virial: bool = False,
61+
charge_spin: torch.Tensor | None = None,
6262
) -> dict[str, torch.Tensor]:
6363
model_ret = self.call_common(
6464
coord,
@@ -92,8 +92,8 @@ def forward_lower(
9292
mapping: torch.Tensor | None = None,
9393
fparam: torch.Tensor | None = None,
9494
aparam: torch.Tensor | None = None,
95-
charge_spin: torch.Tensor | None = None,
9695
do_atomic_virial: bool = False,
96+
charge_spin: torch.Tensor | None = None,
9797
) -> dict[str, torch.Tensor]:
9898
model_ret = self.call_common_lower(
9999
extended_coord,
@@ -148,8 +148,8 @@ def forward_lower_exportable(
148148
mapping: torch.Tensor | None = None,
149149
fparam: torch.Tensor | None = None,
150150
aparam: torch.Tensor | None = None,
151-
charge_spin: torch.Tensor | None = None,
152151
do_atomic_virial: bool = False,
152+
charge_spin: torch.Tensor | None = None,
153153
**make_fx_kwargs: Any,
154154
) -> torch.nn.Module:
155155
"""Trace ``forward_lower`` into an exportable module.

deepmd/pt_expt/model/make_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def _cal_hessian_ext(
6161
mapping: torch.Tensor | None,
6262
fparam: torch.Tensor | None,
6363
aparam: torch.Tensor | None,
64-
charge_spin: torch.Tensor | None = None,
6564
create_graph: bool = False,
65+
charge_spin: torch.Tensor | None = None,
6666
) -> torch.Tensor:
6767
"""Compute hessian of reduced output w.r.t. extended coordinates.
6868
@@ -285,9 +285,9 @@ def forward_common_atomic(
285285
mapping: torch.Tensor | None = None,
286286
fparam: torch.Tensor | None = None,
287287
aparam: torch.Tensor | None = None,
288-
charge_spin: torch.Tensor | None = None,
289288
do_atomic_virial: bool = False,
290289
extended_coord_corr: torch.Tensor | None = None,
290+
charge_spin: torch.Tensor | None = None,
291291
) -> dict[str, torch.Tensor]:
292292
atomic_ret = self.atomic_model.forward_common_atomic(
293293
extended_coord,
@@ -338,8 +338,8 @@ def forward_common_lower_exportable(
338338
mapping: torch.Tensor | None = None,
339339
fparam: torch.Tensor | None = None,
340340
aparam: torch.Tensor | None = None,
341-
charge_spin: torch.Tensor | None = None,
342341
do_atomic_virial: bool = False,
342+
charge_spin: torch.Tensor | None = None,
343343
**make_fx_kwargs: Any,
344344
) -> torch.nn.Module:
345345
"""Trace ``forward_common_lower`` into an exportable module.

0 commit comments

Comments
 (0)