2727 LayerNorm ,
2828 NativeLayer ,
2929)
30+ from packaging .version import (
31+ Version ,
32+ )
33+ from deepmd .jax .env import (
34+ flax_version ,
35+ nnx ,
36+ )
3037
3138
3239@flax_module
@@ -38,6 +45,8 @@ def __setattr__(self, name: str, value: Any) -> None:
3845 value = ArrayAPIVariable (value )
3946 elif name in {"layers" }:
4047 value = [RepformerLayer .deserialize (layer .serialize ()) for layer in value ]
48+ if Version (flax_version ) >= Version ("0.12.0" ):
49+ value = nnx .List ([nnx .data (item ) for item in value ])
4150 elif name == "g2_embd" :
4251 value = NativeLayer .deserialize (value .serialize ())
4352 elif name == "env_mat" :
@@ -87,21 +96,35 @@ def __setattr__(self, name: str, value: Any) -> None:
8796 if name in {"linear1" , "linear2" , "g1_self_mlp" , "proj_g1g2" , "proj_g1g1g2" }:
8897 if value is not None :
8998 value = NativeLayer .deserialize (value .serialize ())
99+ elif Version (flax_version ) >= Version ("0.12.0" ):
100+ value = nnx .data (value )
90101 elif name in {"g1_residual" , "g2_residual" , "h2_residual" }:
91102 value = [ArrayAPIVariable (to_jax_array (vv )) for vv in value ]
103+ if Version (flax_version ) >= Version ("0.12.0" ):
104+ value = nnx .List ([nnx .data (item ) for item in value ])
92105 elif name in {"attn2g_map" }:
93106 if value is not None :
94107 value = Atten2Map .deserialize (value .serialize ())
108+ elif Version (flax_version ) >= Version ("0.12.0" ):
109+ value = nnx .data (value )
95110 elif name in {"attn2_mh_apply" }:
96111 if value is not None :
97112 value = Atten2MultiHeadApply .deserialize (value .serialize ())
113+ elif Version (flax_version ) >= Version ("0.12.0" ):
114+ value = nnx .data (value )
98115 elif name in {"attn2_lm" }:
99116 if value is not None :
100117 value = LayerNorm .deserialize (value .serialize ())
118+ elif Version (flax_version ) >= Version ("0.12.0" ):
119+ value = nnx .data (value )
101120 elif name in {"attn2_ev_apply" }:
102121 if value is not None :
103122 value = Atten2EquiVarApply .deserialize (value .serialize ())
123+ elif Version (flax_version ) >= Version ("0.12.0" ):
124+ value = nnx .data (value )
104125 elif name in {"loc_attn" }:
105126 if value is not None :
106127 value = LocalAtten .deserialize (value .serialize ())
128+ elif Version (flax_version ) >= Version ("0.12.0" ):
129+ value = nnx .data (value )
107130 return super ().__setattr__ (name , value )
0 commit comments