Skip to content

Commit 499c930

Browse files
committed
hip/rocm support, with multiprocessing-safe sweeps
1 parent 72b734e commit 499c930

9 files changed

Lines changed: 533 additions & 170 deletions

File tree

build.sh

Lines changed: 239 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ set -e
1010
# ./build.sh breakout --fast # Standalone executable (optimized)
1111
# ./build.sh breakout --web # Emscripten web build
1212
# ./build.sh breakout --profile # Kernel profiling binary
13+
# ./build.sh breakout --rocm # HIP/ROCm training backend
14+
# ./build.sh breakout --cuda # CUDA training backend
1315
# ./build.sh all # Build all envs with default and --float
1416

1517
if [ -z "$1" ]; then
16-
echo "Usage: ./build.sh ENV_NAME [--float] [--debug] [--local|--fast|--web|--profile|--cpu|--all]"
18+
echo "Usage: ./build.sh ENV_NAME [--float] [--debug] [--cuda|--rocm] [--local|--fast|--web|--profile|--cpu|--all]"
1719
exit 1
1820
fi
1921
ENV=$1
@@ -28,6 +30,8 @@ for arg in "$@"; do
2830
--web) MODE=web ;;
2931
--profile) MODE=profile ;;
3032
--cpu) MODE=cpu; PRECISION="-DPRECISION_FLOAT" ;;
33+
--cuda) BACKEND=cuda ;;
34+
--rocm) BACKEND=rocm ;;
3135
*) echo "Error: unknown argument '$arg'" && exit 1 ;;
3236
esac
3337
done
@@ -205,57 +209,9 @@ elif [ "$MODE" = "web" ]; then
205209
exit 0
206210
fi
207211

208-
# Find cuDNN path
209-
CUDA_HOME=${CUDA_HOME:-${CUDA_PATH:-$(dirname "$(dirname "$(which nvcc)")")}}
210-
CUDNN_IFLAG=""
211-
CUDNN_LFLAG=""
212-
for dir in /usr/local/cuda/include /usr/include; do
213-
if [ -f "$dir/cudnn.h" ]; then
214-
CUDNN_IFLAG="-I$dir"
215-
break
216-
fi
217-
done
218-
for dir in /usr/local/cuda/lib64 /usr/lib/x86_64-linux-gnu; do
219-
if [ -f "$dir/libcudnn.so" ]; then
220-
CUDNN_LFLAG="-L$dir"
221-
break
222-
fi
223-
done
224-
if [ -z "$CUDNN_IFLAG" ]; then
225-
CUDNN_IFLAG=$("$PYTHON_BIN" -c "import nvidia.cudnn, os; print('-I' + os.path.join(nvidia.cudnn.__path__[0], 'include'))" 2>/dev/null || echo "")
226-
fi
227-
if [ -z "$CUDNN_LFLAG" ]; then
228-
CUDNN_LFLAG=$("$PYTHON_BIN" -c "import nvidia.cudnn, os; print('-L' + os.path.join(nvidia.cudnn.__path__[0], 'lib'))" 2>/dev/null || echo "")
229-
fi
230-
231-
# NCCL include/lib fallback (mirrors the cuDNN fallback above).
232-
# Needed when NCCL is provided by the nvidia-nccl-cu12 wheel in the active venv.
233-
NCCL_IFLAG=""
234-
NCCL_LFLAG=""
235-
for dir in /usr/include /usr/local/cuda/include; do
236-
if [ -f "$dir/nccl.h" ]; then NCCL_IFLAG="-I$dir"; break; fi
237-
done
238-
for dir in /usr/lib/x86_64-linux-gnu /usr/local/cuda/lib64; do
239-
if [ -f "$dir/libnccl.so" ] || [ -f "$dir/libnccl.so.2" ]; then NCCL_LFLAG="-L$dir"; break; fi
240-
done
241-
if [ -z "$NCCL_IFLAG" ]; then
242-
NCCL_IFLAG=$("$PYTHON_BIN" -c "import nvidia.nccl, os; print('-I' + os.path.join(nvidia.nccl.__path__[0], 'include'))" 2>/dev/null || echo "")
243-
fi
244-
if [ -z "$NCCL_LFLAG" ]; then
245-
NCCL_LFLAG=$("$PYTHON_BIN" -c "import nvidia.nccl, os; print('-L' + os.path.join(nvidia.nccl.__path__[0], 'lib'))" 2>/dev/null || echo "")
246-
fi
247-
248-
WHEEL_RPATH_FLAGS=()
249-
for lib_flag in "$CUDNN_LFLAG" "$NCCL_LFLAG"; do
250-
if [[ "$lib_flag" == -L* ]]; then
251-
WHEEL_RPATH_FLAGS+=("-Wl,-rpath,${lib_flag#-L}")
252-
fi
253-
done
254-
255212
export CCACHE_DIR="${CCACHE_DIR:-$HOME/.ccache}"
256213
export CCACHE_BASEDIR="$(pwd)"
257214
export CCACHE_COMPILERCHECK=content
258-
NVCC="ccache $CUDA_HOME/bin/nvcc"
259215
CC="${CC:-$(command -v ccache >/dev/null && echo 'ccache clang' || echo 'clang')}"
260216
ARCH=${NVCC_ARCH:-native}
261217

