Skip to content

Commit fc455c9

Browse files
committed
fix(jax): fix compatibility with flax 0.12
1 parent 1ee33c8 commit fc455c9

19 files changed

Lines changed: 202 additions & 5 deletions

File tree

.github/workflows/test_python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu
3333
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
3434
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
35-
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py "jax==0.5.0;python_version>='3.10'"
35+
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py "jax==0.8.1;python_version>='3.10'"
3636
source/install/uv_with_retry.sh pip install --system -U setuptools
3737
source/install/uv_with_retry.sh pip install --system horovod --no-build-isolation
3838
source/install/uv_with_retry.sh pip install --system --pre "paddlepaddle==3.0.0" -i https://www.paddlepaddle.org.cn/packages/stable/cpu/

deepmd/jax/atomic_model/linear_atomic_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
jax,
2626
jnp,
2727
)
28+
from packaging.version import (
29+
Version,
30+
)
31+
from deepmd.jax.env import (
32+
flax_version,
33+
nnx,
34+
)
2835

2936

3037
@flax_module
@@ -33,13 +40,17 @@ def __setattr__(self, name: str, value: Any) -> None:
3340
value = base_atomic_model_set_attr(name, value)
3441
if name == "mapping_list":
3542
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
43+
if Version(flax_version) >= Version("0.12.0"):
44+
value = nnx.List([nnx.data(item) for item in value])
3645
elif name == "zbl_weight":
3746
value = ArrayAPIVariable(to_jax_array(value))
3847
elif name == "models":
3948
value = [
4049
DPAtomicModel.deserialize(value[0].serialize()),
4150
PairTabAtomicModel.deserialize(value[1].serialize()),
4251
]
52+
if Version(flax_version) >= Version("0.12.0"):
53+
value = nnx.List([nnx.data(item) for item in value])
4354
return super().__setattr__(name, value)
4455

4556
def forward_common_atomic(

deepmd/jax/descriptor/dpa1.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
from deepmd.jax.utils.type_embed import (
3232
TypeEmbedNet,
3333
)
34+
from packaging.version import (
35+
Version,
36+
)
37+
from deepmd.jax.env import (
38+
flax_version,
39+
nnx,
40+
)
3441

3542

3643
@flax_module
@@ -58,6 +65,8 @@ def __setattr__(self, name: str, value: Any) -> None:
5865
value = [
5966
NeighborGatedAttentionLayer.deserialize(ii.serialize()) for ii in value
6067
]
68+
if Version(flax_version) >= Version("0.12.0"):
69+
value = nnx.List([nnx.data(item) for item in value])
6170
return super().__setattr__(name, value)
6271

6372

@@ -71,6 +80,8 @@ def __setattr__(self, name: str, value: Any) -> None:
7180
elif name in {"embeddings", "embeddings_strip"}:
7281
if value is not None:
7382
value = NetworkCollection.deserialize(value.serialize())
83+
elif Version(flax_version) >= Version("0.12.0"):
84+
value = nnx.data(value)
7485
elif name == "dpa1_attention":
7586
value = NeighborGatedAttention.deserialize(value.serialize())
7687
elif name == "env_mat":

deepmd/jax/descriptor/dpa2.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
from deepmd.jax.utils.type_embed import (
3030
TypeEmbedNet,
3131
)
32+
from packaging.version import (
33+
Version,
34+
)
35+
from deepmd.jax.env import (
36+
flax_version,
37+
nnx,
38+
)
3239

3340

3441
@BaseDescriptor.register("dpa2")
@@ -44,13 +51,16 @@ def __setattr__(self, name: str, value: Any) -> None:
4451
elif name in {"repinit_three_body"}:
4552
if value is not None:
4653
value = DescrptBlockSeTTebd.deserialize(value.serialize())
54+
elif Version(flax_version) >= Version("0.12.0"):
55+
value = nnx.data(value)
4756
elif name in {"repformers"}:
4857
value = DescrptBlockRepformers.deserialize(value.serialize())
4958
elif name in {"type_embedding"}:
5059
value = TypeEmbedNet.deserialize(value.serialize())
5160
elif name in {"g1_shape_tranform", "tebd_transform"}:
5261
if value is None:
53-
pass
62+
if Version(flax_version) >= Version("0.12.0"):
63+
value = nnx.data(value)
5464
elif isinstance(value, NativeLayerDP):
5565
value = NativeLayer.deserialize(value.serialize())
5666
elif isinstance(value, IdentityDP):

deepmd/jax/descriptor/hybrid.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,25 @@
1212
from deepmd.jax.descriptor.base_descriptor import (
1313
BaseDescriptor,
1414
)
15-
15+
from packaging.version import (
16+
Version,
17+
)
18+
from deepmd.jax.env import (
19+
flax_version,
20+
nnx,
21+
)
1622

1723
@BaseDescriptor.register("hybrid")
1824
@flax_module
1925
class DescrptHybrid(DescrptHybridDP):
2026
def __setattr__(self, name: str, value: Any) -> None:
2127
if name in {"nlist_cut_idx"}:
2228
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
29+
if Version(flax_version) >= Version("0.12.0"):
30+
value = nnx.List([nnx.data(item) for item in value])
2331
elif name in {"descrpt_list"}:
2432
value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value]
33+
if Version(flax_version) >= Version("0.12.0"):
34+
value = nnx.List([nnx.data(item) for item in value])
2535

