Skip to content

Commit 8e7b5b8

Browse files
committed
attention ablation scripts
1 parent 6c0babe commit 8e7b5b8

3 files changed

Lines changed: 556 additions & 0 deletions

File tree

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/bin/bash
2+
# evaluate_ablations.sh
3+
# Discovers and evaluates all ablation variants under an ablations/ directory.
4+
# Reads shared config (fasta paths, biotype CSV) from the parent experiment config.
5+
#
6+
# Usage:
7+
# bash 01b_evaluate_ablations.sh \
8+
# --ablations_dir <path/to/ablations> \
9+
# --eval_script <path/to/01_evaluate_cv_folds.sh> \
10+
# --biotype_csv <path/to/biotype.csv> \
11+
# --gencode_version v47|v49 \
12+
# [--n_folds 5] \
13+
# [--device cuda:0]
14+
15+
set -euo pipefail
16+
17+
ABLATIONS_DIR=""
18+
EVAL_SCRIPT=""
19+
BIOTYPE_CSV=""
20+
GENCODE_VERSION=""
21+
N_FOLDS=5
22+
DEVICE="cuda:0"
23+
24+
while [[ $# -gt 0 ]]; do
25+
case $1 in
26+
--ablations_dir) ABLATIONS_DIR="$2"; shift 2 ;;
27+
--eval_script) EVAL_SCRIPT="$2"; shift 2 ;;
28+
--biotype_csv) BIOTYPE_CSV="$2"; shift 2 ;;
29+
--gencode_version) GENCODE_VERSION="$2"; shift 2 ;;
30+
--n_folds) N_FOLDS="$2"; shift 2 ;;
31+
--device) DEVICE="$2"; shift 2 ;;
32+
*) echo "Unknown argument: $1"; exit 1 ;;
33+
esac
34+
done
35+
36+
# Validate required args
37+
for var in ABLATIONS_DIR EVAL_SCRIPT BIOTYPE_CSV GENCODE_VERSION; do
38+
[[ -z "${!var}" ]] && { echo "ERROR: --${var,,} required"; exit 1; }
39+
done
40+
[[ -d "$ABLATIONS_DIR" ]] || { echo "ERROR: $ABLATIONS_DIR not found"; exit 1; }
41+
[[ -f "$EVAL_SCRIPT" ]] || { echo "ERROR: $EVAL_SCRIPT not found"; exit 1; }
42+
43+
# Read fasta paths from parent experiment config (one level above ablations/)
44+
PARENT_CONFIG="$(dirname "$ABLATIONS_DIR")/config.json"
45+
[[ -f "$PARENT_CONFIG" ]] || { echo "ERROR: parent config not found at $PARENT_CONFIG"; exit 1; }
46+
47+
LNC_TEST_FASTA=$(python3 -c "import json,sys; c=json.load(open('$PARENT_CONFIG')); print(c['data']['lnc_test_fasta'])")
48+
PC_TEST_FASTA=$(python3 -c "import json,sys; c=json.load(open('$PARENT_CONFIG')); print(c['data']['pc_test_fasta'])")
49+
50+
echo "=== Ablation evaluation ==="
51+
echo " Ablations dir : $ABLATIONS_DIR"
52+
echo " Eval script : $EVAL_SCRIPT"
53+
echo " GENCODE version: $GENCODE_VERSION"
54+
echo " lnc test fasta : $LNC_TEST_FASTA"
55+
echo " pc test fasta : $PC_TEST_FASTA"
56+
echo " Biotype CSV : $BIOTYPE_CSV"
57+
echo " Device : $DEVICE"
58+
echo "==========================="
59+
60+
n_total=0; n_skipped=0; n_done=0; n_failed=0
61+
62+
for variant_dir in "$ABLATIONS_DIR"/*/; do
63+
[[ -d "$variant_dir" ]] || continue
64+
variant=$(basename "$variant_dir")
65+
66+
if [[ ! -f "$variant_dir/fold_results.json" ]]; then
67+
echo " [SKIP] $variant — training incomplete"
68+
((n_skipped++)); ((n_total++)); continue
69+
fi
70+
71+
if [[ -f "$variant_dir/evaluation_csv/test_predictions.csv" ]]; then
72+
echo " [SKIP] $variant — already evaluated"
73+
((n_skipped++)); ((n_total++)); continue
74+
fi
75+
76+
echo " [RUN] $variant"
77+
if bash "$EVAL_SCRIPT" \
78+
--experiment_dir "$variant_dir" \
79+
--config "$variant_dir/config.json" \
80+
--biotype_csv "$BIOTYPE_CSV" \
81+
--lnc_test_fasta "$LNC_TEST_FASTA" \
82+
--pc_test_fasta "$PC_TEST_FASTA" \
83+
--n_folds "$N_FOLDS" \
84+
--device "$DEVICE" \
85+
--model_label "$variant" \
86+
--gencode_version "$GENCODE_VERSION" \
87+
>> "$variant_dir/eval.log" 2>&1; then
88+
echo " [OK] $variant"
89+
((n_done++))
90+
else
91+
echo " [FAIL] $variant — see $variant_dir/eval.log"
92+
((n_failed++))
93+
fi
94+
((n_total++))
95+
done
96+
97+
echo ""
98+
echo "=== Summary: $n_total variants | $n_done evaluated | $n_skipped skipped | $n_failed failed ==="
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=train_bvae_g47_features_attn_ablation_seqonly # Job name
3+
#SBATCH --output=logs/train_bvae_g47_features_attn_ablation_seqonly.out # Standard output log
4+
#SBATCH --error=logs/train_bvae_g47_features_attn_ablation_seqonly.err # Standard error log
5+
#SBATCH --gres=gpu:nvidia_h100_nvl # Request 1 GPU
6+
#SBATCH --partition=gpu # Partition to submit to (e.g., GPU queue)
7+
#SBATCH --time=64:00:00 # Time limit day:hrs:min:sec
8+
#SBATCH --mem=64G
9+
10+
# Load modules (if necessary)
11+
module load nvidia/cuda/12.1 # Load CUDA module (adjust version as needed)
12+
eval "$(conda shell.bash hook)"
13+
conda activate beta_lncrna
14+
15+
cd /mnt/cbib/LNClassifier/DL_benchmark
16+
17+
# Run from project root
18+
python -m experiments.train.main_features_attn_ablation --base_config configs/beta_vae_features_attn_g47.json --variants seq_only seq_te seq_nonb --device cuda:0 --n_folds 5
19+
#python -m experiments.train.main_features_attn_ablation --base_config configs/beta_vae_features_attn_g49.json --variants seq_only seq_te seq_nonb --device cuda:0 --n_folds 5
20+
21+
22+
echo "Training complete at $(date)"
23+
echo "Job ID: $SLURM_JOB_ID"

0 commit comments

Comments
 (0)