@@ -275,6 +231,41 @@ if [ ! -f "$BINDING_SRC" ]; then
275231
exit 1
276232
fi
277233

234+
if [ -z "$MODE" ]; then
235+
if [ -z "$BACKEND" ]; then
236+
if "$PYTHON_BIN" -c "from torch.utils.cpp_extension import IS_HIP_EXTENSION; raise SystemExit(0 if IS_HIP_EXTENSION else 1)" 2>/dev/null && ! command -v nvcc >/dev/null 2>&1; then
237+
BACKEND=rocm
238+
else
239+
BACKEND=cuda
240+
fi
241+
fi
242+
elif [ -n "$BACKEND" ]; then
243+
echo "Error: --cuda/--rocm only apply to the training backend"
244+
exit 1
245+
fi
246+
247+
if [ "$BACKEND" = "rocm" ] && [ "$ENV" = "nmmo3" ]; then
248+
echo "Error: NMMO3 native encoder is CUDA-only in build.sh --rocm"
249+
exit 1
250+
fi
251+
252+
CUDA_HOME=${CUDA_HOME:-${CUDA_PATH:-}}
253+
CUDA_IFLAG=""
254+
if [ "$BACKEND" != "rocm" ]; then
255+
CUDA_HOME=${CUDA_HOME:-${CUDA_PATH:-$(dirname "$(dirname "$(which nvcc)")")}}
256+
fi
257+
if [ "$BACKEND" = "cuda" ] || [ "$MODE" = "profile" ]; then
258+
if [ -z "$CUDA_HOME" ]; then
259+
if command -v nvcc >/dev/null 2>&1; then
260+
CUDA_HOME=$(dirname "$(dirname "$(command -v nvcc)")")
261+
else
262+
echo "Error: nvcc not found. Use --rocm for HIP/ROCm or --cpu for CPU fallback."
263+
exit 1
264+
fi
265+
fi
266+
CUDA_IFLAG="-I$CUDA_HOME/include"
267+
fi
268+
278269
if [ "$MODE" = "cpu" ]; then
279270
echo "Compiling static library for $ENV..."
280271
${CC:-clang} -c "${CLANG_OPT[@]}" $EXTRA_CFLAGS \
@@ -288,7 +279,59 @@ if [ "$MODE" = "cpu" ]; then
288279
ar rcs "$STATIC_LIB" "$STATIC_OBJ"
289280
fi
290281

