@@ -101,20 +101,32 @@ def __init__(
101101
102102 def get_dim_chg_spin (self ) -> int :
103103 """Returns the dimension of charge_spin input (0 if not supported)."""
104- return max (
105- (descrpt .get_dim_chg_spin () for descrpt in self .descrpt_list ), default = 0
106- )
104+ # JIT-compiled via DPAtomicModel.get_dim_chg_spin; avoid generator
105+ # expressions and `max(..., default=...)` which TorchScript rejects.
106+ dim : int = 0
107+ for descrpt in self .descrpt_list :
108+ d = descrpt .get_dim_chg_spin ()
109+ if d > dim :
110+ dim = d
111+ return dim
107112
108113 def has_default_chg_spin (self ) -> bool :
109114 """Returns whether the descriptor has a default charge_spin value."""
110- return any (descrpt .has_default_chg_spin () for descrpt in self .descrpt_list )
115+ # JIT-compiled via DPAtomicModel.has_default_chg_spin; keep as an
116+ # explicit loop instead of `any(generator)` for TorchScript.
117+ for descrpt in self .descrpt_list :
118+ if descrpt .has_default_chg_spin ():
119+ return True
120+ return False
111121
112- def get_default_chg_spin (self ) -> list [float ] | None :
122+ @torch .jit .export
123+ def get_default_chg_spin (self ) -> Optional [torch .Tensor ]: # noqa: UP045
113124 """Returns the default charge_spin value, or None."""
125+ # JIT-compiled via DPAtomicModel.get_default_chg_spin; the caller
126+ # invokes `.unsqueeze(0)` on the result, so return a Tensor (not list).
114127 for descrpt in self .descrpt_list :
115- default = descrpt .get_default_chg_spin ()
116- if default is not None :
117- return default
128+ if descrpt .has_default_chg_spin ():
129+ return descrpt .get_default_chg_spin ()
118130 return None
119131
120132 def get_rcut (self ) -> float :
0 commit comments