Skip to content

Commit 9eff90a

Browse files
Copilotnjzjz
andcommitted
fix(infer): return backend-specific models instead of converting to dpmodel BaseModel
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent e5c7968 commit 9eff90a

3 files changed

Lines changed: 20 additions & 41 deletions

File tree

deepmd/pd/infer/deep_eval.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
import paddle
1212

1313
from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
14-
from deepmd.dpmodel.model.base_model import (
15-
BaseModel,
16-
)
1714
from deepmd.dpmodel.output_def import (
1815
ModelOutputDef,
1916
OutputVariableCategory,
@@ -510,16 +507,14 @@ def get_model_size(self) -> dict:
510507
}
511508

512509
def get_model(self):
513-
"""Get the Paddle model as BaseModel.
510+
"""Get the Paddle model.
514511
515512
Returns
516513
-------
517514
BaseModel
518-
The Paddle model converted to BaseModel.
515+
The Paddle model instance.
519516
"""
520-
# Convert Paddle model to BaseModel by serializing and deserializing
521-
model_dict = self.dp.model["Default"].serialize()
522-
return BaseModel.deserialize(model_dict)
517+
return self.dp.model["Default"]
523518

524519
def eval_descriptor(
525520
self,

deepmd/pt/infer/deep_eval.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
import torch
1414

1515
from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
16-
from deepmd.dpmodel.model.base_model import (
17-
BaseModel,
18-
)
1916
from deepmd.dpmodel.output_def import (
2017
ModelOutputDef,
2118
OutputVariableCategory,
@@ -710,31 +707,14 @@ def get_observed_types(self) -> dict:
710707
}
711708

712709
def get_model(self):
713-
"""Get the PyTorch model as BaseModel.
710+
"""Get the PyTorch model.
714711
715712
Returns
716713
-------
717714
BaseModel
718-
The PyTorch model converted to BaseModel.
715+
The PyTorch model instance.
719716
"""
720-
# Convert PyTorch model to BaseModel by serializing and deserializing
721-
if str(self.model_path).endswith(".pth"):
722-
# For JIT models (.pth), we need to reconstruct the original model first
723-
from deepmd.pt.model.model import (
724-
get_model,
725-
)
726-
727-
# The JIT model should have model_def_script
728-
model_def_script = self.model_def_script
729-
model = get_model(model_def_script)
730-
# Load state dict with strict=False to handle compression info differences
731-
model.load_state_dict(self.dp.model["Default"].state_dict(), strict=False)
732-
model_dict = model.serialize()
733-
else:
734-
# For regular models (.pt), we can serialize directly
735-
model_dict = self.dp.model["Default"].serialize()
736-
737-
return BaseModel.deserialize(model_dict)
717+
return self.dp.model["Default"]
738718

739719
def eval_descriptor(
740720
self,

source/tests/infer/test_get_model.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,23 @@ def test_get_model_backend_specific(self):
5151
model = self.dp.get_model()
5252

5353
if extension == ".pth":
54-
# For PyTorch models, should return nn.Module-like object
55-
import torch
54+
# For PyTorch models, should return the PyTorch BaseModel
55+
from deepmd.pt.model.model.model import BaseModel as PTBaseModel
5656

57+
self.assertIsInstance(
58+
model,
59+
PTBaseModel,
60+
"PyTorch model should return PyTorch BaseModel instance",
61+
)
62+
# Check if it has common model methods
5763
self.assertTrue(
58-
hasattr(model, "forward") or isinstance(model, torch.nn.Module),
59-
"PyTorch model should be or behave like nn.Module",
64+
hasattr(model, "get_type_map"),
65+
"PyTorch BaseModel should have get_type_map method",
66+
)
67+
self.assertTrue(
68+
hasattr(model, "get_rcut"),
69+
"PyTorch BaseModel should have get_rcut method",
6070
)
61-
# Check if it has model attribute (ModelWrapper)
62-
if hasattr(model, "model"):
63-
self.assertTrue(
64-
hasattr(model.model, "__getitem__"),
65-
"PyTorch ModelWrapper should have model dict",
66-
)
6771
elif extension == ".pb":
6872
# For TensorFlow models, should return graph
6973
try:

0 commit comments

Comments
 (0)