2636
return super().__setattr__(name, value)

deepmd/jax/descriptor/repflows.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
from deepmd.jax.utils.network import (
1919
NativeLayer,
2020
)
21+
from packaging.version import (
22+
Version,
23+
)
24+
from deepmd.jax.env import (
25+
flax_version,
26+
nnx,
27+
)
2128

2229

2330
@flax_module
@@ -29,6 +36,8 @@ def __setattr__(self, name: str, value: Any) -> None:
2936
value = ArrayAPIVariable(value)
3037
elif name in {"layers"}:
3138
value = [RepFlowLayer.deserialize(layer.serialize()) for layer in value]
39+
if Version(flax_version) >= Version("0.12.0"):
40+
value = nnx.List([nnx.data(item) for item in value])
3241
elif name in {"edge_embd", "angle_embd"}:
3342
value = NativeLayer.deserialize(value.serialize())
3443
elif name in {"env_mat_edge", "env_mat_angle"}:
@@ -58,8 +67,12 @@ def __setattr__(self, name: str, value: Any) -> None:
5867
}:
5968
if value is not None:
6069
value = NativeLayer.deserialize(value.serialize())
70+
elif Version(flax_version) >= Version("0.12.0"):
71+
value = nnx.data(value)
6172
elif name in {"n_residual", "e_residual", "a_residual"}:
6273
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
74+
if Version(flax_version) >= Version("0.12.0"):
75+
value = nnx.List([nnx.data(item) for item in value])
6376
else:
6477
pass
6578
return super().__setattr__(name, value)

deepmd/jax/descriptor/repformers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@
2727
LayerNorm,
2828
NativeLayer,
2929
)
30+
from packaging.version import (
31+
Version,
32+
)
33+
from deepmd.jax.env import (
34+
flax_version,
35+
nnx,
36+
)
3037

3138

