Skip to content

Commit a72b3af

Browse files
fix(jax): fix compatibility with flax 0.12 (#5067)
Since the latest TF and JAX have not been compatible with each other, I keep the old JAX and flax version in the CI. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **Chores** * Updated internal handling across model and descriptor modules to support Flax 0.12.0 and later versions with conditional runtime behavior based on detected Flax version. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0f06833 commit a72b3af

18 files changed

Lines changed: 216 additions & 3 deletions

deepmd/jax/atomic_model/base_atomic_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,18 @@
33
Any,
44
)
55

6+
from packaging.version import (
7+
Version,
8+
)
9+
610
from deepmd.jax.common import (
711
ArrayAPIVariable,
812
to_jax_array,
913
)
14+
from deepmd.jax.env import (
15+
flax_version,
16+
nnx,
17+
)
1018
from deepmd.jax.utils.exclude_mask import (
1119
AtomExcludeMask,
1220
PairExcludeMask,
@@ -18,6 +26,8 @@ def base_atomic_model_set_attr(name: str, value: Any) -> Any:
1826
value = to_jax_array(value)
1927
if value is not None:
2028
value = ArrayAPIVariable(value)
29+
elif Version(flax_version) >= Version("0.12.0"):
30+
value = nnx.data(value)
2131
elif name == "pair_excl" and value is not None:
2232
value = PairExcludeMask(value.ntypes, value.exclude_types)
2333
elif name == "atom_excl" and value is not None:

deepmd/jax/atomic_model/linear_atomic_model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
Optional,
55
)
66

7+
from packaging.version import (
8+
Version,
9+
)
10+
711
from deepmd.dpmodel.atomic_model.linear_atomic_model import (
812
DPZBLLinearEnergyAtomicModel as DPZBLLinearEnergyAtomicModelDP,
913
)
@@ -22,8 +26,10 @@
2226
to_jax_array,
2327
)
2428
from deepmd.jax.env import (
29+
flax_version,
2530
jax,
2631
jnp,
32+
nnx,
2733
)
2834

2935

@@ -33,13 +39,19 @@ def __setattr__(self, name: str, value: Any) -> None:
3339
value = base_atomic_model_set_attr(name, value)
3440
if name == "mapping_list":
3541
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
42+
if Version(flax_version) >= Version("0.12.0"):
43+
value = nnx.List([nnx.data(item) for item in value])
3644
elif name == "zbl_weight":
37-
value = ArrayAPIVariable(to_jax_array(value))
45+
# discard since it's only used in tests
46+
# to fix flax.errors.TraceContextError: Cannot mutate 'FlaxModule' from different trace level
47+
return
3848
elif name == "models":
3949
value = [
4050
DPAtomicModel.deserialize(value[0].serialize()),
4151
PairTabAtomicModel.deserialize(value[1].serialize()),
4252
]
53+
if Version(flax_version) >= Version("0.12.0"):
54+
value = nnx.List([nnx.data(item) for item in value])
4355
return super().__setattr__(name, value)
4456

4557
def forward_common_atomic(

deepmd/jax/atomic_model/pairtab_atomic_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
Optional,
55
)
66

7+
from packaging.version import (
8+
Version,
9+
)
10+
711
from deepmd.dpmodel.atomic_model.pairtab_atomic_model import (
812
PairTabAtomicModel as PairTabAtomicModelDP,
913
)
@@ -16,8 +20,10 @@
1620
to_jax_array,
1721
)
1822
from deepmd.jax.env import (
23+
flax_version,
1924
jax,
2025
jnp,
26+
nnx,
2127
)
2228

2329

@@ -29,6 +35,8 @@ def __setattr__(self, name: str, value: Any) -> None:
2935
value = to_jax_array(value)
3036
if value is not None:
3137
value = ArrayAPIVariable(value)
38+
elif Version(flax_version) >= Version("0.12.0"):
39+
value = nnx.data(value)
3240
return super().__setattr__(name, value)
3341

3442
def forward_common_atomic(

deepmd/jax/descriptor/dpa1.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Any,
44
)
55

6+
from packaging.version import (
7+
Version,
8+
)
9+
610
from deepmd.dpmodel.descriptor.dpa1 import DescrptBlockSeAtten as DescrptBlockSeAttenDP
711
from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP
812
from deepmd.dpmodel.descriptor.dpa1 import GatedAttentionLayer as GatedAttentionLayerDP
@@ -20,6 +24,10 @@
2024
from deepmd.jax.descriptor.base_descriptor import (
2125
BaseDescriptor,
2226
)
27+
from deepmd.jax.env import (
28+
flax_version,
29+
nnx,
30+
)
2331
from deepmd.jax.utils.exclude_mask import (
2432
PairExcludeMask,
2533
)
@@ -58,6 +66,8 @@ def __setattr__(self, name: str, value: Any) -> None:
5866
value = [
5967
NeighborGatedAttentionLayer.deserialize(ii.serialize()) for ii in value
6068
]
69+
if Version(flax_version) >= Version("0.12.0"):
70+
value = nnx.List([nnx.data(item) for item in value])
6171
return super().__setattr__(name, value)
6272

