diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index b307f2f15b..f460c6062e 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -137,7 +137,7 @@ def model_type(self) -> type["DeepEvalWrapper"]: return DeepDOS elif "dipole" in model_output_type: return DeepDipole - elif "polar" in model_output_type: + elif "polar" in model_output_type or "polarizability" in model_output_type: return DeepPolar elif "wfc" in model_output_type: return DeepWFC diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 92ed78a13e..fbd8860c0c 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -158,7 +158,7 @@ def model_type(self) -> type["DeepEvalWrapper"]: return DeepDOS elif "dipole" in model_output_type: return DeepDipole - elif "polar" in model_output_type: + elif "polar" in model_output_type or "polarizability" in model_output_type: return DeepPolar elif "wfc" in model_output_type: return DeepWFC diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index f3e52cdac0..d8173e4570 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -248,7 +248,7 @@ def model_type(self) -> type["DeepEvalWrapper"]: return DeepDOS elif "dipole" in model_output_type: return DeepDipole - elif "polar" in model_output_type: + elif "polar" in model_output_type or "polarizability" in model_output_type: return DeepPolar elif "global_polar" in model_output_type: return DeepGlobalPolar