Skip to content

Commit babd11c

Browse files
committed
feat: use bfloat16 with torch.autocast on training
1 parent 43e0288 commit babd11c

3 files changed

Lines changed: 8 additions & 3 deletions

File tree

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def optim_angle_update(
424424
sub_edge_update_ik = torch.matmul(edge_ebd, sub_edge_ik)
425425

426426
result_update = (
427-
bias
427+
bias.to(sub_angle_update.dtype)
428428
+ sub_node_update.unsqueeze(2).unsqueeze(3)
429429
+ sub_edge_update_ij.unsqueeze(2)
430430
+ sub_edge_update_ik.unsqueeze(3)
@@ -463,7 +463,10 @@ def optim_edge_update(
463463
sub_edge_update = torch.matmul(edge_ebd, edge)
464464

465465
result_update = (
466-
bias + sub_node_update.unsqueeze(2) + sub_edge_update + sub_node_ext_update
466+
bias.to(sub_node_update.dtype)
467+
+ sub_node_update.unsqueeze(2)
468+
+ sub_edge_update
469+
+ sub_node_ext_update
467470
)
468471
return result_update
469472

@@ -679,6 +682,7 @@ def forward(
679682
)
680683
)
681684

685+
a_sw.to(edge_angle_update.dtype)
682686
# nb x nloc x a_nnei x a_nnei x e_dim
683687
weighted_edge_angle_update = (
684688
a_sw[..., None, None] * a_sw[..., None, :, None] * edge_angle_update

deepmd/pt/train/wrapper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def share_params(self, shared_links, resume=False) -> None:
136136
f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!"
137137
)
138138

139+
@torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True)
139140
def forward(
140141
self,
141142
coord,

deepmd/pt/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def forward(ctx, x, threshold, slope, const_val):
9393
ctx.threshold = threshold
9494
ctx.slope = slope
9595
ctx.const_val = const_val
96-
return silut_forward_script(x, threshold, slope, const_val)
96+
return silut_forward_script(x, threshold, slope, const_val).to(x.dtype)
9797

9898
@staticmethod
9999
def backward(ctx, grad_output):

0 commit comments

Comments
 (0)