Skip to content

Commit 00dcd58

Browse files
committed
feat: enable configurable bfloat16 autocasting in ModelWrapper
1 parent 9025d0b commit 00dcd58

2 files changed

Lines changed: 4 additions & 1 deletion

File tree

deepmd/pt/train/wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Union,
66
)
77

8+
from deepmd.pt.utils.env import BF16_AUTOCAST
89
import torch
910

1011
if torch.__version__.startswith("2"):
@@ -136,7 +137,7 @@ def share_params(self, shared_links, resume=False) -> None:
136137
f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!"
137138
)
138139

139-
@torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True)
140+
@torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=BF16_AUTOCAST)
140141
def forward(
141142
self,
142143
coord,

deepmd/pt/utils/env.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
3636
ENERGY_BIAS_TRAINABLE = True
3737
CUSTOM_OP_USE_JIT = False
38+
BF16_AUTOCAST = False
3839

3940
PRECISION_DICT = {
4041
"float16": torch.float16,
@@ -76,6 +77,7 @@
7677
torch.set_num_threads(intra_nthreads)
7778

7879
__all__ = [
80+
"BF16_AUTOCAST",
7981
"CACHE_PER_SYS",
8082
"CUSTOM_OP_USE_JIT",
8183
"DEFAULT_PRECISION",

0 commit comments

Comments
 (0)