diff --git a/deepmd/common.py b/deepmd/common.py index 5581f3533b..03afdbc2c2 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -35,6 +35,7 @@ ) __all__ = [ + "GLOBAL_NP_FLOAT_PRECISION", "VALID_ACTIVATION", "VALID_PRECISION", "expand_sys_str", @@ -249,16 +250,11 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype: RuntimeError if string is invalid """ - if precision == "default": - return GLOBAL_NP_FLOAT_PRECISION - elif precision == "float16": - return np.float16 - elif precision == "float32": - return np.float32 - elif precision == "float64": - return np.float64 - else: - raise RuntimeError(f"{precision} is not a valid precision") + from deepmd.dpmodel.common import ( + get_xp_precision, + ) + + return get_xp_precision(np, precision) def symlink_prefix_files(old_prefix: str, new_prefix: str) -> None: