diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index f109109cfd..cd2d79f52f 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -424,7 +424,7 @@ def optim_angle_update( sub_edge_update_ik = torch.matmul(edge_ebd, sub_edge_ik) result_update = ( - bias + bias.to(sub_angle_update.dtype) + sub_node_update.unsqueeze(2).unsqueeze(3) + sub_edge_update_ij.unsqueeze(2) + sub_edge_update_ik.unsqueeze(3) @@ -463,7 +463,10 @@ def optim_edge_update( sub_edge_update = torch.matmul(edge_ebd, edge) result_update = ( - bias + sub_node_update.unsqueeze(2) + sub_edge_update + sub_node_ext_update + bias.to(sub_node_update.dtype) + + sub_node_update.unsqueeze(2) + + sub_edge_update + + sub_node_ext_update ) return result_update diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index 9a2cbff295..8666f642fc 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -7,6 +7,10 @@ import torch +from deepmd.pt.utils.env import ( + BF16_AUTOCAST, +) + if torch.__version__.startswith("2"): import torch._dynamo @@ -136,6 +140,7 @@ def share_params(self, shared_links, resume=False) -> None: f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!" ) + @torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=BF16_AUTOCAST) def forward( self, coord, diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 185bb1add3..63b0ecca07 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -35,6 +35,7 @@ CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True CUSTOM_OP_USE_JIT = False +BF16_AUTOCAST = False PRECISION_DICT = { "float16": torch.float16, @@ -76,6 +77,7 @@ torch.set_num_threads(intra_nthreads) __all__ = [ + "BF16_AUTOCAST", "CACHE_PER_SYS", "CUSTOM_OP_USE_JIT", "DEFAULT_PRECISION", diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 85988e3523..4377939275 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -93,7 +93,7 @@ def forward(ctx, x, threshold, slope, const_val): ctx.threshold = threshold ctx.slope = slope ctx.const_val = const_val - return silut_forward_script(x, threshold, slope, const_val) + return silut_forward_script(x, threshold, slope, const_val).to(x.dtype) @staticmethod def backward(ctx, grad_output):