Skip to content

Commit 0e792ac

Browse files
fix(pd): add charge_spin stub methods and param to dpa1/dpa2/se_a/se_t_tebd descriptors
Co-authored-by: HydrogenSulfate <23737287+HydrogenSulfate@users.noreply.github.com>
1 parent fe9f626 commit 0e792ac

4 files changed

Lines changed: 52 additions & 3 deletions

File tree

deepmd/pd/model/descriptor/dpa1.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,18 @@ def get_buffer_type_map(self) -> paddle.Tensor:
366366
"""
367367
return self.buffer_type_map
368368

369+
def get_dim_chg_spin(self) -> int:
370+
"""Returns the dimension of charge_spin input (0 if not supported)."""
371+
return 0
372+
373+
def has_default_chg_spin(self) -> bool:
374+
"""Returns whether the descriptor has a default charge_spin value."""
375+
return False
376+
377+
def get_default_chg_spin(self) -> None:
378+
"""Returns the default charge_spin value, or None."""
379+
return None
380+
369381
def get_dim_out(self) -> int:
370382
"""Returns the output dimension."""
371383
ret = self.se_atten.get_dim_out()
@@ -627,6 +639,7 @@ def forward(
627639
mapping: paddle.Tensor | None = None,
628640
comm_dict: list[paddle.Tensor] | None = None,
629641
fparam: paddle.Tensor | None = None,
642+
charge_spin: paddle.Tensor | None = None,
630643
) -> tuple[
631644
paddle.Tensor,
632645
paddle.Tensor | None,

deepmd/pd/model/descriptor/dpa2.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,18 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
333333
param.stop_gradient = not trainable
334334
self.compress = False
335335

336+
def get_dim_chg_spin(self) -> int:
337+
"""Returns the dimension of charge_spin input (0 if not supported)."""
338+
return 0
339+
340+
def has_default_chg_spin(self) -> bool:
341+
"""Returns whether the descriptor has a default charge_spin value."""
342+
return False
343+
344+
def get_default_chg_spin(self) -> None:
345+
"""Returns the default charge_spin value, or None."""
346+
return None
347+
336348
def get_rcut(self) -> float:
337349
"""Returns the cut-off radius."""
338350
return self.rcut
@@ -734,10 +746,8 @@ def forward(
734746
mapping: paddle.Tensor | None = None,
735747
comm_dict: list[paddle.Tensor] | None = None,
736748
fparam: paddle.Tensor | None = None,
749+
charge_spin: paddle.Tensor | None = None,
737750
) -> tuple[
738-
paddle.Tensor,
739-
paddle.Tensor | None,
740-
paddle.Tensor | None,
741751
paddle.Tensor | None,
742752
paddle.Tensor | None,
743753
]:

deepmd/pd/model/descriptor/se_a.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@ def __init__(
119119
seed=seed,
120120
)
121121

122+
def get_dim_chg_spin(self) -> int:
123+
"""Returns the dimension of charge_spin input (0 if not supported)."""
124+
return 0
125+
126+
def has_default_chg_spin(self) -> bool:
127+
"""Returns whether the descriptor has a default charge_spin value."""
128+
return False
129+
130+
def get_default_chg_spin(self) -> None:
131+
"""Returns the default charge_spin value, or None."""
132+
return None
133+
122134
def get_rcut(self) -> float:
123135
"""Returns the cut-off radius."""
124136
return self.sea.get_rcut()
@@ -289,6 +301,7 @@ def forward(
289301
mapping: paddle.Tensor | None = None,
290302
comm_dict: list[paddle.Tensor] | None = None,
291303
fparam: paddle.Tensor | None = None,
304+
charge_spin: paddle.Tensor | None = None,
292305
) -> tuple[
293306
paddle.Tensor,
294307
paddle.Tensor | None,

deepmd/pd/model/descriptor/se_t_tebd.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,18 @@ def __init__(
190190
for param in self.parameters():
191191
param.stop_gradient = not trainable
192192

193+
def get_dim_chg_spin(self) -> int:
194+
"""Returns the dimension of charge_spin input (0 if not supported)."""
195+
return 0
196+
197+
def has_default_chg_spin(self) -> bool:
198+
"""Returns whether the descriptor has a default charge_spin value."""
199+
return False
200+
201+
def get_default_chg_spin(self) -> None:
202+
"""Returns the default charge_spin value, or None."""
203+
return None
204+
193205
def get_rcut(self) -> float:
194206
"""Returns the cut-off radius."""
195207
return self.se_ttebd.get_rcut()
@@ -438,6 +450,7 @@ def forward(
438450
mapping: paddle.Tensor | None = None,
439451
comm_dict: list[paddle.Tensor] | None = None,
440452
fparam: paddle.Tensor | None = None,
453+
charge_spin: paddle.Tensor | None = None,
441454
) -> paddle.Tensor:
442455
"""Compute the descriptor.
443456

0 commit comments

Comments
 (0)