Skip to content

Commit 40a32a5

Browse files
committed
feat: add basic support to torch_npu for dp training
1 parent ab6e300 commit 40a32a5

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

deepmd/env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@
3030

3131

3232
# FLOAT_PREC
33-
dp_float_prec = os.environ.get("DP_INTERFACE_PREC", "high").lower()
33+
dp_float_prec = os.environ.get("DP_INTERFACE_PREC", "low").lower()
3434
if dp_float_prec in ("high", ""):
3535
# default is high
3636
GLOBAL_NP_FLOAT_PRECISION = np.float64
3737
GLOBAL_ENER_FLOAT_PRECISION = np.float64
3838
global_float_prec = "double"
3939
elif dp_float_prec == "low":
4040
GLOBAL_NP_FLOAT_PRECISION = np.float32
41-
GLOBAL_ENER_FLOAT_PRECISION = np.float64
41+
GLOBAL_ENER_FLOAT_PRECISION = np.float32
4242
global_float_prec = "float"
4343
else:
4444
raise RuntimeError(

deepmd/pt/utils/env.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import torch
8+
import torch_npu
89

910
from deepmd.common import (
1011
VALID_PRECISION,
@@ -37,10 +38,10 @@
3738
LOCAL_RANK = os.environ.get("LOCAL_RANK")
3839
LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK)
3940

40-
if os.environ.get("DEVICE") == "cpu" or torch.cuda.is_available() is False:
41+
if os.environ.get("DEVICE") == "cpu" or torch_npu.npu.is_available() is False:
4142
DEVICE = torch.device("cpu")
4243
else:
43-
DEVICE = torch.device(f"cuda:{LOCAL_RANK}")
44+
DEVICE = torch.device(f"npu:{LOCAL_RANK}")
4445

4546
JIT = False
4647
CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory

0 commit comments

Comments
 (0)