3239
@flax_module
@@ -38,6 +45,8 @@ def __setattr__(self, name: str, value: Any) -> None:
3845
value = ArrayAPIVariable(value)
3946
elif name in {"layers"}:
4047
value = [RepformerLayer.deserialize(layer.serialize()) for layer in value]
48+
if Version(flax_version) >= Version("0.12.0"):
49+
value = nnx.List([nnx.data(item) for item in value])
4150
elif name == "g2_embd":
4251
value = NativeLayer.deserialize(value.serialize())
4352
elif name == "env_mat":
@@ -87,21 +96,35 @@ def __setattr__(self, name: str, value: Any) -> None:
8796
if name in {"linear1", "linear2", "g1_self_mlp", "proj_g1g2", "proj_g1g1g2"}:
8897
if value is not None:
8998
value = NativeLayer.deserialize(value.serialize())
99+
elif Version(flax_version) >= Version("0.12.0"):
100+
value = nnx.data(value)
90101
elif name in {"g1_residual", "g2_residual", "h2_residual"}:
91102
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
103+
if Version(flax_version) >= Version("0.12.0"):
104+
value = nnx.List([nnx.data(item) for item in value])
92105
elif name in {"attn2g_map"}:
93106
if value is not None:
94107
value = Atten2Map.deserialize(value.serialize())
108+
elif Version(flax_version) >= Version("0.12.0"):
109+
value = nnx.data(value)
95110
elif name in {"attn2_mh_apply"}:
96111
if value is not None:
97112
value = Atten2MultiHeadApply.deserialize(value.serialize())
113+
elif Version(flax_version) >= Version("0.12.0"):
114+
value = nnx.data(value)
98115
elif name in {"attn2_lm"}:
99116
if value is not None:
100117
value = LayerNorm.deserialize(value.serialize())
118+
elif Version(flax_version) >= Version("0.12.0"):
119+
value = nnx.data(value)
101120
elif name in {"attn2_ev_apply"}:
102121
if value is not None:
103122
value = Atten2EquiVarApply.deserialize(value.serialize())
123+
elif Version(flax_version) >= Version("0.12.0"):
124+
value = nnx.data(value)
104125
elif name in {"loc_attn"}:
105126
if value is not None:
106127
value = LocalAtten.deserialize(value.serialize())
128+
elif Version(flax_version) >= Version("0.12.0"):
129+
value = nnx.data(value)
107130
return super().__setattr__(name, value)

deepmd/jax/descriptor/se_e2_a.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
NetworkCollection,
2020
)
2121

22+
from packaging.version import (
23+
Version,
24+
)
25+
from deepmd.jax.env import (
26+
flax_version,
27+
nnx,
28+
)
29+
2230

2331
@BaseDescriptor.register("se_e2_a")
2432
@BaseDescriptor.register("se_a")
@@ -32,6 +40,8 @@ def __setattr__(self, name: str, value: Any) -> None:
3240
elif name in {"embeddings"}:
3341
if value is not None:
3442
value = NetworkCollection.deserialize(value.serialize())
43+
elif Version(flax_version) >= Version("0.12.0"):
44+
value = nnx.data(value)
3545
elif name == "env_mat":
3646
# env_mat doesn't store any value
3747
pass

deepmd/jax/descriptor/se_e2_r.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
NetworkCollection,
2020
)
2121

22+
from packaging.version import (
23+
Version,
24+
)
25+
from deepmd.jax.env import (
26+
flax_version,
27+
nnx,
28+
)
29+
2230

2331
@BaseDescriptor.register("se_e2_r")
2432
@BaseDescriptor.register("se_r")
@@ -32,6 +40,8 @@ def __setattr__(self, name: str, value: Any) -> None:
3240
elif name in {"embeddings"}:
3341
if value is not None:
3442
value = NetworkCollection.deserialize(value.serialize())
43+
elif Version(flax_version) >= Version("0.12.0"):
44+
value = nnx.data(value)
3545
elif name == "env_mat":
3646
# env_mat doesn't store any value
3747
pass

deepmd/jax/env.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import jax.numpy as jnp
88
from flax import (
99
nnx,
10+
__version__ as flax_version,
1011
)
1112
from jax import export as jax_export
1213

@@ -23,4 +24,5 @@
2324
"jax_export",
2425
"jnp",
2526
"nnx",
27+
"flax_version",
2628
]

0 commit comments

Comments
 (0)