Skip to content

Commit 85dc0f6

Browse files
Copilotnjzjz
andcommitted
fix(pt): implement get_task_dim method in GeneralFitting base class for proper inheritance
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 835220a commit 85dc0f6

8 files changed

Lines changed: 6 additions & 189 deletions

File tree

checkpoint

Lines changed: 0 additions & 1 deletion
This file was deleted.

deepmd/pt/model/model/dipole_model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,3 @@ def forward_lower(
126126
else:
127127
model_predict = model_ret
128128
return model_predict
129-
130-
@torch.jit.export
131-
def get_task_dim(self) -> int:
132-
"""Get the output dimension of the dipole model."""
133-
# For dipole models, the output dimension is always 3 (x, y, z components)
134-
return 3

deepmd/pt/model/task/fitting.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,8 @@ def _forward_common(
657657
outs = torch.where(mask[:, :, None], outs, 0.0)
658658
results.update({self.var_name: outs})
659659
return results
660+
661+
@torch.jit.export
662+
def get_task_dim(self) -> int:
663+
"""Get the output dimension of the fitting net."""
664+
return self._net_out_dim()

input_v2_compat.json

Lines changed: 0 additions & 69 deletions
This file was deleted.

out.json

Lines changed: 0 additions & 111 deletions
This file was deleted.

source/api_cc/tests/test_deeptensor_pt.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ class TestInferDeepTensorPt : public ::testing::Test {
5353
deepmd::DeepTensor dt;
5454

5555
void SetUp() override {
56-
std::string file_name =
57-
"../../tests/infer/deepdipole_pt_with_get_task_dim.pth";
56+
std::string file_name = "../../tests/infer/deepdipole_pt.pth";
5857
dt.init(file_name);
5958
};
6059

-2.14 KB
Binary file not shown.
-124 KB
Binary file not shown.

0 commit comments

Comments
 (0)