6373

@@ -68,9 +78,13 @@ def __setattr__(self, name: str, value: Any) -> None:
6878
value = to_jax_array(value)
6979
if value is not None:
7080
value = ArrayAPIVariable(value)
81+
elif Version(flax_version) >= Version("0.12.0"):
82+
value = nnx.data(value)
7183
elif name in {"embeddings", "embeddings_strip"}:
7284
if value is not None:
7385
value = NetworkCollection.deserialize(value.serialize())
86+
elif Version(flax_version) >= Version("0.12.0"):
87+
value = nnx.data(value)
7488
elif name == "dpa1_attention":
7589
value = NeighborGatedAttention.deserialize(value.serialize())
7690
elif name == "env_mat":

deepmd/jax/descriptor/dpa2.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Any,
44
)
55

6+
from packaging.version import (
7+
Version,
8+
)
9+
610
from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP
711
from deepmd.dpmodel.utils.network import Identity as IdentityDP
812
from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP
@@ -23,6 +27,10 @@
2327
from deepmd.jax.descriptor.se_t_tebd import (
2428
DescrptBlockSeTTebd,
2529
)
30+
from deepmd.jax.env import (
31+
flax_version,
32+
nnx,
33+
)
2634
from deepmd.jax.utils.network import (
2735
NativeLayer,
2836
)
@@ -39,18 +47,23 @@ def __setattr__(self, name: str, value: Any) -> None:
3947
value = to_jax_array(value)
4048
if value is not None:
4149
value = ArrayAPIVariable(value)
50+
elif Version(flax_version) >= Version("0.12.0"):
51+
value = nnx.data(value)
4252
elif name in {"repinit"}:
4353
value = DescrptBlockSeAtten.deserialize(value.serialize())
4454
elif name in {"repinit_three_body"}:
4555
if value is not None:
4656
value = DescrptBlockSeTTebd.deserialize(value.serialize())
57+
elif Version(flax_version) >= Version("0.12.0"):
58+
value = nnx.data(value)
4759
elif name in {"repformers"}:
4860
value = DescrptBlockRepformers.deserialize(value.serialize())
4961
elif name in {"type_embedding"}:
5062
value = TypeEmbedNet.deserialize(value.serialize())
5163
elif name in {"g1_shape_tranform", "tebd_transform"}:
5264
if value is None:
53-
pass
65+
if Version(flax_version) >= Version("0.12.0"):
66+
value = nnx.data(value)
5467
elif isinstance(value, NativeLayerDP):
5568
value = NativeLayer.deserialize(value.serialize())
5669
elif isinstance(value, IdentityDP):

deepmd/jax/descriptor/dpa3.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Any,
44
)
55

6+
from packaging.version import (
7+
Version,
8+
)
9+
610
from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP
711
from deepmd.jax.common import (
812
ArrayAPIVariable,
@@ -15,6 +19,10 @@
1519
from deepmd.jax.descriptor.repflows import (
1620
DescrptBlockRepflows,
1721
)
22+
from deepmd.jax.env import (
23+
flax_version,
24+
nnx,
25+
)
1826
from deepmd.jax.utils.type_embed import (
1927
TypeEmbedNet,
2028
)
@@ -28,6 +36,8 @@ def __setattr__(self, name: str, value: Any) -> None:
2836
value = to_jax_array(value)
2937
if value is not None:
3038
value = ArrayAPIVariable(value)
39+
elif Version(flax_version) >= Version("0.12.0"):
40+
value = nnx.data(value)
3141
elif name in {"repflows"}:
3242
value = DescrptBlockRepflows.deserialize(value.serialize())
3343
elif name in {"type_embedding"}:

deepmd/jax/descriptor/hybrid.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Any,
44
)
55

6+
from packaging.version import (
7+
Version,
8+
)
9+
610
from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP
711
from deepmd.jax.common import (
812
ArrayAPIVariable,
@@ -12,6 +16,10 @@
1216
from deepmd.jax.descriptor.base_descriptor import (
1317
BaseDescriptor,
1418
)
19+
from deepmd.jax.env import (
20+
flax_version,
21+
nnx,
22+
)
1523

1624

1725
@BaseDescriptor.register("hybrid")
@@ -20,7 +28,11 @@ class DescrptHybrid(DescrptHybridDP):
2028
def __setattr__(self, name: str, value: Any) -> None:
2129
if name in {"nlist_cut_idx"}:
2230
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
31+
if Version(flax_version) >= Version("0.12.0"):
32+
value = nnx.List([nnx.data(item) for item in value])
2333
elif name in {"descrpt_list"}:
2434
value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value]
35+
if Version(flax_version) >= Version("0.12.0"):
36+
value = nnx.List([nnx.data(item) for item in value])
2537

2638
return super().__setattr__(name, value)

deepmd/jax/descriptor/repflows.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Any,
44
)
55

