2020from deepmd .dpmodel .array_api import (
2121 Array ,
2222 xp_take_along_axis ,
23+ xp_take_first_n ,
2324)
2425from deepmd .dpmodel .common import (
2526 cast_precision ,
@@ -344,6 +345,7 @@ def __init__(
344345 self .concat_output_tebd = concat_output_tebd
345346 self .trainable = trainable
346347 self .precision = precision
348+ self .compress = False
347349
348350 def get_rcut (self ) -> float :
349351 """Returns the cut-off radius."""
@@ -535,7 +537,7 @@ def call(
535537 (nf , nall , self .tebd_dim ),
536538 )
537539 # nfnl x tebd_dim
538- atype_embd = atype_embd_ext [:, : nloc , :]
540+ atype_embd = xp_take_first_n ( atype_embd_ext , 1 , nloc )
539541 grrg , g2 , h2 , rot_mat , sw = self .se_atten (
540542 nlist ,
541543 coord_ext ,
@@ -557,7 +559,7 @@ def serialize(self) -> dict:
557559 data = {
558560 "@class" : "Descriptor" ,
559561 "type" : "dpa1" ,
560- "@version" : 2 ,
562+ "@version" : 3 if self . compress else 2 ,
561563 "rcut" : obj .rcut ,
562564 "rcut_smth" : obj .rcut_smth ,
563565 "sel" : obj .sel ,
@@ -602,20 +604,36 @@ def serialize(self) -> dict:
602604 }
603605 if obj .tebd_input_mode in ["strip" ]:
604606 data .update ({"embeddings_strip" : obj .embeddings_strip .serialize ()})
607+ if self .compress :
608+ compress_dict : dict = {
609+ "@variables" : {
610+ "type_embd_data" : to_numpy_array (self .type_embd_data ),
611+ },
612+ "geo_compress" : self .geo_compress ,
613+ }
614+ if self .geo_compress :
615+ compress_dict ["@variables" ]["compress_data" ] = [
616+ to_numpy_array (d ) for d in self .compress_data
617+ ]
618+ compress_dict ["@variables" ]["compress_info" ] = [
619+ to_numpy_array (i ) for i in self .compress_info
620+ ]
621+ data ["compress" ] = compress_dict
605622 return data
606623
607624 @classmethod
608625 def deserialize (cls , data : dict ) -> "DescrptDPA1" :
609626 """Deserialize from dict."""
610627 data = data .copy ()
611- check_version_compatibility (data .pop ("@version" ), 2 , 1 )
628+ check_version_compatibility (data .pop ("@version" ), 3 , 1 )
612629 data .pop ("@class" )
613630 data .pop ("type" )
614631 variables = data .pop ("@variables" )
615632 embeddings = data .pop ("embeddings" )
616633 type_embedding = data .pop ("type_embedding" )
617634 attention_layers = data .pop ("attention_layers" )
618635 env_mat = data .pop ("env_mat" )
636+ compress = data .pop ("compress" , None )
619637 tebd_input_mode = data ["tebd_input_mode" ]
620638 if tebd_input_mode in ["strip" ]:
621639 embeddings_strip = data .pop ("embeddings_strip" )
@@ -637,8 +655,20 @@ def deserialize(cls, data: dict) -> "DescrptDPA1":
637655 obj .se_atten .dpa1_attention = NeighborGatedAttention .deserialize (
638656 attention_layers
639657 )
658+ if compress is not None :
659+ obj ._load_compress_data (compress )
640660 return obj
641661
662+ def _load_compress_data (self , compress : dict ) -> None :
663+ """Load compression state from serialized data."""
664+ variables = compress ["@variables" ]
665+ self .type_embd_data = variables ["type_embd_data" ]
666+ self .geo_compress = compress .get ("geo_compress" , False )
667+ if self .geo_compress :
668+ self .compress_data = variables ["compress_data" ]
669+ self .compress_info = variables ["compress_info" ]
670+ self .compress = True
671+
642672 @classmethod
643673 def update_sel (
644674 cls ,
@@ -1057,7 +1087,7 @@ def call(
10571087 self .stddev [...],
10581088 )
10591089 nf , nloc , nnei , _ = dmatrix .shape
1060- atype = atype_ext [:, : nloc ]
1090+ atype = xp_take_first_n ( atype_ext , 1 , nloc )
10611091 exclude_mask = self .emask .build_type_exclude_mask (nlist , atype_ext )
10621092 # nfnl x nnei
10631093 exclude_mask = xp .reshape (exclude_mask , (nf * nloc , nnei ))
@@ -1076,6 +1106,12 @@ def call(
10761106 nlist_masked = xp .where (nlist_mask , nlist , xp .zeros_like (nlist ))
10771107 ng = self .neuron [- 1 ]
10781108 nt = self .tebd_dim
1109+
1110+ # Gather neighbor info using xp_take_along_axis along axis=1.
1111+ # This avoids flat (nf*nall,) indexing that creates Ne(nall, nloc)
1112+ # constraints in torch.export, breaking NoPbc (nall == nloc).
1113+ nlist_2d = xp .reshape (nlist_masked , (nf , nloc * nnei )) # (nf, nloc*nnei)
1114+
10791115 # nfnl x nnei x 4
10801116 rr = xp .reshape (dmatrix , (nf * nloc , nnei , 4 ))
10811117 rr = rr * xp .astype (exclude_mask [:, :, None ], rr .dtype )
@@ -1084,15 +1120,16 @@ def call(
10841120 if self .tebd_input_mode in ["concat" ]:
10851121 # nfnl x tebd_dim
10861122 atype_embd = xp .reshape (
1087- atype_embd_ext [:, : nloc , :] , (nf * nloc , self .tebd_dim )
1123+ xp_take_first_n ( atype_embd_ext , 1 , nloc ) , (nf * nloc , self .tebd_dim )
10881124 )
10891125 # nfnl x nnei x tebd_dim
10901126 atype_embd_nnei = xp .tile (atype_embd [:, xp .newaxis , :], (1 , nnei , 1 ))
1091- index = xp .tile (
1092- xp .reshape (nlist_masked , (nf , - 1 , 1 )), (1 , 1 , self .tebd_dim )
1127+ # Gather neighbor type embeddings: (nf, nall, tebd_dim) -> (nf, nloc*nnei, tebd_dim)
1128+ nlist_idx_tebd = xp .tile (nlist_2d [:, :, xp .newaxis ], (1 , 1 , self .tebd_dim ))
1129+ atype_embd_nlist = xp_take_along_axis (
1130+ atype_embd_ext , nlist_idx_tebd , axis = 1
10931131 )
10941132 # nfnl x nnei x tebd_dim
1095- atype_embd_nlist = xp_take_along_axis (atype_embd_ext , index , axis = 1 )
10961133 atype_embd_nlist = xp .reshape (
10971134 atype_embd_nlist , (nf * nloc , nnei , self .tebd_dim )
10981135 )
@@ -1111,10 +1148,9 @@ def call(
11111148 assert self .embeddings_strip is not None
11121149 assert type_embedding is not None
11131150 ntypes_with_padding = type_embedding .shape [0 ]
1114- # nf x (nl x nnei)
1115- nlist_index = xp .reshape (nlist_masked , (nf , nloc * nnei ))
1116- # nf x (nl x nnei)
1117- nei_type = xp_take_along_axis (atype_ext , nlist_index , axis = 1 )
1151+ # Gather neighbor types: (nf, nall) -> (nf, nloc*nnei)
1152+ nei_type = xp_take_along_axis (atype_ext , nlist_2d , axis = 1 )
1153+ nei_type = xp .reshape (nei_type , (- 1 ,)) # (nf * nloc * nnei,)
11181154 # (nf x nl x nnei) x ng
11191155 nei_type_index = xp .tile (xp .reshape (nei_type , (- 1 , 1 )), (1 , ng ))
11201156 if self .type_one_side :
0 commit comments