Skip to content

Commit 8fe0df6

Browse files
support DPA2/DPA3 inference
1 parent d2a46b7 commit 8fe0df6

12 files changed

Lines changed: 832 additions & 57 deletions

File tree

deepmd/pd/entrypoints/main.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,14 +388,25 @@ def freeze(
388388
model.forward_lower = paddle.jit.to_static(
389389
model.forward_lower,
390390
input_spec=[
391-
InputSpec([1, -1, 3], dtype="float64", name="coord"), # extended_coord
392-
InputSpec([1, -1], dtype="int32", name="atype"), # extended_atype
391+
InputSpec(
392+
[1, -1, 3], dtype="float64", name="extended_coord"
393+
), # extended_coord
394+
InputSpec(
395+
[1, -1], dtype="int32", name="extended_atype"
396+
), # extended_atype
393397
InputSpec([1, -1, -1], dtype="int32", name="nlist"), # nlist
394398
InputSpec([1, -1], dtype="int64", name="mapping"), # mapping
395399
None, # fparam
396400
None, # aparam
397401
True, # do_atomic_virial
398-
None, # comm_dict
402+
[
403+
InputSpec([-1], "int32", name="send_list"),
404+
InputSpec([-1], "int32", name="send_proc"),
405+
InputSpec([-1], "int32", name="recv_proc"),
406+
InputSpec([-1], "int32", name="send_num"),
407+
InputSpec([-1], "int32", name="recv_num"),
408+
InputSpec([-1], "int64", name="communicator"),
409+
], # comm_dict
399410
],
400411
full_graph=True,
401412
)

deepmd/pd/model/descriptor/dpa2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,8 +798,8 @@ def forward(
798798
assert self.tebd_transform is not None
799799
g1 = g1 + self.tebd_transform(g1_inp)
800800
# mapping g1
801-
if comm_dict is None:
802-
assert mapping is not None
801+
if comm_dict is None or len(comm_dict) == 0:
802+
assert mapping.numel() > 0
803803
mapping_ext = (
804804
mapping.reshape([nframes, nall])
805805
.unsqueeze(-1)

deepmd/pd/model/descriptor/repflows.py

Lines changed: 107 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from deepmd.dpmodel.utils.seed import (
1212
child_seed,
1313
)
14+
from deepmd.pd.cxx_op import (
15+
ENABLE_CUSTOMIZED_OP,
16+
paddle_ops_deepmd,
17+
)
1418
from deepmd.pd.model.descriptor.descriptor import (
1519
DescriptorBlock,
1620
)
@@ -35,6 +39,9 @@
3539
from deepmd.pd.utils.exclude_mask import (
3640
PairExcludeMask,
3741
)
42+
from deepmd.pd.utils.spin import (
43+
concat_switch_virtual,
44+
)
3845
from deepmd.pd.utils.utils import (
3946
ActivationFn,
4047
)
@@ -49,6 +56,29 @@
4956
RepFlowLayer,
5057
)
5158

59+
if not ENABLE_CUSTOMIZED_OP:
60+
61+
def border_op(
62+
argument0,
63+
argument1,
64+
argument2,
65+
argument3,
66+
argument4,
67+
argument5,
68+
argument6,
69+
argument7,
70+
argument8,
71+
) -> paddle.Tensor:
72+
raise NotImplementedError(
73+
"border_op is not available since customized Paddle OP library is not built when freezing the model. "
74+
"See documentation for DPA3 for details."
75+
)
76+
77+
# Note: this hack cannot actually save a model that can be run using LAMMPS.
78+
paddle_ops_deepmd_border_op = border_op
79+
else:
80+
paddle_ops_deepmd_border_op = paddle_ops_deepmd.border_op
81+
5282

5383
@DescriptorBlock.register("se_repflow")
5484
class DescrptBlockRepflows(DescriptorBlock):
@@ -418,13 +448,14 @@ def forward(
418448
):
419449
parallel_mode = comm_dict is not None
420450
if not parallel_mode:
421-
assert mapping is not None
451+
if paddle.in_dynamic_mode():
452+
assert mapping is not None and mapping.numel() > 0
422453
nframes, nloc, nnei = nlist.shape
423454
nall = extended_coord.reshape([nframes, -1]).shape[1] // 3
424455
atype = extended_atype[:, :nloc]
425456
# nb x nloc x nnei
426457
exclude_mask = self.emask(nlist, extended_atype)
427-
nlist = paddle.where(exclude_mask != 0, nlist, -1)
458+
nlist = paddle.where(exclude_mask != 0, nlist, paddle.full_like(nlist, -1))
428459
# nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1
429460
dmatrix, diff, sw = prod_env_mat(
430461
extended_coord,
@@ -447,7 +478,7 @@ def forward(
447478
:, :, : self.a_sel
448479
]
449480
a_nlist = nlist[:, :, : self.a_sel]
450-
a_nlist = paddle.where(a_dist_mask, a_nlist, -1)
481+
a_nlist = paddle.where(a_dist_mask, a_nlist, paddle.full_like(a_nlist, -1))
451482
_, a_diff, a_sw = prod_env_mat(
452483
extended_coord,
453484
a_nlist,
@@ -497,7 +528,8 @@ def forward(
497528
angle_input = cosine_ij.unsqueeze(-1) / (np.pi**0.5)
498529

499530
if not parallel_mode and self.use_loc_mapping:
500-
assert mapping is not None
531+
if paddle.in_dynamic_mode():
532+
assert mapping is not None and mapping.numel() > 0
501533
# convert nlist from nall to nloc index
502534
nlist = paddle.take_along_axis(
503535
mapping,
@@ -542,7 +574,8 @@ def forward(
542574

543575
# nb x nall x n_dim
544576
if not parallel_mode:
545-
assert mapping is not None
577+
if paddle.in_dynamic_mode():
578+
assert mapping is not None and mapping.numel() > 0
546579
mapping = (
547580
mapping.reshape([nframes, nall])
548581
.unsqueeze(-1)
@@ -552,14 +585,81 @@ def forward(
552585
# node_ebd: nb x nloc x n_dim
553586
# node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parallel_mode
554587
if not parallel_mode:
555-
assert mapping is not None
588+
if paddle.in_dynamic_mode():
589+
assert mapping is not None and mapping.numel() > 0
556590
node_ebd_ext = (
557591
paddle.take_along_axis(node_ebd, mapping, 1, broadcast=False)
558592
if not self.use_loc_mapping
559593
else node_ebd
560594
)
561595
else:
562-
raise NotImplementedError("Not implemented")
596+
assert len(comm_dict) >= 6
597+
has_spin = len(comm_dict) >= 7
598+
if not has_spin:
599+
n_padding = nall - nloc
600+
# node_ebd = paddle.nn.functional.pad(
601+
# node_ebd.squeeze(0), [0, 0, 0, n_padding], value=0.0
602+
# )
603+
# [nframes, nloc, tebd_dim]
604+
_shapes = node_ebd.shape[1:]
605+
_shapes[1] = n_padding
606+
node_ebd = paddle.concat(
607+
[node_ebd, paddle.zeros(_shapes, dtype=node_ebd.dtype)],
608+
axis=1,
609+
)
610+
real_nloc = nloc
611+
real_nall = nall
612+
else:
613+
# for spin
614+
real_nloc = nloc // 2
615+
real_nall = nall // 2
616+
real_n_padding = real_nall - real_nloc
617+
node_ebd_real, node_ebd_virtual = paddle.split(
618+
node_ebd, [real_nloc, real_nloc], axis=1
619+
)
620+
# mix_node_ebd: nb x real_nloc x (n_dim * 2)
621+
mix_node_ebd = paddle.concat(
622+
[node_ebd_real, node_ebd_virtual], axis=2
623+
)
624+
# nb x real_nall x (n_dim * 2)
625+
node_ebd = paddle.nn.functional.pad(
626+
mix_node_ebd.squeeze(0), (0, 0, 0, real_n_padding), value=0.0
627+
)
628+
629+
assert len(comm_dict) >= 6
630+
# assert "send_list" in comm_dict
631+
# assert "send_proc" in comm_dict
632+
# assert "recv_proc" in comm_dict
633+
# assert "send_num" in comm_dict
634+
# assert "recv_num" in comm_dict
635+
# assert "communicator" in comm_dict
636+
ret = paddle_ops_deepmd_border_op(
637+
comm_dict[0],
638+
comm_dict[1],
639+
comm_dict[2],
640+
comm_dict[3],
641+
comm_dict[4],
642+
node_ebd,
643+
comm_dict[5],
644+
paddle.to_tensor(
645+
real_nloc,
646+
dtype=paddle.int32,
647+
place=paddle.CPUPlace(),
648+
), # should be int of c++, placed on cpu
649+
paddle.to_tensor(
650+
real_nall - real_nloc,
651+
dtype=paddle.int32,
652+
place=paddle.CPUPlace(),
653+
), # should be int of c++, placed on cpu
654+
)
655+
node_ebd_ext = ret[0].unsqueeze(0)
656+
if has_spin:
657+
node_ebd_real_ext, node_ebd_virtual_ext = paddle.split(
658+
node_ebd_ext, [n_dim, n_dim], axis=2
659+
)
660+
node_ebd_ext = concat_switch_virtual(
661+
node_ebd_real_ext, node_ebd_virtual_ext, real_nloc
662+
)
563663
node_ebd, edge_ebd, angle_ebd = ll.forward(
564664
node_ebd_ext,
565665
edge_ebd,

deepmd/pd/model/descriptor/repformer_layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _make_nei_g1(
110110
# index: nb x (nloc x nnei) x ng1
111111
index = nlist.reshape([nb, nloc * nnei]).unsqueeze(-1).expand([-1, -1, ng1])
112112
# gg1 : nb x (nloc x nnei) x ng1
113+
# print(g1_ext.shape, index.shape)
113114
gg1 = paddle.take_along_axis(g1_ext, indices=index, axis=1, broadcast=False)
114115
# gg1 : nb x nloc x nnei x ng1
115116
gg1 = gg1.reshape([nb, nloc, nnei, ng1])

deepmd/pd/model/descriptor/repformers.py

Lines changed: 108 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from deepmd.dpmodel.utils.seed import (
1111
child_seed,
1212
)
13+
from deepmd.pd.cxx_op import (
14+
ENABLE_CUSTOMIZED_OP,
15+
paddle_ops_deepmd,
16+
)
1317
from deepmd.pd.model.descriptor.descriptor import (
1418
DescriptorBlock,
1519
)
@@ -31,6 +35,9 @@
3135
from deepmd.pd.utils.exclude_mask import (
3236
PairExcludeMask,
3337
)
38+
from deepmd.pd.utils.spin import (
39+
concat_switch_virtual,
40+
)
3441
from deepmd.pd.utils.utils import (
3542
ActivationFn,
3643
)
@@ -45,6 +52,29 @@
4552
RepformerLayer,
4653
)
4754

55+
if not ENABLE_CUSTOMIZED_OP:
56+
57+
def border_op(
58+
argument0,
59+
argument1,
60+
argument2,
61+
argument3,
62+
argument4,
63+
argument5,
64+
argument6,
65+
argument7,
66+
argument8,
67+
) -> paddle.Tensor:
68+
raise NotImplementedError(
69+
"border_op is not available since customized Paddle OP library is not built when freezing the model. "
70+
"See documentation for DPA3 for details."
71+
)
72+
73+
# Note: this hack cannot actually save a model that can be run using LAMMPS.
74+
paddle_ops_deepmd_border_op = border_op
75+
else:
76+
paddle_ops_deepmd_border_op = paddle_ops_deepmd.border_op
77+
4878

4979
@DescriptorBlock.register("se_repformer")
5080
@DescriptorBlock.register("se_uni")
@@ -380,9 +410,10 @@ def forward(
380410
type_embedding: Optional[paddle.Tensor] = None,
381411
comm_dict: Optional[dict[str, paddle.Tensor]] = None,
382412
):
383-
if comm_dict is None:
384-
assert mapping is not None
385-
assert extended_atype_embd is not None
413+
if comm_dict is None or len(comm_dict) == 0:
414+
if paddle.in_dynamic_mode():
415+
assert mapping is not None and mapping.numel() > 0
416+
assert extended_atype_embd is not None or extended_atype_embd.numel() > 0
386417
nframes, nloc, nnei = nlist.shape
387418
nall = extended_coord.reshape([nframes, -1]).shape[1] // 3
388419
atype = extended_atype[:, :nloc]
@@ -406,7 +437,7 @@ def forward(
406437
sw = sw.masked_fill(~nlist_mask, 0.0)
407438

408439
# [nframes, nloc, tebd_dim]
409-
if comm_dict is None:
440+
if comm_dict is None or len(comm_dict) == 0:
410441
if paddle.in_dynamic_mode():
411442
assert isinstance(extended_atype_embd, paddle.Tensor) # for jit
412443
atype_embd = extended_atype_embd[:, :nloc, :]
@@ -432,8 +463,9 @@ def forward(
432463
# if the a neighbor is real or not is indicated by nlist_mask
433464
nlist[nlist == -1] = 0
434465
# nb x nall x ng1
435-
if comm_dict is None:
436-
assert mapping is not None
466+
if comm_dict is None or len(comm_dict) == 0:
467+
if paddle.in_dynamic_mode():
468+
assert mapping is not None and mapping.numel() > 0
437469
mapping = (
438470
mapping.reshape([nframes, nall])
439471
.unsqueeze(-1)
@@ -442,13 +474,80 @@ def forward(
442474
for idx, ll in enumerate(self.layers):
443475
# g1: nb x nloc x ng1
444476
# g1_ext: nb x nall x ng1
445-
if comm_dict is None:
446-
assert mapping is not None
477+
if comm_dict is None or len(comm_dict) == 0:
478+
if paddle.in_dynamic_mode():
479+
assert mapping is not None and mapping.numel() > 0
447480
g1_ext = paddle.take_along_axis(
448481
g1, axis=1, indices=mapping, broadcast=False
449482
)
450483
else:
451-
raise NotImplementedError("Not implemented yet")
484+
has_spin = len(comm_dict) >= 7
485+
if not has_spin:
486+
n_padding = nall - nloc
487+
# g1 = paddle.nn.functional.pad(
488+
# g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
489+
# )
490+
_shapes = g1.shape[1:]
491+
_shapes[1] = n_padding
492+
g1 = paddle.concat(
493+
[g1.squeeze(0), paddle.zeros(_shapes, dtype=g1.dtype)],
494+
axis=1,
495+
)
496+
real_nloc = nloc
497+
real_nall = nall
498+
else:
499+
# for spin
500+
real_nloc = nloc // 2
501+
real_nall = nall // 2
502+
real_n_padding = real_nall - real_nloc
503+
g1_real, g1_virtual = paddle.split(
504+
g1, [real_nloc, real_nloc], axis=1
505+
)
506+
# mix_g1: nb x real_nloc x (ng1 * 2)
507+
mix_g1 = paddle.concat([g1_real, g1_virtual], axis=2)
508+
# nb x real_nall x (ng1 * 2)
509+
g1 = paddle.nn.functional.pad(
510+
mix_g1.squeeze(0), (0, 0, 0, real_n_padding), value=0.0
511+
)
512+
513+
# assert "send_list" in comm_dict
514+
# assert "send_proc" in comm_dict
515+
# assert "recv_proc" in comm_dict
516+
# assert "send_num" in comm_dict
517+
# assert "recv_num" in comm_dict
518+
# assert "communicator" in comm_dict
519+
# print(f"g1.shape = ", g1.shape)
520+
ret = paddle_ops_deepmd_border_op(
521+
comm_dict[0],
522+
comm_dict[1],
523+
comm_dict[2],
524+
comm_dict[3],
525+
comm_dict[4],
526+
g1,
527+
comm_dict[5],
528+
paddle.to_tensor(
529+
real_nloc,
530+
dtype=paddle.int32,
531+
place=paddle.CPUPlace(),
532+
), # should be int of c++, placed on cpu
533+
paddle.to_tensor(
534+
real_nall - real_nloc,
535+
dtype=paddle.int32,
536+
place=paddle.CPUPlace(),
537+
), # should be int of c++, placed on cpu
538+
)
539+
# print(f"ret.shape = {ret.shape}")
540+
# print(f"ret[0].shape = ", ret[0].shape)
541+
g1_ext = ret.unsqueeze(0)
542+
# print(f"g1_ext.shape = ", g1_ext.shape)
543+
# exit()
544+
if has_spin:
545+
g1_real_ext, g1_virtual_ext = paddle.split(
546+
g1_ext, [ng1, ng1], dim=2
547+
)
548+
g1_ext = concat_switch_virtual(
549+
g1_real_ext, g1_virtual_ext, real_nloc
550+
)
452551

453552
g1, g2, h2 = ll.forward(
454553
g1_ext,

0 commit comments

Comments
 (0)