Skip to content

Commit b19783a

Browse files
committed
Refactor install_gcn.sh to streamline CUDA detection and improve environment setup
1 parent 44a9de1 commit b19783a

1 file changed

Lines changed: 37 additions & 89 deletions

File tree

devtools/install_gcn.sh

Lines changed: 37 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ done
4646

4747
# ── determine CUDA vs CPU ───────────────────────────────────────────────────
4848
if [[ -n "$TSGCN_CUDA_REQ" ]]; then
49-
# user override
5049
case "$TSGCN_CUDA_REQ" in
5150
9.2|10.1|10.2|11.0)
5251
CUDA="cudatoolkit=${TSGCN_CUDA_REQ}"
@@ -57,28 +56,19 @@ if [[ -n "$TSGCN_CUDA_REQ" ]]; then
5756
exit 1
5857
;;
5958
esac
60-
6159
elif $FORCE_CPU; then
6260
CUDA="cpuonly"
6361
CUDA_VERSION="cpu"
64-
6562
else
66-
# auto-detect via nvcc
6763
if command -v nvcc &>/dev/null; then
6864
VER=$(nvcc --version | grep -oP "release \K[0-9]+\.[0-9]+")
69-
echo "Detected nvcc CUDA $VER"
7065
CUDA="cudatoolkit=$VER"
7166
CUDA_VERSION="cu${VER/./}"
72-
73-
# or via nvidia-smi
7467
elif command -v nvidia-smi &>/dev/null; then
7568
VER=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n1 | cut -d. -f1-2)
76-
echo "Detected NVIDIA-driver CUDA $VER"
7769
CUDA="cudatoolkit=$VER"
7870
CUDA_VERSION="cu${VER/./}"
79-
8071
else
81-
echo "No CUDA toolchain found: defaulting to CPU build"
8272
CUDA="cpuonly"
8373
CUDA_VERSION="cpu"
8474
fi
@@ -89,9 +79,7 @@ echo "→ Installing with $CUDA on platform $CUDA_VERSION"
8979
# ── functions ─────────────────────────────────────────────────────────────
9080
write_hook () { # env_name repo_path
9181
local env="$1" repo="$2"
92-
# skip if env missing
9382
$COMMAND_PKG env list | awk '{print $1}' | grep -qx "$env" || return 0
94-
9583
local prefix
9684
if [[ $COMMAND_PKG == micromamba ]]; then
9785
prefix="$(micromamba info --base)/envs/$env"
@@ -101,17 +89,16 @@ write_hook () { # env_name repo_path
10189
local act="$prefix/etc/conda/activate.d/zzz-tsgcn.sh"
10290
local deact="$prefix/etc/conda/deactivate.d/zzz-tsgcn.sh"
10391
mkdir -p "${act%/*}" "${deact%/*}"
104-
rm -f "$act" "$deact" # ensure fresh copy each run
92+
rm -f "$act" "$deact"
10593

106-
# --- activate ----------------------------------------------------------
94+
# activate hook
10795
cat >"$act" <<EOF
10896
# TS-GCN hook – $(date +%F)
10997
export TSGCN_ROOT="$repo"
110-
case ":\$PYTHONPATH:" in *":\$TSGCN_ROOT:"*) ;; \
111-
*) export PYTHONPATH="\$TSGCN_ROOT:\${PYTHONPATH:-}" ;; esac
98+
case ":\$PYTHONPATH:" in *":\$TSGCN_ROOT:") ;; *) export PYTHONPATH="\$TSGCN_ROOT:\${PYTHONPATH:-}" ;; esac
11299
EOF
113100

114-
# --- deactivate --------------------------------------------------------
101+
# deactivate hook
115102
cat >"$deact" <<'EOF'
116103
_strip () { local n=":$1:"; local s=":$2:"; echo "${s//$n/:}" | sed 's/^://;s/:$//'; }
117104
export PYTHONPATH=$(_strip "$TSGCN_ROOT" ":${PYTHONPATH:-}:")
@@ -120,97 +107,58 @@ EOF
120107
echo "🔗 PYTHONPATH hook refreshed in $env"
121108
}
122109

123-
# ── locate folders relative to this script ────────────────────────────────
124-
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
125-
ARC_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" # …/ARC_Mol
126-
CLONE_ROOT="$(cd "$ARC_ROOT/.." && pwd)" # directory that *contains* ARC_Mol
110+
# ── locate folders ─────────────────────────────────────────────────────────
111+
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)"
112+
CLONE_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
127113
cd "$CLONE_ROOT"
128114

129-
if command -v micromamba &> /dev/null; then
130-
echo "✔️ Micromamba is installed."
115+
# choose backend
116+
if command -v micromamba &>/dev/null; then
131117
COMMAND_PKG=micromamba
132-
elif command -v mamba &> /dev/null; then
133-
echo "✔️ Mamba is installed."
118+
elif command -v mamba &>/dev/null; then
134119
COMMAND_PKG=mamba
135-
elif command -v conda &> /dev/null; then
136-
echo "✔️ Conda is installed."
137-
COMMAND_PKG=conda
138-
else
139-
echo "❌ Micromamba, Mamba, or Conda is required. Please install one."
140-
exit 1
141-
fi
142-
143-
if [ "$COMMAND_PKG" = "micromamba" ]; then
144-
eval "$(micromamba shell hook --shell=bash)"
145120
else
146-
BASE=$(conda info --base)
147-
# shellcheck source=/dev/null
148-
. "$BASE/etc/profile.d/conda.sh"
121+
COMMAND_PKG=conda
149122
fi
150123

