Skip to content

Commit 44a9de1

Browse files
committed
Refactor install_gcn.sh to enhance CUDA handling and command-line options
1 parent 7366934 commit 44a9de1

1 file changed

Lines changed: 98 additions & 19 deletions

File tree

devtools/install_gcn.sh

Lines changed: 98 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,91 @@
1-
#!/bin/bash -l
2-
set -e
1+
#!/usr/bin/env bash
2+
set -euo pipefail
33

4-
# ── command-line flags ────────────────────────────────────────────────
5-
MODE=path # {path|conda}
4+
# ── defaults ───────────────────────────────────────────────────────────────
5+
MODE="path" # or "conda"
6+
TSGCN_CUDA_REQ=""
7+
FORCE_CPU=false
68

7-
TEMP=$(getopt -o ch --long conda,help -- "$@")
9+
# ── parse flags ────────────────────────────────────────────────────────────
10+
TEMP=$(getopt -o h --long cuda:,cpu,conda,path,help -- "$@")
811
eval set -- "$TEMP"
912
while true; do
1013
case "$1" in
11-
-c|--conda) MODE=conda; shift ;;
12-
-h|--help)
13-
cat <<EOF
14-
Usage: $0 [--conda]
15-
16-
-c, --conda write PYTHONPATH hooks into ts_gcn (and arc_env if it exists)
17-
instead of modifying ~/.bashrc
14+
--cuda)
15+
TSGCN_CUDA_REQ="$2"
16+
shift 2
17+
;;
18+
--cpu)
19+
FORCE_CPU=true
20+
shift
21+
;;
22+
--conda)
23+
MODE="conda"
24+
shift
25+
;;
26+
--path)
27+
MODE="path"
28+
shift
29+
;;
30+
-h|--help)
31+
cat <<EOF
32+
Usage: $0 [--cuda <9.2|10.1|10.2|11.0>] [--cpu] [--conda|--path] [--help]
33+
34+
--cuda request a specific CUDA version (overrides auto-detect)
35+
--cpu force a CPU-only install
36+
--conda install hooks into conda activate/deactivate
37+
--path append TS-GCN to ~/.bashrc
38+
-h this help
1839
EOF
19-
exit 0 ;;
20-
--) shift; break ;;
21-
*) echo "Internal getopt error"; exit 1 ;;
40+
exit 0
41+
;;
42+
--) shift; break ;;
43+
*) echo "Invalid flag: $1" >&2; exit 1 ;;
2244
esac
2345
done
2446

47+
# ── determine CUDA vs CPU ───────────────────────────────────────────────────
48+
if [[ -n "$TSGCN_CUDA_REQ" ]]; then
49+
# user override
50+
case "$TSGCN_CUDA_REQ" in
51+
9.2|10.1|10.2|11.0)
52+
CUDA="cudatoolkit=${TSGCN_CUDA_REQ}"
53+
CUDA_VERSION="cu${TSGCN_CUDA_REQ/./}"
54+
;;
55+
*)
56+
echo "Error: unsupported --cuda version: $TSGCN_CUDA_REQ" >&2
57+
exit 1
58+
;;
59+
esac
60+
61+
elif $FORCE_CPU; then
62+
CUDA="cpuonly"
63+
CUDA_VERSION="cpu"
64+
65+
else
66+
# auto-detect via nvcc
67+
if command -v nvcc &>/dev/null; then
68+
VER=$(nvcc --version | grep -oP "release \K[0-9]+\.[0-9]+")
69+
echo "Detected nvcc CUDA $VER"
70+
CUDA="cudatoolkit=$VER"
71+
CUDA_VERSION="cu${VER/./}"
72+
73+
# or via nvidia-smi
74+
elif command -v nvidia-smi &>/dev/null; then
75+
VER=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n1 | cut -d. -f1-2)
76+
echo "Detected NVIDIA-driver CUDA $VER"
77+
CUDA="cudatoolkit=$VER"
78+
CUDA_VERSION="cu${VER/./}"
79+
80+
else
81+
echo "No CUDA toolchain found: defaulting to CPU build"
82+
CUDA="cpuonly"
83+
CUDA_VERSION="cpu"
84+
fi
85+
fi
86+
87+
echo "→ Installing with $CUDA on platform $CUDA_VERSION"
88+
2589
# ── functions ─────────────────────────────────────────────────────────────
2690
write_hook () { # env_name repo_path
2791
local env="$1" repo="$2"
@@ -121,10 +185,25 @@ if grep -q '^conda_env:' Makefile; then
121185
fi
122186
echo "⚡ Using $_backend for create_env.sh"
123187
# run make in a subshell so the alias doesn't leak
124-
(
125-
alias conda="$_backend" # This alias is only for this subshell - so we can use the fastest Conda frontend
126-
make conda_env
127-
)
188+
(
189+
alias conda="$_backend"
190+
export CUDA CUDA_VERSION
191+
192+
# map CUDA_VERSION → select index in create_env.sh’s menu
193+
case "$CUDA_VERSION" in
194+
cu92) pick=1 ;;
195+
cu101) pick=2 ;;
196+
cu102) pick=3 ;;
197+
cu110) pick=4 ;;
198+
cpu) pick=5 ;;
199+
*) pick=5 ;;
200+
esac
201+
202+
echo "→ Auto-selecting menu item #$pick for CUDA install"
203+
# pipe the choice into create_env.sh via make
204+
printf '%s\n' "$pick" | make conda_env
205+
)
206+
128207

129208
else
130209
echo "❌ Makefile target 'conda_env' not found. Please check TS-GCN repo."

0 commit comments

Comments
 (0)