Skip to content

Commit 6f6493d

Browse files
committed
debug for rebase
1 parent 29ed198 commit 6f6493d

4 files changed

Lines changed: 36 additions & 2 deletions

File tree

deepmd/pt/model/model/density_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def forward(
6060
fparam: Optional[torch.Tensor] = None,
6161
aparam: Optional[torch.Tensor] = None,
6262
do_atomic_virial: bool = False,
63+
charge_spin: Optional[torch.Tensor] = None,
6364
) -> dict[str, torch.Tensor]:
6465
model_ret = self.forward_common(
6566
coord,
@@ -87,5 +88,6 @@ def forward_lower(
8788
aparam: Optional[torch.Tensor] = None,
8889
do_atomic_virial: bool = False,
8990
comm_dict: Optional[dict[str, torch.Tensor]] = None,
91+
charge_spin: Optional[torch.Tensor] = None,
9092
):
9193
raise NotImplementedError

deepmd/pt/model/model/make_density_model.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,26 @@ def model_output_type(self) -> list[str]:
101101
vars.append(kk)
102102
return vars
103103

104+
@torch.jit.export
105+
def has_chg_spin_ebd(self) -> bool:
106+
"""Check if the model has charge spin embedding."""
107+
return self.atomic_model.has_chg_spin_ebd()
108+
109+
@torch.jit.export
110+
def get_dim_chg_spin(self) -> int:
111+
"""Get the dimension of charge_spin input."""
112+
return self.atomic_model.get_dim_chg_spin()
113+
114+
@torch.jit.export
115+
def has_default_chg_spin(self) -> bool:
116+
"""Check if the model has default charge_spin values."""
117+
return self.atomic_model.has_default_chg_spin()
118+
119+
@torch.jit.export
120+
def get_default_chg_spin(self) -> torch.Tensor | None:
121+
"""Get the default charge_spin values."""
122+
return self.atomic_model.get_default_chg_spin()
123+
104124
# cannot use the name forward. torch script does not work
105125
def forward_common(
106126
self,
@@ -111,6 +131,7 @@ def forward_common(
111131
fparam: Optional[torch.Tensor] = None,
112132
aparam: Optional[torch.Tensor] = None,
113133
do_atomic_virial: bool = False,
134+
charge_spin: Optional[torch.Tensor] = None,
114135
) -> dict[str, torch.Tensor]:
115136
"""Return model prediction.
116137
@@ -145,6 +166,7 @@ def forward_common(
145166
coord, grid, box=box, fparam=fparam, aparam=aparam
146167
)
147168
del coord, grid, box, fparam, aparam
169+
gg = gg.view(gg.shape[0], -1, 3)
148170
(
149171
extended_coord,
150172
extended_atype,
@@ -158,7 +180,7 @@ def forward_common(
158180
mixed_types=self.mixed_types(),
159181
box=bb,
160182
)
161-
grid_type = torch.zeros(gg.shape[:-1], device=gg.device, dtype=atype.dtype)
183+
grid_type = torch.zeros(gg.shape[0], gg.shape[1], device=gg.device, dtype=atype.dtype)
162184
grid_nlist = build_directional_neighbor_list(
163185
gg,
164186
grid_type,
@@ -233,6 +255,7 @@ def forward_common_lower(
233255
mapping: Optional[torch.Tensor] = None,
234256
fparam: Optional[torch.Tensor] = None,
235257
aparam: Optional[torch.Tensor] = None,
258+
charge_spin: Optional[torch.Tensor] = None,
236259
do_atomic_virial: bool = False,
237260
comm_dict: Optional[dict[str, torch.Tensor]] = None,
238261
extra_nlist_sort: bool = False,
@@ -572,9 +595,14 @@ def compute_or_load_stat(
572595
self,
573596
sampled_func,
574597
stat_file_path: Optional[DPPath] = None,
598+
preset_observed_type: list[str] | None = None,
575599
):
576600
"""Compute or load the statistics."""
577-
return self.atomic_model.compute_or_load_stat(sampled_func, stat_file_path)
601+
return self.atomic_model.compute_or_load_stat(
602+
sampled_func,
603+
stat_file_path,
604+
preset_observed_type=preset_observed_type,
605+
)
578606

579607
def get_sel(self) -> list[int]:
580608
"""Returns the number of selected atoms for each type."""

deepmd/pt/train/wrapper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import logging
33
from typing import (
44
Any,
5+
Optional,
6+
Union,
57
)
68

79
import torch

deepmd/utils/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,8 @@ def _load_single_data(
966966
data = mmap_obj[frame_idx].copy().astype(dtype, copy=False)
967967

968968
try:
969+
if key in ("grid", "density"):
970+
return np.float32(1.0), data
969971
if vv["atomic"]:
970972
# Handle type_sel logic
971973
if vv["type_sel"] is not None:

0 commit comments

Comments
 (0)