File tree Expand file tree Collapse file tree
deepmd/pt/model/descriptor Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -435,9 +435,8 @@ def optim_angle_update(
435435 def optim_edge_update (
436436 self ,
437437 node_ebd : torch .Tensor ,
438- node_ebd_ext : torch .Tensor ,
438+ nei_node_ebd : torch .Tensor ,
439439 edge_ebd : torch .Tensor ,
440- nlist : torch .Tensor ,
441440 feat : str = "node" ,
442441 ) -> torch .Tensor :
443442 if feat == "node" :
@@ -455,10 +454,8 @@ def optim_edge_update(
455454
456455 # nf * nloc * node/edge_dim
457456 sub_node_update = torch .matmul (node_ebd , node )
458- # nf * nloc * nnei * node/edge_dim
459- gathered_node_ebd_ext = _make_nei_g1 (node_ebd_ext , nlist )
460- # nf * nloc * nnei * node/edge_dim
461- sub_node_ext_update = torch .matmul (gathered_node_ebd_ext , node_ext )
457+ # nf * nloc * node/edge_dim
458+ sub_node_ext_update = torch .matmul (nei_node_ebd , node_ext )
462459 # nf * nloc * nnei * node/edge_dim
463460 sub_edge_update = torch .matmul (edge_ebd , edge )
464461
@@ -577,9 +574,8 @@ def forward(
577574 node_edge_update = self .act (
578575 self .optim_edge_update (
579576 node_ebd ,
580- node_ebd_ext ,
577+ nei_node_ebd ,
581578 edge_ebd ,
582- nlist ,
583579 "node" ,
584580 )
585581 ) * sw .unsqueeze (- 1 )
@@ -605,9 +601,8 @@ def forward(
605601 edge_self_update = self .act (
606602 self .optim_edge_update (
607603 node_ebd ,
608- node_ebd_ext ,
604+ nei_node_ebd ,
609605 edge_ebd ,
610- nlist ,
611606 "edge" ,
612607 )
613608 )
You can’t perform that action at this time.
0 commit comments