Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def optim_angle_update(
sub_edge_update_ik = torch.matmul(edge_ebd, sub_edge_ik)

result_update = (
bias
bias.to(sub_angle_update.dtype)
+ sub_node_update.unsqueeze(2).unsqueeze(3)
+ sub_edge_update_ij.unsqueeze(2)
+ sub_edge_update_ik.unsqueeze(3)
Expand Down Expand Up @@ -463,7 +463,10 @@ def optim_edge_update(
sub_edge_update = torch.matmul(edge_ebd, edge)

result_update = (
bias + sub_node_update.unsqueeze(2) + sub_edge_update + sub_node_ext_update
bias.to(sub_node_update.dtype)
+ sub_node_update.unsqueeze(2)
+ sub_edge_update
+ sub_node_ext_update
)
return result_update

Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/train/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

import torch

from deepmd.pt.utils.env import (
BF16_AUTOCAST,
)

if torch.__version__.startswith("2"):
import torch._dynamo

Expand Down Expand Up @@ -136,6 +140,7 @@ def share_params(self, shared_links, resume=False) -> None:
f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!"
)

@torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=BF16_AUTOCAST)
def forward(
self,
coord,
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
ENERGY_BIAS_TRAINABLE = True
CUSTOM_OP_USE_JIT = False
BF16_AUTOCAST = False

PRECISION_DICT = {
"float16": torch.float16,
Expand Down Expand Up @@ -76,6 +77,7 @@
torch.set_num_threads(intra_nthreads)

__all__ = [
"BF16_AUTOCAST",
"CACHE_PER_SYS",
"CUSTOM_OP_USE_JIT",
"DEFAULT_PRECISION",
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def forward(ctx, x, threshold, slope, const_val):
ctx.threshold = threshold
ctx.slope = slope
ctx.const_val = const_val
return silut_forward_script(x, threshold, slope, const_val)
return silut_forward_script(x, threshold, slope, const_val).to(x.dtype)

@staticmethod
def backward(ctx, grad_output):
Expand Down