Skip to content

Commit 920dc29

Browse files
committed
Initial commit
0 parents  commit 920dc29

195 files changed

Lines changed: 42594 additions & 0 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
*.py[co]
2+
__pycache__/
3+
4+
.tox/
5+
.pytest_cache/
6+
.mypy_cache/
7+
.ruff_cache/
8+
9+
dist/
10+
build/
11+
*.egg-info/
12+
13+
*.bak
14+
*.wpu
15+
16+
.coverage
17+
htmlcov
18+
pytest_report*.html
19+
20+
node_modules
21+
.ipynb_checkpoints
22+
23+
*.zip
24+
25+
checkpoint_*
26+
.env
27+
.claude/
28+
29+
# Rust:
30+
*/target/
31+
Cargo.lock
32+
*.rs.bk

AGENTS.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Documentation TOCs (start with the one relevant to your task):
2+
- `README.md` shows basic usage
3+
- `docs/HighJax docs.md` — Top-level TOC
4+
- `docs/HighJax/HighJax docs.md` — HighJax environment (state, observations, reward, NPCs)
5+
- `docs/Octane/Octane docs.md` — Octane TUI explorer (navigation, key bindings, rendering)
6+
7+
More things:
8+
- Critical for writing code: `docs/HighJax coding conventions.md` Don't write code without consulting this and abiding to it.
9+
- Run tests like `JAX_PLATFORMS=cpu pytest -n 12 <...other pytest args if needed>`

AUTHORS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Ram Rachum
2+
University of California, Berkeley

CLAUDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../AGENTS.md

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 The HighJax authors
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# HighJax: Highway Driving environment for Reinforcement Learning research
2+
3+
<p align="center">
4+
<img src="misc/videos/demo.webp" alt="HighJax PPO training demo"><br/>
5+
<em>PPO agent learning to drive on a 4-lane highway</em>
6+
</p>
7+
8+
HighJax is an autonomous driving environment for Reinforcement Learning research. It's a JAX implementation of the [HighwayEnv](https://github.com/Farama-Foundation/HighwayEnv). HighJax provides a fully JIT-compilable and vectorizable highway driving simulation.
9+
10+
Besides being much faster than the original, it provides Octane, a Rust-based TUI for examining your experiment runs. Octane provides an interface for defining behaviors and then measuring how much each policy exhibits them.
11+
12+
HighJax was produced as part of our research project about [BXRL:Behavior-Explainable Reinforcement Learning](https://arxiv.org/abs/XXXX.XXXXX).
13+
14+
## Installation
15+
16+
```bash
17+
pip install highjax # Minimal installation
18+
pip install "highjax[cuda12]" # Including GPU support
19+
pip install "highjax[trainer]" # Including PPO implementation
20+
pip install "highjax[cuda12,trainer]" # Including both
21+
```
22+
23+
## Quick Start
24+
25+
```python
26+
import jax
27+
import highjax
28+
29+
env, params = highjax.make('highjax-v0')
30+
key = jax.random.PRNGKey(0)
31+
obs, state = env.reset(key, params)
32+
obs, state, reward, done, info = env.step(key, state, 1, params) # IDLE
33+
```
34+
35+
## Using with JAX RL Libraries
36+
37+
HighJax follows the [gymnax](https://github.com/RobertTLange/gymnax) API, so it works with JAX RL frameworks that expect gymnax-style environments:
38+
39+
- [PureJaxRL](https://github.com/luchris429/purejaxrl) — drop-in gymnax replacement (no PureJaxRL install needed), see [`examples/use_purejaxrl.py`](examples/use_purejaxrl.py)
40+
- [Stoix](https://github.com/EdanToledo/Stoix) — via `stoa` gymnax adapter, see [`examples/use_stoix.py`](examples/use_stoix.py)
41+
- [Rejax](https://github.com/keraJLi/rejax) — pass env object directly, see [`examples/use_rejax.py`](examples/use_rejax.py)
42+
43+
## Training
44+
45+
Train a PPO agent via the CLI:
46+
47+
```bash
48+
highjax-trainer train
49+
```
50+
51+
Key options:
52+
53+
| Flag | Default | Description |
54+
|---------------------|---------|--------------------------------------|
55+
| `--n-epochs` / `-e` | 300 | Training epochs |
56+
| `--n-es` | 400 | Parallel episodes per epoch |
57+
| `--n-ts` | 40 | Timesteps per episode |
58+
| `--seed` / `-s` | 0 | Random seed |
59+
| `--actor-lr` | 3e-4 | Actor learning rate |
60+
| `--critic-lr` | 3e-3 | Critic learning rate |
61+
| `--n-npcs` | 50 | NPC vehicles |
62+
| `--no-trek` || Disable trek recording |
63+
| `--n-sample-es` | 1 | Episodes to sample per epoch for trek|
64+
| `--trek-path` | auto | Custom trek directory path |
65+
| `--discount` | 0.95 | Discount factor (gamma) |
66+
| `--n-lanes` | 4 | Number of highway lanes |
67+
68+
Training automatically records episode data to `~/.highjax/t/` for browsing with Octane (the TUI). Use `--no-trek` to disable.
69+
70+
Here's a snazzy one-liner that will let you explore the results of the current experiment run using [VisiData](https://github.com/saulpw/visidata):
71+
72+
```bash
73+
pip install visidata
74+
vd "$(ls -d ~/.highjax/t/2*/ | tail -1)"/epochia.pq
75+
```
76+
77+
Use the following command line to produce similar results as seen in Figure 2 of the paper:
78+
79+
```bash
80+
highjax-trainer train --n-es 128 --n-ts 400 --n-epochs 300 --target-kld 0.0005
81+
```
82+
83+
## Octane (Episode Browser)
84+
85+
This repo also includes Octane, which is a Rust-based TUI for browsing HighJax experiments.
86+
87+
### Installation
88+
89+
```bash
90+
sudo apt-get install build-essential # C toolchain (needed by Rust)
91+
sudo apt-get install ffmpeg # Needed for `octane animate`
92+
git clone https://github.com/HumanCompatibleAI/HighJax # Clone this repo
93+
cd HighJax
94+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh # Install Rust
95+
source "$HOME/.cargo/env"
96+
cd octane && cargo build --release # Build Octane
97+
alias octane="$(readlink -f octane/target/release/octane)"
98+
```
99+
100+
The binary will be at `octane/target/release/octane`.
101+
102+
### Usage
103+
104+
After training, launch Octane to see all the experiments you ran with `highjax-trainer`:
105+
106+
```bash
107+
octane
108+
```
109+
110+
### Figures
111+
112+
Use Octane to make figures for your paper:
113+
114+
```bash
115+
octane draw -t ~/.highjax/t/2026-03-15-20-02-25-101327 --epoch 300 -e 0 --timestep 19 --theme light \
116+
--zoom 1.8 --png ~/figure.png
117+
```
118+
119+
<p align="center">
120+
<img src="misc/images/figure.png" alt="Octane figure output" width="428"><br/>
121+
</p>
122+
123+
### Behavior crafting
124+
125+
Octane includes a behavior explorer for defining measurable policy properties. While watching an episode, press `b` to capture a scenario — mark which actions you want (positive weight) or don't want (negative weight) at that traffic state. Name it, and Octane saves the behavior to `~/.highjax/behaviors/`. The next time you run `highjax-trainer train`, all discovered behaviors are evaluated every epoch and their scores are recorded as `behavior.{name}` columns in `epochia.parquet`.
126+
127+
<p align="center">
128+
<img src="misc/images/behavior_tui.png" alt="Behavior crafting dialog in Octane" width="364"><br/>
129+
<em>Defining a behavior scenario in Octane</em>
130+
</p>
131+
132+
Press `B` (Shift-B) to open the full Behavior Explorer tab.
133+
134+
See the [Octane docs](docs/Octane/Octane%20docs.md) for full details.
135+
136+
## Documentation
137+
138+
Full documentation is in the `docs/` folder:
139+
140+
- [HighJax environment docs](docs/HighJax/HighJax%20docs.md) — state, observations, reward, NPCs, physics
141+
- [Octane TUI docs](docs/Octane/Octane%20docs.md) — episode browser, configuration, key bindings
142+
- [Coding conventions](docs/HighJax%20coding%20conventions.md) — naming, array indices, style
143+
144+
## Examples
145+
146+
- `examples/basic_usage.py` — Create env, reset, step, print observations
147+
- `examples/train_ppo.py` — Train a PPO agent and evaluate it
148+
- `examples/use_purejaxrl.py` — PureJaxRL integration (vectorized scan loop)
149+
- `examples/use_stoix.py` — Stoix integration (via stoa gymnax adapter)
150+
- `examples/use_rejax.py` — Rejax integration (JIT-compiled training, vmapped seeds)
151+
152+
## Citation
153+
154+
If you use HighJax in your research, please cite:
155+
156+
```bibtex
157+
@article{rachum2025bxrl,
158+
title={BXRL: Behavior-Explainable Reinforcement Learning},
159+
author={Rachum, Ram and Amitai, Yotam and Nakar, Yonatan and Mirsky, Reuth and Allen, Cameron},
160+
year={2025}
161+
}
162+
```

docs/HighJax coding conventions.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# HighJax Coding Conventions
2+
3+
Related: [[HighJax docs]]
4+
5+
Naming conventions, style rules, and JAX-specific patterns used throughout the HighJax codebase.
6+
7+
## Array naming: `foo_by_bar_by_baz`
8+
9+
Almost all JAX arrays use the pattern `foo_by_bar_by_baz`. Each `by_something` is an axis (or sometimes multiple axes). Axes go right-to-left: `foo_by_bar_by_baz` is a 2D array where axis 0 is baz, axis 1 is bar, and the value is foo.
10+
11+
Examples:
12+
13+
- `reward_by_e_by_t` -- shape (n_ts, n_es), reward at each (timestep, episode)
14+
- `p_by_action_by_e` -- shape (n_es, *action_shape), action probabilities
15+
16+
Multi-dimensional items like `position` or `action` occupy multiple axes. The number of axes that `action` occupies depends on the environment.
17+
18+
**This applies to ALL arrays** -- including intermediate variables, temporaries, results, diffs, distances. Even a simple subtraction result should be named `diff_by_e_by_t`, not `diff`.
19+
20+
## Dimension size naming: `n_foos_per_bar`
21+
22+
When unpacking array shapes to get dimension sizes, use `n_foos_per_bar`:
23+
24+
- `foos`: plural of what's being counted
25+
- `bar`: the containing unit
26+
27+
Examples:
28+
29+
```python
30+
n_es_per_epoch, n_cells_per_e, n_tokens_per_vocabulary = logit_by_token_by_cell_by_e.shape
31+
n_es_per_epoch, n_cells_per_e = token_by_cell_by_e.shape
32+
```
33+
34+
## Index variables: `i_foo`
35+
36+
When naming a variable that's an index number, use `i_foo`. However, these are exceptions that don't need the `i_` prefix: `epoch`, `t`, `e`.
37+
38+
## Shorthands
39+
40+
Only use shorthands that already exist in the codebase. Don't invent new ones. Established shorthands:
41+
42+
- `p` for probability
43+
- `e` for episode
44+
- `t` for timestep
45+
- `ft` for flat timestep (full flattened pool, in minibatch code)
46+
- `mt` for minibatch timestep (within one minibatch slice)
47+
- `ts` for timesteps (plural, in CLI args like `--n-ts`)
48+
- `es` for episodes (plural, in CLI args like `--n-es`)
49+
- `v` for value estimate
50+
- `vf` for value function
51+
- `kld` for KL divergence (never bare `kl`)
52+
- `obs` for observation (matches gymnax API convention)
53+
- `nz` for normalized (prefix, e.g. `nz_speed`, `nz_return`, `nz_advantage`)
54+
- `theta` for model parameters (neural network param dicts)
55+
- `vital` for alive-mask arrays (not post-crash)
56+
- `tendency` for log-probability of chosen action
57+
- `epilogue` for post-final-step values (e.g. `epilogue_v_by_agent_by_e`)
58+
- `lunge` for action dimension in multi-discrete action spaces
59+
- `deed` for a choice within a lunge
60+
- `mb` for minibatch (in axis names like `_by_mt_by_mb`)
61+
- `sweep` for one pass over all minibatches
62+
63+
Don't shorten `position` to `pos`, etc.
64+
65+
## Code style
66+
67+
- Python 3.12+ required
68+
- Single quotes everywhere, unless there's a quote-in-quote situation
69+
- Maximum line length: 100 characters
70+
- Type annotations using builtins (`list`, `tuple`, `dict`), not `List`, `Tuple`, `Dict`
71+
- `from __future__ import annotations` at the top of every file
72+
- Import order: `__future__` > stdlib > third-party > highjax
73+
- snake_case for variables/functions, PascalCase for classes
74+
75+
## Docstrings and comments
76+
77+
The `highjax` environment package has docstrings on public API functions (reset, step, etc.) since it serves as a library. The `highjax_trainer` package generally avoids function docstrings — the code should be self-explanatory. Add comments sparingly, only when the code is genuinely difficult to understand otherwise.
78+
79+
## JAX JIT
80+
81+
Many functions that process arrays need to be JIT-compiled. This means:
82+
83+
- No Python control flow on array values (use `jnp.where` instead of `if`)
84+
- Be careful with loops (use `jax.lax.scan` or `jax.vmap` instead)
85+
- Use `jax.Array` and pytree-compatible data structures
86+
87+
## Flax dataclasses
88+
89+
HighJax uses `@flax.struct.dataclass` for most data classes. These are JAX-compatible (pytree-registered) frozen dataclasses. Fields can be marked `pytree_node=False` via `flax.struct.field(pytree_node=False)` to exclude from JAX tracing (e.g., config objects).
90+
91+
## Minibatch PPO pipeline
92+
93+
The gradient computation pipeline has a specific data flow with its own naming:
94+
95+
```
96+
Ascender -> SweepMaster (flatten to _by_ft) -> Sweeper (shuffle to _by_mt_by_minibatch) -> Minibatcher (_by_mt per minibatch)
97+
```
98+
99+
The Minibatcher computes the composite actor objective (PPO clipped surrogate + entropy) and produces gradients via `jax.grad`. The critic is updated separately.
100+
101+
## Testing
102+
103+
- **Golden tests** (`test_golden_runs/`): Deterministic training runs with exact expected values. When the training pipeline changes, these need regeneration. Each test defines its own `train()` function; run it, capture the new values, update `golden_data`.
104+
- **Unit tests**: Everything else -- estimators, objectives, masking, freezing, trainer integration, etc.

docs/HighJax docs.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# HighJax Docs
2+
3+
Top-level table of contents for HighJax documentation.
4+
5+
## Environment
6+
7+
[[HighJax environment]] — State, observations, reward, NPCs, physics.
8+
9+
## Training
10+
11+
[[Trainer docs]] — PPO pipeline, epochia.parquet, trek recording, behaviors.
12+
13+
## TUI Explorer
14+
15+
[[Octane docs]] — Octane TUI: navigation, key bindings, rendering.
16+
17+
## Coding
18+
19+
[[HighJax coding conventions]] — Naming conventions, style rules, JAX patterns.

0 commit comments

Comments
 (0)