291-
if [ -z "$MODE" ]; then
282+
if [ -z "$MODE" ] && [ "$BACKEND" = "cuda" ]; then
283+
# Find cuDNN path
284+
CUDNN_IFLAG=""
285+
CUDNN_LFLAG=""
286+
for dir in /usr/local/cuda/include /usr/include; do
287+
if [ -f "$dir/cudnn.h" ]; then
288+
CUDNN_IFLAG="-I$dir"
289+
break
290+
fi
291+
done
292+
for dir in /usr/local/cuda/lib64 /usr/lib/x86_64-linux-gnu; do
293+
if [ -f "$dir/libcudnn.so" ]; then
294+
CUDNN_LFLAG="-L$dir"
295+
break
296+
fi
297+
done
298+
if [ -z "$CUDNN_IFLAG" ]; then
299+
CUDNN_IFLAG=$("$PYTHON_BIN" -c "import nvidia.cudnn, os; print('-I' + os.path.join(nvidia.cudnn.__path__[0], 'include'))" 2>/dev/null || echo "")
300+
fi
301+
if [ -z "$CUDNN_LFLAG" ]; then
302+
CUDNN_LFLAG=$("$PYTHON_BIN" -c "import nvidia.cudnn, os; print('-L' + os.path.join(nvidia.cudnn.__path__[0], 'lib'))" 2>/dev/null || echo "")
303+
fi
304+
305+
# NCCL include/lib fallback (mirrors the cuDNN fallback above).
306+
# Needed when NCCL is provided by the nvidia-nccl-cu12 wheel in the active venv.
307+
NCCL_IFLAG=""
308+
NCCL_LFLAG=""
309+
for dir in /usr/include /usr/local/cuda/include; do
310+
if [ -f "$dir/nccl.h" ]; then NCCL_IFLAG="-I$dir"; break; fi
311+
done
312+
for dir in /usr/lib/x86_64-linux-gnu /usr/local/cuda/lib64; do
313+
if [ -f "$dir/libnccl.so" ] || [ -f "$dir/libnccl.so.2" ]; then NCCL_LFLAG="-L$dir"; break; fi
314+
done
315+
if [ -z "$NCCL_IFLAG" ]; then
316+
NCCL_IFLAG=$("$PYTHON_BIN" -c "import nvidia.nccl, os; print('-I' + os.path.join(nvidia.nccl.__path__[0], 'include'))" 2>/dev/null || echo "")
317+
fi
318+
if [ -z "$NCCL_LFLAG" ]; then
319+
NCCL_LFLAG=$("$PYTHON_BIN" -c "import nvidia.nccl, os; print('-L' + os.path.join(nvidia.nccl.__path__[0], 'lib'))" 2>/dev/null || echo "")
320+
fi
321+
322+
WHEEL_RPATH_FLAGS=()
323+
for lib_flag in "$CUDNN_LFLAG" "$NCCL_LFLAG"; do
324+
if [[ "$lib_flag" == -L* ]]; then
325+
WHEEL_RPATH_FLAGS+=("-Wl,-rpath,${lib_flag#-L}")
326+
fi
327+
done
328+
329+
if command -v ccache >/dev/null 2>&1; then
330+
NVCC="ccache $CUDA_HOME/bin/nvcc"
331+
else
332+
NVCC="$CUDA_HOME/bin/nvcc"
333+
fi
334+
292335
echo "Compiling CUDA ($ARCH) training backend with $ENV binding..."
293336
$NVCC -c -arch=$ARCH -Xcompiler -fPIC \
294337
-Xcompiler=-D_GLIBCXX_USE_CXX11_ABI=1 \
@@ -320,6 +363,152 @@ if [ -z "$MODE" ]; then
320363
"${LINK_CMD[@]}"
321364
echo "Built: $OUTPUT"
322365

