Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions examples/jaxfluids/test_jaxfluids_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
which serves as the common base class.

JAXFluidsFlowEnv has the following arguments:
- environment_name: Required. Name of the enviroment.
- environment_name: Required. Name of the environment.
- hf_repo_id: Hugging Face repository (default: 'dynamicslab/HydroGym-environments')

- use_clean_cache: Use clean cache directory (default: True)
Expand Down Expand Up @@ -46,7 +46,7 @@
are part of the observation
- is_scale_observations: Optional. Boolean indicating whether observations are scaled to [0, 1].
- target_fn: Optional. Target thrust vector function. Choose either 'sine' or 'step'.

The Nozzle3D environment has the following additional arguments:
- num_actuators: Required. Integer number of actuators. Must be between 4 and 12.
- secondary_pressure_ratio: Optional. Float, must be between 0.7 and 0.9.
Expand All @@ -57,7 +57,7 @@
are part of the observation
- is_scale_observations: Optional. Boolean indicating whether observations are scaled to [0, 1].
- target_fn: Optional. Target thrust vector function. Choose either 'sine' or 'step'.

"""

import os
Expand All @@ -68,7 +68,7 @@
def main():
env_config = {
"environment_name": "Nozzle2D_coarse",
"configuration_file": os.path.abspath("environment_config.yaml")
"configuration_file": os.path.abspath("environment_config.yaml"),
}

env = Nozzle2D(env_config=env_config)
Expand All @@ -77,12 +77,11 @@ def main():
env.render()

for i in range(1000):

# Random action
# action = env.action_space.sample()
# action = env.action_space.sample()

# Fixed action
action = [0.0, 0.5]
action = [0.0, 0.5]

observation, reward, terminated, truncated, info = env.step(action)

Expand All @@ -96,4 +95,4 @@ def main():


if __name__ == "__main__":
main()
main()
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env bash
#
# Run NEK5000 PettingZoo tests with MPMD coupling.
# Config is loaded automatically from HuggingFace (environment_config.yaml).
#
# Usage:
# ./run_pettingzoo_docker.sh # Test only
Expand Down
7 changes: 0 additions & 7 deletions examples/nek/getting_started/6_zeroshot_wing_demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ __This is a deployment/evaluation demo only (no training). The template and cont
## What the script does

`test_nek_pettingzoo.py`:

- loads a base `NekEnv` via `NekEnv.from_hf(...)` and wraps it with `make_pettingzoo_env(...)`
- builds one controller per entry in `POLICY_SPECS` (from `meta_policy_small_wing_template.py`)
- assigns each controller to actuator agents by `x_range` and `side` (`SS` means `y > 0`, `PS` means `y < 0`)
Expand Down Expand Up @@ -42,7 +41,6 @@ obs_dict, rewards_dict, terminations, truncations, infos = env.step(actions)
## Usage

### Recommended: use the runner script

From `6_zeroshot_wing_demo/`:

```bash
Expand All @@ -58,7 +56,6 @@ mpirun -np 1 python test_nek_pettingzoo.py : -np 12 nek5000
```

Legacy policy template + run root:

```bash
mpirun -np 1 python test_nek_pettingzoo.py \
--policy-template ./meta_policy_small_wing_template.py \
Expand All @@ -68,7 +65,6 @@ mpirun -np 1 python test_nek_pettingzoo.py \
```

Useful overrides:

- `--policy-template PATH` (defaults to `./meta_policy_small_wing_template.py`)
- `--env ENV_NAME` (defaults from template `ENV_NAME`)
- `--nproc NPROC` (defaults from template `NPROC`)
Expand All @@ -82,15 +78,13 @@ Useful overrides:
The template defines a lightweight legacy-`MetaPolicy.py`-style configuration.

Required top-level variables:

- `ENV_NAME`
- `NPROC`
- `NUM_STEPS`
- `POLICY_ROOT` (default for `--policy-root`)
- `POLICY_SPECS` (list of policy group dicts)

Each `POLICY_SPECS` entry supports:

- `name`
- `x_range: [x_min, x_max]`
- `side: "SS"` (y>0) or `"PS"` (y<0)
Expand All @@ -101,7 +95,6 @@ Each `POLICY_SPECS` entry supports:
- RL algorithms only: `agent_run_name`, `policy`, and/or `model_path`

Algorithm semantics:

- `ZERO` outputs an all-zero action (no model needed)
- `BL` outputs a constant action equal to `action_max` (no model needed)
- `PPO`/`TD3`/`DDPG` load a Stable-Baselines3 model from `model_path`/`POLICY_ROOT`
Expand Down
Empty file.
Empty file.
Empty file.
2 changes: 1 addition & 1 deletion hydrogym/jaxfluids/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .envs.nozzle import Nozzle2D, Nozzle3D
from .envs.nozzle import Nozzle2D, Nozzle3D
18 changes: 7 additions & 11 deletions hydrogym/jaxfluids/env_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from pathlib import Path
from typing import Dict, Optional, Tuple, Union

from jaxfluids_rl.jxf_env import JAXFluidsEnv, RenderMode
from omegaconf import OmegaConf

from hydrogym.data_manager import HFDataManager

from jaxfluids_rl.jxf_env import JAXFluidsEnv, RenderMode


class ConfigError(Exception):
"""Exception raised for configuration-related errors."""
Expand All @@ -21,7 +20,7 @@ class JAXFluidsFlowEnv(JAXFluidsEnv):
Base JAXFluidsFlowEnv with Hugging Face Hub integration for configuration management.

Arguments:
- environment_name: Required. Name of the enviroment.
- environment_name: Required. Name of the environment.
- hf_repo_id: Hugging Face repository (default: 'dynamicslab/HydroGym-environments')

- use_clean_cache: Use clean cache directory (default: True)
Expand All @@ -35,7 +34,6 @@ class JAXFluidsFlowEnv(JAXFluidsEnv):
"""