124+
eval "\$($COMMAND_PKG shell hook --shell=bash)"
151125

152-
echo ">>> Cloning or updating TS-GCN..."
153-
if [ -d TS-GCN ]; then
154-
cd TS-GCN
155-
git fetch origin
156-
git checkout main
157-
git pull origin main
158-
else
159-
git clone https://github.com/ReactionMechanismGenerator/TS-GCN
160-
cd TS-GCN
161-
fi
126+
# clone/update repo
127+
if [ -d TS-GCN ]; then cd TS-GCN && git fetch && git checkout main && git pull; else git clone https://github.com/ReactionMechanismGenerator/TS-GCN && cd TS-GCN; fi
162128

163-
# 3. PATH vs hooks ----------------------------------------------------------
129+
# 3. PATH vs hooks
164130
if [[ $MODE == path ]]; then
165-
GCN_LINE="export PYTHONPATH=\$PYTHONPATH:$(pwd)"
166-
if ! grep -Fqx "$GCN_LINE" ~/.bashrc; then
167-
echo "$GCN_LINE" >> ~/.bashrc
168-
echo "✔️ Added TS-GCN path to ~/.bashrc"
169-
else
170-
echo "ℹ️ TS-GCN path already exists in ~/.bashrc"
171-
fi
131+
GCN_LINE="export PYTHONPATH=\\$PYTHONPATH:$(pwd)"
132+
grep -Fqx "$GCN_LINE" ~/.bashrc || { echo "$GCN_LINE" >> ~/.bashrc; echo "✔️ Added TS-GCN path to ~/.bashrc"; }
172133
fi
173134

174-
# ---------------------------------------------------------------------------
175-
# create / update env *here* (unchanged)
176-
if grep -q '^conda_env:' Makefile; then
177-
echo ">>> Creating GCN conda environment via Makefile"
178-
# --- pick the fastest Conda frontend just for create_env.sh ---------------
179-
if command -v micromamba >/dev/null; then
180-
_backend="micromamba"
181-
elif command -v mamba >/dev/null; then
182-
_backend="mamba"
135+
# 4. inline env creation & install
136+
if [[ -f environment.yml ]]; then
137+
echo "Creating/updating ts_gcn environment"
138+
if $COMMAND_PKG env list | awk '{print $1}' | grep -qx ts_gcn; then
139+
$COMMAND_PKG env update -n ts_gcn -f environment.yml --prune -y
183140
else
184-
_backend="conda" # fallback to classic
141+
$COMMAND_PKG env create -n ts_gcn -f environment.yml -y
185142
fi
186-
echo "⚡ Using $_backend for create_env.sh"
187-
# run make in a subshell so the alias doesn't leak
188-
(
189-
alias conda="$_backend"
190-
export CUDA CUDA_VERSION
191-
192-
# map CUDA_VERSION → select index in create_env.sh’s menu
193-
case "$CUDA_VERSION" in
194-
cu92) pick=1 ;;
195-
cu101) pick=2 ;;
196-
cu102) pick=3 ;;
197-
cu110) pick=4 ;;
198-
cpu) pick=5 ;;
199-
*) pick=5 ;;
200-
esac
201-
202-
echo "→ Auto-selecting menu item #$pick for CUDA install"
203-
# pipe the choice into create_env.sh via make
204-
printf '%s\n' "$pick" | make conda_env
205-
)
206-
207143

144+
conda activate ts_gcn
145+
echo "Installing PyTorch + torchvision with $CUDA"
146+
$COMMAND_PKG install -n ts_gcn pytorch torchvision $CUDA -c pytorch -y
147+
148+
TORCH_VER=$(python -c "import torch; print(torch.__version__)" | cut -c1-4)0
149+
WHEEL_URL="https://pytorch-geometric.com/whl/torch-${TORCH_VER}+${CUDA_VERSION}.html"
150+
pip install torch-scatter -f "$WHEEL_URL"
151+
pip install torch-sparse -f "$WHEEL_URL"
152+
pip install torch-cluster -f "$WHEEL_URL"
153+
pip install torch-spline-conv -f "$WHEEL_URL"
154+
pip install torch-geometric
155+
echo "✅ ts_gcn environment ready"
208156
else
209-
echo "Makefile target 'conda_env' not found. Please check TS-GCN repo."
157+
echo "environment.yml not found."
210158
exit 1
211159
fi
212160

213-
# 4. write hooks *after* env exists -----------------------------------------
161+
# 5. write hooks
214162
if [[ $MODE == conda ]]; then
215163
write_hook ts_gcn "$(pwd)"
216164
if $COMMAND_PKG env list | awk '{print $1}' | grep -qx arc_env; then

0 commit comments

Comments
 (0)