Skip to content

Commit 4ff3e3a

Browse files
author
nle18370
committed
bug fix on 2026-03-30
1 parent cc7a52c commit 4ff3e3a

16 files changed

Lines changed: 164 additions & 119 deletions

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,36 @@ Created and maintained by Makoto Takamoto
2626
Francesco Alesiani, Dirk Pflüger, and Mathias Niepert.
2727

2828
---
29+
## Compatibility with Newer Python, JAX, PyTorch, and CUDA Versions
30+
31+
While this project officially targets older versions:
32+
- Python 3.9
33+
- JAX 0.4.11
34+
- PyTorch 1.13.0
35+
- CUDA 11.7
36+
37+
we have verified that most components remain compatible with newer releases.:
38+
- Python 3.12
39+
- JAX 0.9.2
40+
- PyTorch 2.11.0
41+
- CUDA 13.0
42+
43+
In particular, the core codebase, forward training in `pdebench/model`, and data generation in `data_gen_NLE` work with these newer versions.
44+
Compatibility for the remaining components is still under investigation.
45+
_Last updated: 2026-03-30_
46+
47+
## Data Path Issue in Forward Training (`model/fno/utils.py` and `model/unet/utils.py`)
48+
49+
We identified an issue with data paths during forward training in `model/fno/utils.py` and `model/unet/utils.py`. This problem is caused by Hydra modifying the `data_path` at runtime.
50+
51+
To resolve this issue, please use the commented lines:
52+
- Lines 185–188 in `model/fno/utils.py`
53+
- Lines 185–187 in `model/unet/utils.py`
54+
55+
(Note that the above modification (use of "to_absolute_path" in hydra) is allowed Hydra version >= 0.11.0.)
56+
57+
_Last updated: 2026-03-30_
58+
2959

3060
## Datasets and Pretrained Models
3161

pdebench/data_gen/data_gen_NLE/CompressibleFluid/CFD_multi_Hydra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@
159159
import jax.numpy as jnp
160160
from jax import device_put, jit, lax
161161
from omegaconf import DictConfig
162+
163+
sys.path.append("..") # if PDEBench cannot recognize the location of utils
162164
from utils import (
163165
Courant_HD,
164166
Courant_vis_HD,
@@ -180,8 +182,6 @@
180182
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
181183
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".9"
182184

183-
sys.path.append("..")
184-
185185
# if double precision
186186
# from jax.config import config
187187
# config.update("jax_enable_x64", True)

pdebench/data_gen/data_gen_NLE/CompressibleFluid/run_trainset_1D.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ nn=1
33
key=2020
44
while [ "$nn" -le 10 ]; do
55
CUDA_VISIBLE_DEVICES='0,1' python3 CFD_multi_Hydra.py +args=1D_Multi.yaml ++args.init_key="$key"
6-
nn=$(${nn} + 1)
7-
key=$(${key} + 1)
6+
nn=$((nn + 1))
7+
key=$((key + 1))
88
echo "$nn"
99
echo "$key"
1010
done

pdebench/data_gen/data_gen_NLE/CompressibleFluid/run_trainset_1DShock.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ nn=1
33
key=2031
44
while [ "$nn" -le 10 ]; do
55
CUDA_VISIBLE_DEVICES='0,1' python3 CFD_multi_Hydra.py +args=1D_Multi_shock.yaml ++args.init_key="$key"
6-
nn=$(${nn} + 1)
7-
key=$(${key} + 1)
6+
nn=$((nn + 1))
7+
key=$((key + 1))
88
echo "$nn"
99
echo "$key"
1010
done

pdebench/data_gen/data_gen_NLE/CompressibleFluid/run_trainset_1D_trans.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ nn=1
33
key=2020
44
while [ "$nn" -le 10 ]; do
55
CUDA_VISIBLE_DEVICES='0,1' python3 CFD_multi_Hydra.py +args=1D_Multi_trans.yaml ++args.init_key="$key"
6-
nn=$(${nn} + 1)
7-
key=$(${key} + 1)
6+
nn=$((nn + 1))
7+
key=$((key + 1))
88
echo "$nn"
99
echo "$key"
1010
done

pdebench/data_gen/data_gen_NLE/CompressibleFluid/run_trainset_2D.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ key=2031
44
while [ "$nn" -le 100 ]; do
55
CUDA_VISIBLE_DEVICES='0,1,2,3' python3 CFD_multi_Hydra.py +args=2D_Multi_Rand.yaml ++args.init_key="$key"
66
#CUDA_VISIBLE_DEVICES='0,1,2,3' python3 CFD_multi_Hydra.py +args=2D_Multi_Rand_HR.yaml ++args.init_key="$key"
7-
nn=$(${nn} + 1)
8-
key=$(${key} + 1)
7+
nn=$((nn + 1))
8+
key=$((key + 1))
99
echo "$nn"
1010
echo "$key"
1111
done

pdebench/data_gen/data_gen_NLE/CompressibleFluid/run_trainset_2DTurb.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ key=2031
44
#while [ "$nn" -le 100 ]; do
55
while [ "$nn" -le 55 ]; do
66
CUDA_VISIBLE_DEVICES='0,1,2,3' python3 CFD_multi_Hydra.py +args=2D_Multi_Turb.yaml ++args.init_key="$key"
7-
nn=$(${nn} + 1)
8-
key=$(${key} + 1)
7+
nn=$((nn + 1))
8+
key=$((key + 1))
99
echo "$nn"
1010
echo "$key"
1111
done

pdebench/data_gen/data_gen_NLE/CompressibleFluid/run_trainset_3D.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ nn=1
33
key=2031
44
while [ "$nn" -le 10 ]; do
55
CUDA_VISIBLE_DEVICES='0,1' python3 CFD_multi_Hydra.py +args=3D_Multi_Rand.yaml ++args.init_key="$key"
6-
nn=$(${nn} + 1)
7-
key=$(${key} + 1)
6+
nn=$((nn + 1))
7+
key=$((key + 1))
88
echo "$nn"
99
echo "$key"
1010
done

pdebench/data_gen/data_gen_NLE/CompressibleFluid/run_trainset_3DTurb.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ nn=1
33
key=2031
44
while [ "$nn" -le 6 ]; do
55
CUDA_VISIBLE_DEVICES='0,1,2,3' python3 CFD_multi_Hydra.py +args=3D_Multi_TurbM1.yaml ++args.init_key="$key"
6-
nn=$(${nn} + 1)
7-
key=$(${key} + 1)
6+
nn=$((nn + 1))
7+
key=$((key + 1))
88
echo "$nn"
99
echo "$key"
1010
done

pdebench/data_gen/data_gen_NLE/ReactionDiffusionEq/run_DarcyFlow2D.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ key=2020
44
while [ "$nn" -le 50 ]; do
55
CUDA_VISIBLE_DEVICES='0,1' python3 reaction_diffusion_2D_multi_solution_Hydra.py +multi=config_2D.yaml ++multi.init_k\
66
ey="$key"
7-
nn=$(${nn} + 1)
8-
key=$(${key} + 1)
7+
nn=$((nn + 1))
8+
key=$((key + 1))
99
echo "$nn"
1010
echo "$key"
1111
done

0 commit comments

Comments
 (0)