Skip to content

Commit 466bf1a

Browse files
Added Ablation Studies
- No Phase Encoding - No Dendritic Routing - No Plasticity - Full [Having Phase Encoding + Dendritic Routing + Plasticity] -- For Dataset -- 1. Using MountainCarv0, Pendulumv1, Cartpolev1, Acrobotv1 2. Using the method of taking the randomness more in compare to the expert and medium one... 3. A1, B1, F1 are the files where we did the: generate -> process -> verify the datasets... Architecture is still the same, now talking for the scripts and computes: 4 environments x 5 models = 20 scripts will be run in the T4 GPU in Google Colab. Co-Authored-By: Aditi <76613793+aditi0x@users.noreply.github.com>
1 parent a8e5f89 commit 466bf1a

20 files changed

Lines changed: 1746 additions & 0 deletions

ablation_studies/README.md

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Phase 2: Ablation Study & Baseline Comparison
2+
3+
This directory contains the code and resources for the Phase 2 ablation study of the SNN-DT project, which now includes a comprehensive comparison against baseline models.
4+
5+
## 1. Installation
6+
7+
Ensure you have the required dependencies installed from the main project's `requirements.txt`:
8+
9+
```bash
10+
pip install -r ../requirements.txt
11+
```
12+
13+
You will also need to install `stable-baselines3`, `pyyaml`, `tqdm`, and `gymnasium`:
14+
15+
```bash
16+
pip install stable-baselines3 pyyaml tqdm gymnasium
17+
```
18+
19+
## 2. Dataset Generation
20+
21+
The new dataset generation process is a multi-stage pipeline that produces high-quality, return-stratified datasets.
22+
23+
### Step 1: Generate Raw Trajectories
24+
25+
First, generate the raw trajectories from random, medium, and expert policies.
26+
27+
```bash
28+
python scripts/A1_generate_trajectories.py
29+
```
30+
31+
This will save the raw trajectories to `ablation_studies/datasets/raw`.
32+
33+
### Step 2: Process Datasets
34+
35+
Next, process the raw trajectories to create the final, stratified datasets.
36+
37+
```bash
38+
python scripts/B1_process_datasets.py
39+
```
40+
41+
This will create `stratified_dataset.npz` and `random_heavy_dataset.npz` in `ablation_studies/datasets/processed` for each environment.
42+
43+
### Step 3: Verify Datasets
44+
45+
Finally, verify the quality of the generated datasets.
46+
47+
```bash
48+
python scripts/F1_verify_datasets.py
49+
```
50+
51+
This will generate distribution plots in `ablation_studies/datasets/verification_plots` and print a spike sanity check to the console.
52+
53+
## 3. Run Experiments
54+
55+
To run an experiment, use the `run_experiment.py` script with the desired variant, environment, and seed.
56+
57+
### Ablation Variants
58+
59+
**Example:**
60+
61+
```bash
62+
python run_experiment.py --variant full --env CartPole-v1 --seed 1001
63+
```
64+
65+
### Baseline Models
66+
67+
**Example:**
68+
69+
```bash
70+
python run_experiment.py --variant dt --env CartPole-v1 --seed 1001
71+
```
72+
73+
### Full Experimental Run
74+
75+
You can run all experiments using a simple shell loop:
76+
77+
```bash
78+
for variant in full no_phase no_routing no_plasticity dt snn_dt iql cql; do
79+
for env in CartPole-v1 Acrobot-v1 Pendulum-v1; do
80+
for seed in 1001 1002 1003; do
81+
echo "--- Running $variant on $env with seed $seed ---"
82+
python run_experiment.py --variant "$variant" --env "$env" --seed "$seed"
83+
done
84+
done
85+
done
86+
```
87+
88+
## 4. Post-process Results
89+
90+
After the experiments are complete, you can generate the plots and summary tables using the `post_process.py` script.
91+
92+
```bash
93+
python scripts/post_process.py
94+
```
95+
96+
This will save the figures to the `ablation_studies/figures` directory.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model:
2+
name: cql
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model:
2+
name: dt
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
phase_encoder:
2+
enabled: true
3+
routing:
4+
enabled: true
5+
local_plasticity:
6+
enabled: true
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model:
2+
name: iql
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
phase_encoder:
2+
enabled: false
3+
routing:
4+
enabled: true
5+
local_plasticity:
6+
enabled: true
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
phase_encoder:
2+
enabled: true
3+
routing:
4+
enabled: true
5+
local_plasticity:
6+
enabled: false
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
phase_encoder:
2+
enabled: true
3+
routing:
4+
enabled: false
5+
local_plasticity:
6+
enabled: true
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model:
2+
name: snn_dt
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
envs: ["CartPole-v1", "Acrobot-v1", "Pendulum-v1", "MountainCar-v0"]
2+
datasets_per_env:
3+
steps: 10000
4+
mix: "50% expert / 50% random"
5+
sequence_length_N: 20
6+
spiking_window_T: 10
7+
hidden_dim_d: 128
8+
num_layers_L: 2
9+
num_heads_H: 4
10+
batch_size: 64
11+
optimizer: AdamW
12+
lr: 3e-4
13+
weight_decay: 1e-2
14+
epochs: 50
15+
local_lr_eta_local: 0.05
16+
surrogate_slope_k: 10
17+
spike_energy_pJ: 5.0
18+
seeds: [1001, 1002, 1003]
19+
log_interval_steps: 250
20+
checkpoint_interval_epochs: 5
21+
eval_rollouts: 50
22+
23+
# --- Model-specific configs ---
24+
iql:
25+
tau: 0.005
26+
temperature: 3.0
27+
expectile: 0.7
28+
hidden_size: 256
29+
cql:
30+
tau: 0.005
31+
temperature: 1.0
32+
hidden_size: 256
33+
with_lagrange: false
34+
cql_weight: 1.0
35+
target_action_gap: 10.0

0 commit comments

Comments
 (0)