366+
elif [ -z "$MODE" ] && [ "$BACKEND" = "rocm" ]; then
367+
mapfile -t ROCM_INFO < <("$PYTHON_BIN" - <<'PY'
368+
import os
369+
from torch.utils.cpp_extension import ROCM_HOME, library_paths, include_paths
370+
371+
rocm_home = os.environ.get("ROCM_HOME") or ROCM_HOME
372+
if not rocm_home:
373+
raise SystemExit("ROCM_HOME not found. Install/use a ROCm-enabled PyTorch environment.")
374+
print(rocm_home)
375+
print(os.environ.get("HIPCC") or os.path.join(rocm_home, "bin", "hipcc"))
376+
print(os.pathsep.join(include_paths("cuda")))
377+
print(os.pathsep.join(library_paths("cuda")))
378+
PY
379+
)
380+
ROCM_HOME=${ROCM_INFO[0]}
381+
HIPCC=${ROCM_INFO[1]}
382+
ROCM_INCLUDE_PATHS=${ROCM_INFO[2]}
383+
ROCM_LIBRARY_PATHS=${ROCM_INFO[3]}
384+
385+
if [ ! -x "$HIPCC" ]; then
386+
if command -v hipcc >/dev/null 2>&1; then
387+
HIPCC=$(command -v hipcc)
388+
else
389+
echo "Error: hipcc not found"
390+
exit 1
391+
fi
392+
fi
393+
394+
if [ -z "$HIP_CLANG_PATH" ] || [ ! -x "$HIP_CLANG_PATH/clang++" ]; then
395+
for dir in "$ROCM_HOME/lib/llvm/bin" /usr/lib/llvm/*/bin; do
396+
if [ -x "$dir/clang++" ]; then
397+
export HIP_CLANG_PATH="$dir"
398+
break
399+
fi
400+
done
401+
fi
402+
403+
HIPIFY_SRC="build/hip/src"
404+
HIPIFY_SRC_ABS="$(pwd)/$HIPIFY_SRC"
405+
SRC_ABS="$(pwd)/src"
406+
echo "Hipifying CUDA sources into $HIPIFY_SRC..."
407+
rm -rf "$HIPIFY_SRC"
408+
"$PYTHON_BIN" - <<PY
409+
from torch.utils.hipify import hipify_python
410+
hipify_python.hipify(
411+
project_directory="$SRC_ABS",
412+
output_directory="$HIPIFY_SRC_ABS",
413+
includes=["*"],
414+
show_progress=False,
415+
show_detailed=False,
416+
is_pytorch_extension=True,
417+
)
418+
PY
419+
cp "$HIPIFY_SRC/vecenv_hip.h" "$HIPIFY_SRC/vecenv.h"
420+
"$PYTHON_BIN" - <<PY
421+
path = "$HIPIFY_SRC/pufferlib.hip"
422+
with open(path) as f:
423+
src = f.read()
424+
src = src.replace('#include "vecenv_hip.h"', '#include "vecenv.h"')
425+
with open(path, 'w') as f:
426+
f.write(src)
427+
PY
428+
429+
ROCM_IFLAGS=()
430+
IFS=':' read -ra ROCM_INC_ARR <<< "$ROCM_INCLUDE_PATHS"
431+
for dir in "${ROCM_INC_ARR[@]}"; do
432+
[ -n "$dir" ] && ROCM_IFLAGS+=("-I$dir")
433+
done
434+
ROCM_LFLAGS=()
435+
ROCM_RPATH_FLAGS=()
436+
if [ -d /usr/lib64 ]; then
437+
ROCM_LFLAGS+=("-L/usr/lib64")
438+
ROCM_RPATH_FLAGS+=("-Wl,-rpath,/usr/lib64")
439+
fi
440+
IFS=':' read -ra ROCM_LIB_ARR <<< "$ROCM_LIBRARY_PATHS"
441+
for dir in "${ROCM_LIB_ARR[@]}"; do
442+
[ -n "$dir" ] || continue
443+
[ "$dir" = "/usr/lib" ] && [ -d /usr/lib64 ] && continue
444+
ROCM_LFLAGS+=("-L$dir")
445+
ROCM_RPATH_FLAGS+=("-Wl,-rpath,$dir")
446+
done
447+
ROCM_OMP_LIB=""
448+
for dir in /usr/lib64 /usr/lib /usr/local/lib; do
449+
if [ -f "$dir/libomp.so" ]; then
450+
ROCM_LFLAGS+=("-L$dir")
451+
ROCM_RPATH_FLAGS+=("-Wl,-rpath,$dir")
452+
ROCM_OMP_LIB="-lomp"
453+
break
454+
elif [ -f "$dir/libomp5.so" ]; then
455+
ROCM_LFLAGS+=("-L$dir")
456+
ROCM_RPATH_FLAGS+=("-Wl,-rpath,$dir")
457+
ROCM_OMP_LIB="-lomp5"
458+
break
459+
fi
460+
done
461+
462+
ROCM_ARCH_FLAGS=()
463+
if [ -n "$PYTORCH_ROCM_ARCH" ]; then
464+
IFS=';' read -ra ROCM_ARCH_ARR <<< "$PYTORCH_ROCM_ARCH"
465+
for arch in "${ROCM_ARCH_ARR[@]}"; do
466+
[ -n "$arch" ] && ROCM_ARCH_FLAGS+=("--offload-arch=$arch")
467+
done
468+
fi
469+
470+
HIPCC_OPT=()
471+
if [ -n "$DEBUG" ]; then
472+
HIPCC_OPT=(-O0 -g)
473+
else
474+
HIPCC_OPT=(-O2)
475+
fi
476+
477+
echo "Compiling ROCm/HIP training backend with $ENV binding..."
478+
"$HIPCC" "${ROCM_ARCH_FLAGS[@]}" -c -fPIC \
479+
-D_GLIBCXX_USE_CXX11_ABI=1 \
480+
-DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION \
481+
-DPLATFORM_DESKTOP \
482+
-DUSE_ROCM \
483+
-std=c++17 \
484+
-I. -I"$HIPIFY_SRC" -I$SRC_DIR -Ivendor -I$RAYLIB_NAME/include \
485+
-I$PYTHON_INCLUDE -I$PYBIND_INCLUDE -I$NUMPY_INCLUDE \
486+
"${ROCM_IFLAGS[@]}" \
487+
-fopenmp \
488+
-Wno-c++11-narrowing \
489+
-DENV_BINDING_SRC=\"$BINDING_SRC\" \
490+
-DENV_NAME=$ENV \
491+
$PRECISION "${HIPCC_OPT[@]}" \
492+
"$HIPIFY_SRC/bindings.hip" -o build/bindings.o
493+
494+
"$HIPCC" -c -fPIC -std=c++17 \
495+
"${ROCM_IFLAGS[@]}" \
496+
src/rocm_cuda_shim.cpp -o build/rocm_cuda_shim.o
497+
498+
LINK_CMD=(
499+
${CXX:-g++} -shared -fPIC -fopenmp
500+
build/bindings.o build/rocm_cuda_shim.o "$RAYLIB_A"
501+
"${ROCM_LFLAGS[@]}"
502+
"${ROCM_RPATH_FLAGS[@]}"
503+
-lamdhip64 -lhipblas -lhiprand -lrccl -lrocm_smi64
504+
$ROCM_OMP_LIB
505+
$LINK_OPT
506+
"${SHARED_LDFLAGS[@]}"
507+
-o "$OUTPUT"
508+
)
509+
"${LINK_CMD[@]}"
510+
echo "Built: $OUTPUT"
511+
323512
elif [ "$MODE" = "cpu" ]; then
324513
echo "Compiling CPU training backend..."
325514
${CXX:-g++} -c -fPIC \

0 commit comments

Comments
 (0)