Skip to content

Commit 422799a

Browse files
cl126162clagemann126162ludgerpaehler
authored
Update Jax code (#244)
Updates the environments of the pure JAX backend with a more optimized version which fully utilizes JAX's features. --------- Co-authored-by: christian.lagemann <christian.lagemann@rwth-aachen.de> Co-authored-by: Ludger Paehler <paehlerludger@gmail.com>
1 parent 37da8d5 commit 422799a

10 files changed

Lines changed: 475 additions & 400 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,4 @@ dmypy.json
136136
examples/firedrake/getting_started/output/
137137

138138
# VSCode
139-
.vscode/
139+
.vscode/

examples/jax/README.md

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ jax/
3030
# Activate the GPU environment
3131
source /home/easybuild/venvs/hydrogym_gpu/bin/activate
3232

33-
# Test Kolmogorov flow
33+
# Test Kolmogorov flow (float64, 10 steps)
3434
cd getting_started/1_kolmogorov
3535
./run_kolmogorov_docker.sh
3636

37-
# Test channel flow
37+
# Test channel flow (float32, 5 steps)
3838
cd getting_started/2_channel
3939
./run_channel_docker.sh
4040

@@ -45,10 +45,46 @@ cd getting_started/3_ppo
4545

4646
## Available Environments
4747

48-
| Environment | Solver | Grid | Action | Observation | Reward |
49-
|---|---|---|---|---|---|
50-
| `KolmogorovFlow` | 2D pseudo-spectral | 64×64 | 4 body-force modes | 8×8 velocity probes | -(α·TKE + action penalty) |
51-
| `ChannelFlowSpectralEnv` | 3D pseudo-spectral | 72×72×72 | 24 wall jets | 8×8×2 near-wall velocities | -WSS (drag) |
48+
| Environment | Solver | Grid | Action | Observation | Reward | Default dtype |
49+
|---|---|---|---|---|---|---|
50+
| `KolmogorovFlow` | 2D pseudo-spectral | 64×64 | 4 body-force modes | 8×8 velocity probes | -(α·TKE + action penalty) | float64 |
51+
| `ChannelFlowSpectralEnv` | 3D pseudo-spectral | 72×72×72 | 24 wall jets | 8×8×2 near-wall velocities | -WSS (drag) | float32 |
52+
53+
## JIT Compilation
54+
55+
Both environments are JIT-compiled via `jax.jit` in the runner scripts, which compiles the full DNS rollout into a single GPU kernel:
56+
57+
```python
58+
jit_reset = jax.jit(env.reset_env)
59+
jit_step = jax.jit(env.step_env)
60+
61+
obs, state = jit_reset(key, params) # triggers compilation
62+
obs, state, reward, done, info = jit_step(key, state, action, params) # full GPU speed
63+
```
64+
65+
The first call compiles (takes ~1–2 minutes); all subsequent calls run at full GPU speed.
66+
67+
## Floating-Point Precision
68+
69+
| Environment | Recommended | Notes |
70+
|---|---|---|
71+
| `KolmogorovFlow` | `float64` | Pseudo-spectral 2D NS requires fp64 for JIT stability; fp32 may produce NaNs under XLA reordering |
72+
| `ChannelFlowSpectralEnv` | `float32` | Stable at fp32 with JIT; fp64 available but ~2x slower on A100 |
73+
74+
Override via `env_config`:
75+
```python
76+
# Kolmogorov: float64 is the default and required for JIT stability
77+
env = KolmogorovFlow(env_config={"dt": 5e-4}) # smaller dt for fp32 experiments
78+
79+
# Channel: toggle precision
80+
env = ChannelFlowSpectralEnv(env_config={"dtype": "float64"})
81+
```
82+
83+
Or via the bash scripts:
84+
```bash
85+
./run_kolmogorov_docker.sh minimize_tke 100 float32 # float32 (may diverge)
86+
./run_channel_docker.sh drag_reduction 10 float64 # float64
87+
```
5288

5389
## Typical Usage
5490

@@ -57,14 +93,19 @@ import jax
5793
import jax.numpy as jnp
5894
from hydrogym.jax.envs.kolmogorov import KolmogorovFlow
5995

96+
jax.config.update("jax_enable_x64", True) # required for Kolmogorov + JIT
97+
6098
env = KolmogorovFlow(env_config={}, flow_config={})
6199
params = env.default_params
62100

101+
jit_reset = jax.jit(env.reset_env)
102+
jit_step = jax.jit(env.step_env)
103+
63104
key = jax.random.PRNGKey(0)
64-
obs, state = env.reset_env(key, params)
105+
obs, state = jit_reset(key, params)
65106

66107
action = jnp.zeros((params.action_dim,))
67-
obs, state, reward, done, info = env.step_env(key, state, action, params)
108+
obs, state, reward, done, info = jit_step(key, state, action, params)
68109
```
69110

70111
**Note:** The channel flow environment downloads a fully turbulent initial field from Hugging Face Hub (`dynamicslab/HydroGym-environments`) on the first run and caches it at `~/.cache/hydrogym/`.

examples/jax/getting_started/1_kolmogorov/run_kolmogorov_docker.sh

Lines changed: 12 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
# unnecessarily large actuations and promotes efficient controllers.
2020
#
2121
# Usage:
22-
# ./run_kolmogorov_docker.sh # Objective 1: minimize TKE
23-
# ./run_kolmogorov_docker.sh maximize_tke # Objective 2: maximize TKE
24-
# ./run_kolmogorov_docker.sh no_actuation # Baseline: zero action
22+
# ./run_kolmogorov_docker.sh [mode] [num_steps] [dtype]
23+
#
24+
# ./run_kolmogorov_docker.sh # minimize TKE, 10 steps, float64
25+
# ./run_kolmogorov_docker.sh maximize_tke # maximize TKE
26+
# ./run_kolmogorov_docker.sh no_actuation 500 # baseline, 500 steps
27+
# ./run_kolmogorov_docker.sh minimize_tke 1000 float32 # float32 (fast, may diverge)
2528
#
2629
# Actuation:
2730
# The control input is the amplitude of four sinusoidal body-force modes
@@ -30,9 +33,9 @@
3033
# with wavenumbers k1,k2,k3,k4 = 4,5,6,7 (above the base forcing wavenumber).
3134
# Actions are clipped to [-0.5, 0.5].
3235
#
33-
# Output:
34-
# kolmogorov_<mode>.png -- vorticity snapshots comparing baseline vs
35-
# actuated trajectories
36+
# Precision:
37+
# float64 (default) -- required for JIT stability; matches non-JIT behavior
38+
# float32 -- faster but may produce NaNs due to solver instability under XLA
3639
#
3740

3841
set -e
@@ -45,114 +48,7 @@ source /home/easybuild/venvs/hydrogym_gpu/bin/activate
4548

4649
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
4750
MODE="${1:-minimize_tke}"
48-
NUM_STEPS=10
49-
50-
echo "=== Kolmogorov Flow JAX Environment ==="
51-
echo "Mode: $MODE"
52-
echo "Steps per run: $NUM_STEPS"
53-
echo ""
54-
55-
case "$MODE" in
56-
57-
minimize_tke)
58-
echo "Objective: Minimize TKE (suppress energy bursts)"
59-
echo " reward_alpha = 1.0 -> reward = -(TKE + action_penalty)"
60-
echo " Action: small forcing to damp energy transfer"
61-
echo ""
62-
python - <<PYEOF
63-
import jax
64-
import jax.numpy as jnp
65-
from hydrogym.jax.envs.kolmogorov import KolmogorovFlow
66-
67-
env = KolmogorovFlow(env_config={}, flow_config={})
68-
69-
# reward_alpha > 0: penalize TKE -> agent learns to suppress energy bursts
70-
params = env.default_params.replace(reward_alpha=1.0)
71-
72-
key = jax.random.PRNGKey(0)
73-
obs, state = env.reset_env(key, params)
74-
75-
action = jnp.array([-0.25, -0.03, 0.02, 0.01])
76-
77-
print(f"{'Step':>5} {'mean_TKE':>12} {'reward':>12}")
78-
print("-" * 35)
79-
for i in range($NUM_STEPS):
80-
key, subkey = jax.random.split(key)
81-
obs, state, reward, done, info = env.step_env(subkey, state, action, params)
82-
print(f"{i:>5} {float(info['mean_tke']):>12.4f} {float(reward):>12.4f}")
83-
PYEOF
84-
;;
85-
86-
maximize_tke)
87-
echo "Objective: Maximize TKE (enhance turbulent mixing)"
88-
echo " reward_alpha = -1.0 -> reward = TKE - action_penalty"
89-
echo " Action: forcing to drive the flow into a more turbulent regime"
90-
echo ""
91-
python - <<PYEOF
92-
import jax
93-
import jax.numpy as jnp
94-
from hydrogym.jax.envs.kolmogorov import KolmogorovFlow
95-
96-
env = KolmogorovFlow(env_config={}, flow_config={})
97-
98-
# reward_alpha < 0: reward proportional to TKE -> agent learns to increase mixing
99-
params = env.default_params.replace(reward_alpha=-1.0)
100-
101-
key = jax.random.PRNGKey(0)
102-
obs, state = env.reset_env(key, params)
103-
104-
action = jnp.array([0.25, 0.03, -0.02, -0.01])
105-
106-
print(f"{'Step':>5} {'mean_TKE':>12} {'reward':>12}")
107-
print("-" * 35)
108-
for i in range($NUM_STEPS):
109-
key, subkey = jax.random.split(key)
110-
obs, state, reward, done, info = env.step_env(subkey, state, action, params)
111-
print(f"{i:>5} {float(info['mean_tke']):>12.4f} {float(reward):>12.4f}")
112-
PYEOF
113-
;;
114-
115-
no_actuation)
116-
echo "Baseline: zero actuation (free turbulence evolution)"
117-
echo " reward_alpha = 1.0, action = [0, 0, 0, 0]"
118-
echo " Shows natural energy bursts without control"
119-
echo ""
120-
python - <<PYEOF
121-
import jax
122-
import jax.numpy as jnp
123-
from hydrogym.jax.envs.kolmogorov import KolmogorovFlow
124-
125-
env = KolmogorovFlow(env_config={}, flow_config={})
126-
params = env.default_params.replace(reward_alpha=1.0)
127-
128-
key = jax.random.PRNGKey(0)
129-
obs, state = env.reset_env(key, params)
130-
131-
action = jnp.zeros((params.action_dim,))
132-
133-
print(f"{'Step':>5} {'mean_TKE':>12} {'reward':>12}")
134-
print("-" * 35)
135-
for i in range($NUM_STEPS):
136-
key, subkey = jax.random.split(key)
137-
obs, state, reward, done, info = env.step_env(subkey, state, action, params)
138-
print(f"{i:>5} {float(info['mean_tke']):>12.4f} {float(reward):>12.4f}")
139-
PYEOF
140-
;;
141-
142-
*)
143-
echo "Unknown mode: $MODE"
144-
echo "Usage: $0 [minimize_tke|maximize_tke|no_actuation]"
145-
exit 1
146-
;;
147-
esac
148-
149-
EXIT_CODE=$?
150-
151-
echo ""
152-
if [ $EXIT_CODE -eq 0 ]; then
153-
echo "Completed successfully."
154-
else
155-
echo "Failed with exit code: $EXIT_CODE"
156-
fi
51+
NUM_STEPS="${2:-10}"
52+
DTYPE="${3:-float64}"
15753

158-
exit $EXIT_CODE
54+
python "$SCRIPT_DIR/test_kolmogorov_env.py" "$MODE" --num-steps "$NUM_STEPS" --dtype "$DTYPE"

0 commit comments

Comments
 (0)