Skip to content

Commit 86922bc

Browse files
committed
fix comments
1 parent 4aa0f11 commit 86922bc

6 files changed

Lines changed: 78 additions & 24 deletions

File tree

deepmd/dpmodel/descriptor/hybrid.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def call(
349349
atype_ext,
350350
nl,
351351
mapping,
352+
fparam=fparam,
352353
comm_dict=comm_dict,
353354
charge_spin=charge_spin,
354355
)

deepmd/dpmodel/model/make_model.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,10 @@ def call_common(
298298
The keys are defined by the `ModelOutputDef`.
299299
300300
"""
301-
cc, bb, fp, ap, input_prec = self._input_type_cast(
302-
coord, box=box, fparam=fparam, aparam=aparam
301+
cc, bb, fp, ap, cs, input_prec = self._input_type_cast(
302+
coord, box=box, fparam=fparam, aparam=aparam, charge_spin=charge_spin
303303
)
304-
del coord, box, fparam, aparam
304+
del coord, box, fparam, aparam, charge_spin
305305
model_predict = model_call_from_call_lower(
306306
call_lower=self.call_common_lower,
307307
rcut=self.get_rcut(),
@@ -315,7 +315,7 @@ def call_common(
315315
aparam=ap,
316316
do_atomic_virial=do_atomic_virial,
317317
coord_corr_for_virial=coord_corr_for_virial,
318-
charge_spin=charge_spin,
318+
charge_spin=cs,
319319
)
320320
model_predict = self._output_type_cast(model_predict, input_prec)
321321
return model_predict
@@ -377,10 +377,10 @@ def call_common_lower(
377377
nlist,
378378
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
379379
)
380-
cc_ext, _, fp, ap, input_prec = self._input_type_cast(
381-
extended_coord, fparam=fparam, aparam=aparam
380+
cc_ext, _, fp, ap, cs, input_prec = self._input_type_cast(
381+
extended_coord, fparam=fparam, aparam=aparam, charge_spin=charge_spin
382382
)
383-
del extended_coord, fparam, aparam
383+
del extended_coord, fparam, aparam, charge_spin
384384
model_predict = self.forward_common_atomic(
385385
cc_ext,
386386
extended_atype,
@@ -391,7 +391,7 @@ def call_common_lower(
391391
do_atomic_virial=do_atomic_virial,
392392
extended_coord_corr=extended_coord_corr,
393393
comm_dict=comm_dict,
394-
charge_spin=charge_spin,
394+
charge_spin=cs,
395395
)
396396
model_predict = self._output_type_cast(model_predict, input_prec)
397397
return model_predict
@@ -482,7 +482,8 @@ def _input_type_cast(
482482
box: Array | None = None,
483483
fparam: Array | None = None,
484484
aparam: Array | None = None,
485-
) -> tuple[Array, Array | None, Array | None, Array | None, Any]:
485+
charge_spin: Array | None = None,
486+
) -> tuple[Array, Array | None, Array | None, Array | None, Array | None, Any]:
486487
"""Cast the input data to global float type."""
487488
xp = array_api_compat.array_namespace(coord)
488489
input_dtype = coord.dtype
@@ -494,17 +495,20 @@ def _input_type_cast(
494495
###
495496
_lst: list[Array | None] = [
496497
xp.astype(vv, input_dtype) if vv is not None else None
497-
for vv in [box, fparam, aparam]
498+
for vv in [box, fparam, aparam, charge_spin]
498499
]
499-
box, fparam, aparam = _lst
500+
box, fparam, aparam, charge_spin = _lst
500501
if input_dtype == global_dtype:
501-
return coord, box, fparam, aparam, input_dtype
502+
return coord, box, fparam, aparam, charge_spin, input_dtype
502503
else:
503504
return (
504505
xp.astype(coord, global_dtype),
505506
xp.astype(box, global_dtype) if box is not None else None,
506507
xp.astype(fparam, global_dtype) if fparam is not None else None,
507508
xp.astype(aparam, global_dtype) if aparam is not None else None,
509+
xp.astype(charge_spin, global_dtype)
510+
if charge_spin is not None
511+
else None,
508512
input_dtype,
509513
)
510514

deepmd/jax/model/base_model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def eval_output(
6868
mapping: jnp.ndarray | None,
6969
fparam: jnp.ndarray | None,
7070
aparam: jnp.ndarray | None,
71+
charge_spin_: jnp.ndarray | None,
7172
*,
7273
_kk: str = kk,
7374
_atom_axis: int = atom_axis,
@@ -79,7 +80,9 @@ def eval_output(
7980
mapping=mapping[None, ...] if mapping is not None else None,
8081
fparam=fparam[None, ...] if fparam is not None else None,
8182
aparam=aparam[None, ...] if aparam is not None else None,
82-
charge_spin=charge_spin,
83+
charge_spin=charge_spin_[None, ...]
84+
if charge_spin_ is not None
85+
else None,
8386
)
8487
return jnp.sum(atomic_ret[_kk][0], axis=_atom_axis)
8588

@@ -92,6 +95,7 @@ def eval_output(
9295
mapping,
9396
fparam,
9497
aparam,
98+
charge_spin,
9599
)
96100
# extended_force: [nf, nall, *def, 3]
97101
def_ndim = len(vdef.shape)
@@ -109,6 +113,7 @@ def eval_output(
109113
mapping,
110114
fparam,
111115
aparam,
116+
charge_spin,
112117
)
113118
kk_hessian = get_hessian_name(kk)
114119
model_predict[kk_hessian] = hessian
@@ -130,6 +135,7 @@ def eval_ce(
130135
mapping: jnp.ndarray | None,
131136
fparam: jnp.ndarray | None,
132137
aparam: jnp.ndarray | None,
138+
charge_spin_: jnp.ndarray | None,
133139
*,
134140
_kk: str = kk,
135141
_atom_axis: int = atom_axis - 1,
@@ -142,7 +148,9 @@ def eval_ce(
142148
mapping=mapping[None, ...] if mapping is not None else None,
143149
fparam=fparam[None, ...] if fparam is not None else None,
144150
aparam=aparam[None, ...] if aparam is not None else None,
145-
charge_spin=charge_spin,
151+
charge_spin=charge_spin_[None, ...]
152+
if charge_spin_ is not None
153+
else None,
146154
)
147155
nloc = nlist.shape[0]
148156
cc_loc = jax.lax.stop_gradient(cc_ext)[:nloc, ...]
@@ -160,6 +168,7 @@ def eval_ce(
160168
mapping,
161169
fparam,
162170
aparam,
171+
charge_spin,
163172
)
164173
# move the first 3 to the last
165174
# [nf, *def, nall, 3, 3]

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,14 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any:
204204

205205
if self.add_chg_spin_ebd:
206206
self.act = ActivationFn(activation_function)
207-
# -100 ~ 100 is a conservative bound
207+
# charge range [-100, 99] mapped to indices [0, 199]
208208
self.chg_embedding = TypeEmbedNet(
209209
200,
210210
self.tebd_dim,
211211
precision=precision,
212212
seed=child_seed(seed, 3),
213213
)
214-
# 100 is a conservative upper bound
214+
# spin range [0, 99] mapped to indices [0, 99]
215215
self.spin_embedding = TypeEmbedNet(
216216
100,
217217
self.tebd_dim,
@@ -588,9 +588,19 @@ def forward(
588588
assert charge_spin is not None
589589
assert self.chg_embedding is not None
590590
assert self.spin_embedding is not None
591-
charge = charge_spin[:, 0].to(dtype=torch.int64) + 100
591+
charge = charge_spin[:, 0].to(dtype=torch.int64)
592592
spin = charge_spin[:, 1].to(dtype=torch.int64)
593-
chg_ebd = self.chg_embedding(charge)
593+
# Validate charge range [-100, 99] (200 embedding entries)
594+
if torch.any(charge < -100) or torch.any(charge > 99):
595+
raise ValueError(
596+
f"charge must be in range [-100, 99], got min={charge.min().item()}, max={charge.max().item()}"
597+
)
598+
# Validate spin range [0, 99] (100 embedding entries)
599+
if torch.any(spin < 0) or torch.any(spin >= 100):
600+
raise ValueError(
601+
f"spin must be in range [0, 99], got min={spin.min().item()}, max={spin.max().item()}"
602+
)
603+
chg_ebd = self.chg_embedding(charge + 100)
594604
spin_ebd = self.spin_embedding(spin)
595605
sys_cs_embd = self.act(
596606
self.mix_cs_mlp(torch.cat((chg_ebd, spin_ebd), dim=-1))

deepmd/pt/model/descriptor/hybrid.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,20 @@ def __init__(
101101

102102
def get_dim_chg_spin(self) -> int:
103103
"""Returns the dimension of charge_spin input (0 if not supported)."""
104-
return 0
104+
return max(
105+
(descrpt.get_dim_chg_spin() for descrpt in self.descrpt_list), default=0
106+
)
105107

106108
def has_default_chg_spin(self) -> bool:
107109
"""Returns whether the descriptor has a default charge_spin value."""
108-
return False
110+
return any(descrpt.has_default_chg_spin() for descrpt in self.descrpt_list)
109111

110-
def get_default_chg_spin(self) -> None:
112+
def get_default_chg_spin(self) -> list[float] | None:
111113
"""Returns the default charge_spin value, or None."""
114+
for descrpt in self.descrpt_list:
115+
default = descrpt.get_default_chg_spin()
116+
if default is not None:
117+
return default
112118
return None
113119

114120
def get_rcut(self) -> float:
@@ -346,7 +352,13 @@ def forward(
346352
:, :, self.nlist_cut_idx[ii].to(atype_ext.device)
347353
]
348354
odescriptor, gr, g2, h2, sw = descrpt(
349-
coord_ext, atype_ext, nl, mapping, charge_spin=charge_spin
355+
coord_ext,
356+
atype_ext,
357+
nl,
358+
mapping,
359+
comm_dict=comm_dict,
360+
fparam=fparam,
361+
charge_spin=charge_spin,
350362
)
351363
out_descriptor.append(odescriptor)
352364
if gr is not None:

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,12 @@ def eval(
624624
aparam
625625
The atomic parameter.
626626
The array should be of size nframes x natoms x dim_aparam.
627+
charge_spin
628+
The charge and spin values for each frame.
629+
The array should be of size nframes x 2, where the first column is charge
630+
and the second column is spin. If the model has add_chg_spin_ebd=True and
631+
no default_chg_spin is set, this parameter is required. If default_chg_spin
632+
is configured, this parameter is optional and will override the default.
627633
**kwargs
628634
Other parameters
629635
@@ -1056,8 +1062,14 @@ def _prepare_inputs(
10561062

10571063
# charge_spin handling: dedicated input, separate from fparam.
10581064
if charge_spin is not None:
1065+
charge_spin_arr = np.asarray(charge_spin)
1066+
if charge_spin_arr.shape != (nframes, 2):
1067+
raise ValueError(
1068+
f"charge_spin must have shape (nframes, 2), got {charge_spin_arr.shape}. "
1069+
f"Expected ({nframes}, 2) for {nframes} frame(s)."
1070+
)
10591071
charge_spin_t = torch.tensor(
1060-
np.asarray(charge_spin).reshape(nframes, 2),
1072+
charge_spin_arr,
10611073
dtype=torch.float64,
10621074
device=DEVICE,
10631075
)
@@ -1272,8 +1284,14 @@ def _eval_model_spin(
12721284

12731285
# charge_spin handling: dedicated input, separate from fparam.
12741286
if charge_spin is not None:
1287+
charge_spin_arr = np.asarray(charge_spin)
1288+
if charge_spin_arr.shape != (nframes, 2):
1289+
raise ValueError(
1290+
f"charge_spin must have shape (nframes, 2), got {charge_spin_arr.shape}. "
1291+
f"Expected ({nframes}, 2) for {nframes} frame(s)."
1292+
)
12751293
charge_spin_t = torch.tensor(
1276-
np.asarray(charge_spin).reshape(nframes, 2),
1294+
charge_spin_arr,
12771295
dtype=torch.float64,
12781296
device=DEVICE,
12791297
)

0 commit comments

Comments
 (0)