Skip to content

Commit 44f9cb9

Browse files
committed
fix jit
1 parent 86922bc commit 44f9cb9

2 files changed

Lines changed: 37 additions & 8 deletions

File tree

deepmd/dpmodel/descriptor/hybrid.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,23 @@ def get_rcut(self) -> float:
123123
"""Returns the cut-off radius."""
124124
return np.max([descrpt.get_rcut() for descrpt in self.descrpt_list]).item()
125125

126+
def get_dim_chg_spin(self) -> int:
127+
"""Returns the dimension of charge_spin input (0 if not supported)."""
128+
return max(
129+
(descrpt.get_dim_chg_spin() for descrpt in self.descrpt_list), default=0
130+
)
131+
132+
def has_default_chg_spin(self) -> bool:
133+
"""Returns whether the descriptor has a default charge_spin value."""
134+
return any(descrpt.has_default_chg_spin() for descrpt in self.descrpt_list)
135+
136+
def get_default_chg_spin(self) -> list[float] | None:
137+
"""Returns the default charge_spin value, or None."""
138+
for descrpt in self.descrpt_list:
139+
if descrpt.has_default_chg_spin():
140+
return descrpt.get_default_chg_spin()
141+
return None
142+
126143
def get_rcut_smth(self) -> float:
127144
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
128145
# may not be a good idea...

deepmd/pt/model/descriptor/hybrid.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,32 @@ def __init__(
101101

102102
def get_dim_chg_spin(self) -> int:
103103
"""Returns the dimension of charge_spin input (0 if not supported)."""
104-
return max(
105-
(descrpt.get_dim_chg_spin() for descrpt in self.descrpt_list), default=0
106-
)
104+
# JIT-compiled via DPAtomicModel.get_dim_chg_spin; avoid generator
105+
# expressions and `max(..., default=...)` which TorchScript rejects.
106+
dim: int = 0
107+
for descrpt in self.descrpt_list:
108+
d = descrpt.get_dim_chg_spin()
109+
if d > dim:
110+
dim = d
111+
return dim
107112

108113
def has_default_chg_spin(self) -> bool:
109114
"""Returns whether the descriptor has a default charge_spin value."""
110-
return any(descrpt.has_default_chg_spin() for descrpt in self.descrpt_list)
115+
# JIT-compiled via DPAtomicModel.has_default_chg_spin; keep as an
116+
# explicit loop instead of `any(generator)` for TorchScript.
117+
for descrpt in self.descrpt_list:
118+
if descrpt.has_default_chg_spin():
119+
return True
120+
return False
111121

112-
def get_default_chg_spin(self) -> list[float] | None:
122+
@torch.jit.export
123+
def get_default_chg_spin(self) -> Optional[torch.Tensor]: # noqa: UP045
113124
"""Returns the default charge_spin value, or None."""
125+
# JIT-compiled via DPAtomicModel.get_default_chg_spin; the caller
126+
# invokes `.unsqueeze(0)` on the result, so return a Tensor (not list).
114127
for descrpt in self.descrpt_list:
115-
default = descrpt.get_default_chg_spin()
116-
if default is not None:
117-
return default
128+
if descrpt.has_default_chg_spin():
129+
return descrpt.get_default_chg_spin()
118130
return None
119131

120132
def get_rcut(self) -> float:

0 commit comments

Comments
 (0)