1111from deepmd .dpmodel .utils .seed import (
1212 child_seed ,
1313)
14+ from deepmd .pd .cxx_op import (
15+ ENABLE_CUSTOMIZED_OP ,
16+ paddle_ops_deepmd ,
17+ )
1418from deepmd .pd .model .descriptor .descriptor import (
1519 DescriptorBlock ,
1620)
3539from deepmd .pd .utils .exclude_mask import (
3640 PairExcludeMask ,
3741)
42+ from deepmd .pd .utils .spin import (
43+ concat_switch_virtual ,
44+ )
3845from deepmd .pd .utils .utils import (
3946 ActivationFn ,
4047)
4956 RepFlowLayer ,
5057)
5158
59+ if not ENABLE_CUSTOMIZED_OP :
60+
61+ def border_op (
62+ argument0 ,
63+ argument1 ,
64+ argument2 ,
65+ argument3 ,
66+ argument4 ,
67+ argument5 ,
68+ argument6 ,
69+ argument7 ,
70+ argument8 ,
71+ ) -> paddle .Tensor :
72+ raise NotImplementedError (
73+ "border_op is not available since customized Paddle OP library is not built when freezing the model. "
74+ "See documentation for DPA3 for details."
75+ )
76+
77+ # Note: this hack cannot actually save a model that can be run using LAMMPS.
78+ paddle_ops_deepmd_border_op = border_op
79+ else :
80+ paddle_ops_deepmd_border_op = paddle_ops_deepmd .border_op
81+
5282
5383@DescriptorBlock .register ("se_repflow" )
5484class DescrptBlockRepflows (DescriptorBlock ):
@@ -418,13 +448,14 @@ def forward(
418448 ):
419449 parallel_mode = comm_dict is not None
420450 if not parallel_mode :
421- assert mapping is not None
451+ if paddle .in_dynamic_mode ():
452+ assert mapping is not None and mapping .numel () > 0
422453 nframes , nloc , nnei = nlist .shape
423454 nall = extended_coord .reshape ([nframes , - 1 ]).shape [1 ] // 3
424455 atype = extended_atype [:, :nloc ]
425456 # nb x nloc x nnei
426457 exclude_mask = self .emask (nlist , extended_atype )
427- nlist = paddle .where (exclude_mask != 0 , nlist , - 1 )
458+ nlist = paddle .where (exclude_mask != 0 , nlist , paddle . full_like ( nlist , - 1 ) )
428459 # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1
429460 dmatrix , diff , sw = prod_env_mat (
430461 extended_coord ,
@@ -447,7 +478,7 @@ def forward(
447478 :, :, : self .a_sel
448479 ]
449480 a_nlist = nlist [:, :, : self .a_sel ]
450- a_nlist = paddle .where (a_dist_mask , a_nlist , - 1 )
481+ a_nlist = paddle .where (a_dist_mask , a_nlist , paddle . full_like ( a_nlist , - 1 ) )
451482 _ , a_diff , a_sw = prod_env_mat (
452483 extended_coord ,
453484 a_nlist ,
@@ -497,7 +528,8 @@ def forward(
497528 angle_input = cosine_ij .unsqueeze (- 1 ) / (np .pi ** 0.5 )
498529
499530 if not parallel_mode and self .use_loc_mapping :
500- assert mapping is not None
531+ if paddle .in_dynamic_mode ():
532+ assert mapping is not None and mapping .numel () > 0
501533 # convert nlist from nall to nloc index
502534 nlist = paddle .take_along_axis (
503535 mapping ,
@@ -542,7 +574,8 @@ def forward(
542574
543575 # nb x nall x n_dim
544576 if not parallel_mode :
545- assert mapping is not None
577+ if paddle .in_dynamic_mode ():
578+ assert mapping is not None and mapping .numel () > 0
546579 mapping = (
547580 mapping .reshape ([nframes , nall ])
548581 .unsqueeze (- 1 )
@@ -552,14 +585,81 @@ def forward(
552585 # node_ebd: nb x nloc x n_dim
553586 # node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parallel_mode
554587 if not parallel_mode :
555- assert mapping is not None
588+ if paddle .in_dynamic_mode ():
589+ assert mapping is not None and mapping .numel () > 0
556590 node_ebd_ext = (
557591 paddle .take_along_axis (node_ebd , mapping , 1 , broadcast = False )
558592 if not self .use_loc_mapping
559593 else node_ebd
560594 )
561595 else :
562- raise NotImplementedError ("Not implemented" )
596+ assert len (comm_dict ) >= 6
597+ has_spin = len (comm_dict ) >= 7
598+ if not has_spin :
599+ n_padding = nall - nloc
600+ # node_ebd = paddle.nn.functional.pad(
601+ # node_ebd.squeeze(0), [0, 0, 0, n_padding], value=0.0
602+ # )
603+ # [nframes, nloc, tebd_dim]
604+ _shapes = node_ebd .shape [1 :]
605+ _shapes [1 ] = n_padding
606+ node_ebd = paddle .concat (
607+ [node_ebd , paddle .zeros (_shapes , dtype = node_ebd .dtype )],
608+ axis = 1 ,
609+ )
610+ real_nloc = nloc
611+ real_nall = nall
612+ else :
613+ # for spin
614+ real_nloc = nloc // 2
615+ real_nall = nall // 2
616+ real_n_padding = real_nall - real_nloc
617+ node_ebd_real , node_ebd_virtual = paddle .split (
618+ node_ebd , [real_nloc , real_nloc ], axis = 1
619+ )
620+ # mix_node_ebd: nb x real_nloc x (n_dim * 2)
621+ mix_node_ebd = paddle .concat (
622+ [node_ebd_real , node_ebd_virtual ], axis = 2
623+ )
624+ # nb x real_nall x (n_dim * 2)
625+ node_ebd = paddle .nn .functional .pad (
626+ mix_node_ebd .squeeze (0 ), (0 , 0 , 0 , real_n_padding ), value = 0.0
627+ )
628+
629+ assert len (comm_dict ) >= 6
630+ # assert "send_list" in comm_dict
631+ # assert "send_proc" in comm_dict
632+ # assert "recv_proc" in comm_dict
633+ # assert "send_num" in comm_dict
634+ # assert "recv_num" in comm_dict
635+ # assert "communicator" in comm_dict
636+ ret = paddle_ops_deepmd_border_op (
637+ comm_dict [0 ],
638+ comm_dict [1 ],
639+ comm_dict [2 ],
640+ comm_dict [3 ],
641+ comm_dict [4 ],
642+ node_ebd ,
643+ comm_dict [5 ],
644+ paddle .to_tensor (
645+ real_nloc ,
646+ dtype = paddle .int32 ,
647+ place = paddle .CPUPlace (),
648+ ), # should be int of c++, placed on cpu
649+ paddle .to_tensor (
650+ real_nall - real_nloc ,
651+ dtype = paddle .int32 ,
652+ place = paddle .CPUPlace (),
653+ ), # should be int of c++, placed on cpu
654+ )
655+ node_ebd_ext = ret [0 ].unsqueeze (0 )
656+ if has_spin :
657+ node_ebd_real_ext , node_ebd_virtual_ext = paddle .split (
658+ node_ebd_ext , [n_dim , n_dim ], axis = 2
659+ )
660+ node_ebd_ext = concat_switch_virtual (
661+ node_ebd_real_ext , node_ebd_virtual_ext , real_nloc
662+ )
563663 node_ebd , edge_ebd , angle_ebd = ll .forward (
564664 node_ebd_ext ,
565665 edge_ebd ,
0 commit comments