6+
from packaging.version import (
7+
Version,
8+
)
9+
610
from deepmd.dpmodel.descriptor.repflows import (
711
DescrptBlockRepflows as DescrptBlockRepflowsDP,
812
)
@@ -12,6 +16,10 @@
1216
flax_module,
1317
to_jax_array,
1418
)
19+
from deepmd.jax.env import (
20+
flax_version,
21+
nnx,
22+
)
1523
from deepmd.jax.utils.exclude_mask import (
1624
PairExcludeMask,
1725
)
@@ -27,8 +35,12 @@ def __setattr__(self, name: str, value: Any) -> None:
2735
value = to_jax_array(value)
2836
if value is not None:
2937
value = ArrayAPIVariable(value)
38+
elif Version(flax_version) >= Version("0.12.0"):
39+
value = nnx.data(value)
3040
elif name in {"layers"}:
3141
value = [RepFlowLayer.deserialize(layer.serialize()) for layer in value]
42+
if Version(flax_version) >= Version("0.12.0"):
43+
value = nnx.List([nnx.data(item) for item in value])
3244
elif name in {"edge_embd", "angle_embd"}:
3345
value = NativeLayer.deserialize(value.serialize())
3446
elif name in {"env_mat_edge", "env_mat_angle"}:
@@ -58,8 +70,12 @@ def __setattr__(self, name: str, value: Any) -> None:
5870
}:
5971
if value is not None:
6072
value = NativeLayer.deserialize(value.serialize())
73+
elif Version(flax_version) >= Version("0.12.0"):
74+
value = nnx.data(value)
6175
elif name in {"n_residual", "e_residual", "a_residual"}:
6276
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
77+
if Version(flax_version) >= Version("0.12.0"):
78+
value = nnx.List([nnx.data(item) for item in value])
6379
else:
6480
pass
6581
return super().__setattr__(name, value)

deepmd/jax/descriptor/repformers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
Any,
44
)
55

6+
from packaging.version import (
7+
Version,
8+
)
9+
610
from deepmd.dpmodel.descriptor.repformers import (
711
Atten2EquiVarApply as Atten2EquiVarApplyDP,
812
)
@@ -20,6 +24,10 @@
2024
flax_module,
2125
to_jax_array,
2226
)
27+
from deepmd.jax.env import (
28+
flax_version,
29+
nnx,
30+
)
2331
from deepmd.jax.utils.exclude_mask import (
2432
PairExcludeMask,
2533
)
@@ -36,8 +44,12 @@ def __setattr__(self, name: str, value: Any) -> None:
3644
value = to_jax_array(value)
3745
if value is not None:
3846
value = ArrayAPIVariable(value)
47+
elif Version(flax_version) >= Version("0.12.0"):
48+
value = nnx.data(value)
3949
elif name in {"layers"}:
4050
value = [RepformerLayer.deserialize(layer.serialize()) for layer in value]
51+
if Version(flax_version) >= Version("0.12.0"):
52+
value = nnx.List([nnx.data(item) for item in value])
4153
elif name == "g2_embd":
4254
value = NativeLayer.deserialize(value.serialize())
4355
elif name == "env_mat":
@@ -87,21 +99,35 @@ def __setattr__(self, name: str, value: Any) -> None:
8799
if name in {"linear1", "linear2", "g1_self_mlp", "proj_g1g2", "proj_g1g1g2"}:
88100
if value is not None:
89101
value = NativeLayer.deserialize(value.serialize())
102+
elif Version(flax_version) >= Version("0.12.0"):
103+
value = nnx.data(value)
90104
elif name in {"g1_residual", "g2_residual", "h2_residual"}:
91105
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
106+
if Version(flax_version) >= Version("0.12.0"):
107+
value = nnx.List([nnx.data(item) for item in value])
92108
elif name in {"attn2g_map"}:
93109
if value is not None:
94110
value = Atten2Map.deserialize(value.serialize())
111+
elif Version(flax_version) >= Version("0.12.0"):
112+
value = nnx.data(value)
95113
elif name in {"attn2_mh_apply"}:
96114
if value is not None:
97115
value = Atten2MultiHeadApply.deserialize(value.serialize())
116+
elif Version(flax_version) >= Version("0.12.0"):
117+
value = nnx.data(value)
98118
elif name in {"attn2_lm"}:
99119
if value is not None:
100120
value = LayerNorm.deserialize(value.serialize())
121+
elif Version(flax_version) >= Version("0.12.0"):
122+
value = nnx.data(value)
101123
elif name in {"attn2_ev_apply"}:
102124
if value is not None:
103125
value = Atten2EquiVarApply.deserialize(value.serialize())
126+
elif Version(flax_version) >= Version("0.12.0"):
127+
value = nnx.data(value)
104128
elif name in {"loc_attn"}:
105129
if value is not None:
106130
value = LocalAtten.deserialize(value.serialize())
131+
elif Version(flax_version) >= Version("0.12.0"):
132+
value = nnx.data(value)
107133
return super().__setattr__(name, value)

0 commit comments

Comments
 (0)