def _init_from_hf(self, env_config: dict) -> None:

# Initialize HF data manager
self.hf_repo_id = env_config.get("hf_repo_id", "dynamicslab/HydroGym-environments")
self.local_fallback_dir = env_config.get("local_fallback_dir", None)
Expand All @@ -45,15 +43,15 @@ def _init_from_hf(self, env_config: dict) -> None:
repo_id=self.hf_repo_id,
local_fallback_dir=self.local_fallback_dir,
use_clean_cache=self.use_clean_cache,
fallback_profile="JAXFLUIDS"
fallback_profile="JAXFLUIDS",
)

# Environment identification
self.environment_name = env_config.get("environment_name")

if not self.environment_name:
raise ConfigError("'environment_name' must be specified in env_config")

# Download/get environment configuration
self.env_data_path = self._setup_environment_data()

Expand All @@ -65,11 +63,10 @@ def _init_from_hf(self, env_config: dict) -> None:
f"No configuration file found for environment '{self.environment_name}'. "
f"Expected config.yaml in: {self.env_data_path}"
)

# Load configuration from HF
self.conf = OmegaConf.load(self.configuration_file)


def _setup_environment_data(self) -> str:
"""
Download and setup environment data from HF Hub.
Expand All @@ -95,7 +92,6 @@ def _setup_environment_data(self) -> str:
return env_path
except Exception as e:
raise ConfigError(f"Failed to setup environment data for {self.environment_name}: {e}")


def _resolve_configuration_file(self, config_file_input: Optional[str]) -> Optional[str]:
"""
Expand Down Expand Up @@ -154,7 +150,7 @@ def _resolve_configuration_file(self, config_file_input: Optional[str]) -> Optio
f" - Current directory: {os.getcwd()}\n"
f" - Environment directory: {self.env_data_path}"
)

def _find_configuration_file(self) -> Optional[str]:
"""
Auto-detect configuration file in the environment data directory.
Expand Down Expand Up @@ -192,4 +188,4 @@ def _find_configuration_file(self) -> Optional[str]:
if os.path.exists(self.env_data_path):
print(f"Available files: {os.listdir(self.env_data_path)}")

return None
return None
Loading
Loading