@@ -10,10 +10,12 @@ set -e
1010# ./build.sh breakout --fast # Standalone executable (optimized)
1111# ./build.sh breakout --web # Emscripten web build
1212# ./build.sh breakout --profile # Kernel profiling binary
13+ # ./build.sh breakout --rocm # HIP/ROCm training backend
14+ # ./build.sh breakout --cuda # CUDA training backend
1315# ./build.sh all # Build all envs with default and --float
1416
1517if [ -z " $1 " ]; then
16- echo " Usage: ./build.sh ENV_NAME [--float] [--debug] [--local|--fast|--web|--profile|--cpu|--all]"
18+ echo " Usage: ./build.sh ENV_NAME [--float] [--debug] [--cuda|--rocm] [-- local|--fast|--web|--profile|--cpu|--all]"
1719 exit 1
1820fi
1921ENV=$1
@@ -28,6 +30,8 @@ for arg in "$@"; do
2830 --web) MODE=web ;;
2931 --profile) MODE=profile ;;
3032 --cpu) MODE=cpu; PRECISION=" -DPRECISION_FLOAT" ;;
33+ --cuda) BACKEND=cuda ;;
34+ --rocm) BACKEND=rocm ;;
3135 * ) echo " Error: unknown argument '$arg '" && exit 1 ;;
3236 esac
3337done
@@ -205,57 +209,9 @@ elif [ "$MODE" = "web" ]; then
205209 exit 0
206210fi
207211
208- # Find cuDNN path
209- CUDA_HOME=${CUDA_HOME:- ${CUDA_PATH:- $(dirname " $( dirname " $( which nvcc) " ) " )} }
210- CUDNN_IFLAG=" "
211- CUDNN_LFLAG=" "
212- for dir in /usr/local/cuda/include /usr/include; do
213- if [ -f " $dir /cudnn.h" ]; then
214- CUDNN_IFLAG=" -I$dir "
215- break
216- fi
217- done
218- for dir in /usr/local/cuda/lib64 /usr/lib/x86_64-linux-gnu; do
219- if [ -f " $dir /libcudnn.so" ]; then
220- CUDNN_LFLAG=" -L$dir "
221- break
222- fi
223- done
224- if [ -z " $CUDNN_IFLAG " ]; then
225- CUDNN_IFLAG=$( " $PYTHON_BIN " -c " import nvidia.cudnn, os; print('-I' + os.path.join(nvidia.cudnn.__path__[0], 'include'))" 2> /dev/null || echo " " )
226- fi
227- if [ -z " $CUDNN_LFLAG " ]; then
228- CUDNN_LFLAG=$( " $PYTHON_BIN " -c " import nvidia.cudnn, os; print('-L' + os.path.join(nvidia.cudnn.__path__[0], 'lib'))" 2> /dev/null || echo " " )
229- fi
230-
231- # NCCL include/lib fallback (mirrors the cuDNN fallback above).
232- # Needed when NCCL is provided by the nvidia-nccl-cu12 wheel in the active venv.
233- NCCL_IFLAG=" "
234- NCCL_LFLAG=" "
235- for dir in /usr/include /usr/local/cuda/include; do
236- if [ -f " $dir /nccl.h" ]; then NCCL_IFLAG=" -I$dir " ; break ; fi
237- done
238- for dir in /usr/lib/x86_64-linux-gnu /usr/local/cuda/lib64; do
239- if [ -f " $dir /libnccl.so" ] || [ -f " $dir /libnccl.so.2" ]; then NCCL_LFLAG=" -L$dir " ; break ; fi
240- done
241- if [ -z " $NCCL_IFLAG " ]; then
242- NCCL_IFLAG=$( " $PYTHON_BIN " -c " import nvidia.nccl, os; print('-I' + os.path.join(nvidia.nccl.__path__[0], 'include'))" 2> /dev/null || echo " " )
243- fi
244- if [ -z " $NCCL_LFLAG " ]; then
245- NCCL_LFLAG=$( " $PYTHON_BIN " -c " import nvidia.nccl, os; print('-L' + os.path.join(nvidia.nccl.__path__[0], 'lib'))" 2> /dev/null || echo " " )
246- fi
247-
248- WHEEL_RPATH_FLAGS=()
249- for lib_flag in " $CUDNN_LFLAG " " $NCCL_LFLAG " ; do
250- if [[ " $lib_flag " == -L * ]]; then
251- WHEEL_RPATH_FLAGS+=(" -Wl,-rpath,${lib_flag# -L} " )
252- fi
253- done
254-
255212export CCACHE_DIR=" ${CCACHE_DIR:- $HOME / .ccache} "
256213export CCACHE_BASEDIR=" $( pwd) "
257214export CCACHE_COMPILERCHECK=content
258- NVCC=" ccache $CUDA_HOME /bin/nvcc"
259215CC=" ${CC:- $(command -v ccache >/ dev/ null && echo ' ccache clang' || echo ' clang' )} "
260216ARCH=${NVCC_ARCH:- native}
261217
@@ -275,6 +231,41 @@ if [ ! -f "$BINDING_SRC" ]; then
275231 exit 1
276232fi
277233
234+ if [ -z " $MODE " ]; then
235+ if [ -z " $BACKEND " ]; then
236+ if " $PYTHON_BIN " -c " from torch.utils.cpp_extension import IS_HIP_EXTENSION; raise SystemExit(0 if IS_HIP_EXTENSION else 1)" 2> /dev/null && ! command -v nvcc > /dev/null 2>&1 ; then
237+ BACKEND=rocm
238+ else
239+ BACKEND=cuda
240+ fi
241+ fi
242+ elif [ -n " $BACKEND " ]; then
243+ echo " Error: --cuda/--rocm only apply to the training backend"
244+ exit 1
245+ fi
246+
247+ if [ " $BACKEND " = " rocm" ] && [ " $ENV " = " nmmo3" ]; then
248+ echo " Error: NMMO3 native encoder is CUDA-only in build.sh --rocm"
249+ exit 1
250+ fi
251+
252+ CUDA_HOME=${CUDA_HOME:- ${CUDA_PATH:- } }
253+ CUDA_IFLAG=" "
254+ if [ " $BACKEND " != " rocm" ]; then
255+ CUDA_HOME=${CUDA_HOME:- ${CUDA_PATH:- $(dirname " $( dirname " $( which nvcc) " ) " )} }
256+ fi
257+ if [ " $BACKEND " = " cuda" ] || [ " $MODE " = " profile" ]; then
258+ if [ -z " $CUDA_HOME " ]; then
259+ if command -v nvcc > /dev/null 2>&1 ; then
260+ CUDA_HOME=$( dirname " $( dirname " $( command -v nvcc) " ) " )
261+ else
262+ echo " Error: nvcc not found. Use --rocm for HIP/ROCm or --cpu for CPU fallback."
263+ exit 1
264+ fi
265+ fi
266+ CUDA_IFLAG=" -I$CUDA_HOME /include"
267+ fi
268+
278269if [ " $MODE " = " cpu" ]; then
279270 echo " Compiling static library for $ENV ..."
280271 ${CC:- clang} -c " ${CLANG_OPT[@]} " $EXTRA_CFLAGS \
@@ -288,7 +279,59 @@ if [ "$MODE" = "cpu" ]; then
288279 ar rcs " $STATIC_LIB " " $STATIC_OBJ "
289280fi
290281
291- if [ -z " $MODE " ]; then
282+ if [ -z " $MODE " ] && [ " $BACKEND " = " cuda" ]; then
283+ # Find cuDNN path
284+ CUDNN_IFLAG=" "
285+ CUDNN_LFLAG=" "
286+ for dir in /usr/local/cuda/include /usr/include; do
287+ if [ -f " $dir /cudnn.h" ]; then
288+ CUDNN_IFLAG=" -I$dir "
289+ break
290+ fi
291+ done
292+ for dir in /usr/local/cuda/lib64 /usr/lib/x86_64-linux-gnu; do
293+ if [ -f " $dir /libcudnn.so" ]; then
294+ CUDNN_LFLAG=" -L$dir "
295+ break
296+ fi
297+ done
298+ if [ -z " $CUDNN_IFLAG " ]; then
299+ CUDNN_IFLAG=$( " $PYTHON_BIN " -c " import nvidia.cudnn, os; print('-I' + os.path.join(nvidia.cudnn.__path__[0], 'include'))" 2> /dev/null || echo " " )
300+ fi
301+ if [ -z " $CUDNN_LFLAG " ]; then
302+ CUDNN_LFLAG=$( " $PYTHON_BIN " -c " import nvidia.cudnn, os; print('-L' + os.path.join(nvidia.cudnn.__path__[0], 'lib'))" 2> /dev/null || echo " " )
303+ fi
304+
305+ # NCCL include/lib fallback (mirrors the cuDNN fallback above).
306+ # Needed when NCCL is provided by the nvidia-nccl-cu12 wheel in the active venv.
307+ NCCL_IFLAG=" "
308+ NCCL_LFLAG=" "
309+ for dir in /usr/include /usr/local/cuda/include; do
310+ if [ -f " $dir /nccl.h" ]; then NCCL_IFLAG=" -I$dir " ; break ; fi
311+ done
312+ for dir in /usr/lib/x86_64-linux-gnu /usr/local/cuda/lib64; do
313+ if [ -f " $dir /libnccl.so" ] || [ -f " $dir /libnccl.so.2" ]; then NCCL_LFLAG=" -L$dir " ; break ; fi
314+ done
315+ if [ -z " $NCCL_IFLAG " ]; then
316+ NCCL_IFLAG=$( " $PYTHON_BIN " -c " import nvidia.nccl, os; print('-I' + os.path.join(nvidia.nccl.__path__[0], 'include'))" 2> /dev/null || echo " " )
317+ fi
318+ if [ -z " $NCCL_LFLAG " ]; then
319+ NCCL_LFLAG=$( " $PYTHON_BIN " -c " import nvidia.nccl, os; print('-L' + os.path.join(nvidia.nccl.__path__[0], 'lib'))" 2> /dev/null || echo " " )
320+ fi
321+
322+ WHEEL_RPATH_FLAGS=()
323+ for lib_flag in " $CUDNN_LFLAG " " $NCCL_LFLAG " ; do
324+ if [[ " $lib_flag " == -L * ]]; then
325+ WHEEL_RPATH_FLAGS+=(" -Wl,-rpath,${lib_flag# -L} " )
326+ fi
327+ done
328+
329+ if command -v ccache > /dev/null 2>&1 ; then
330+ NVCC=" ccache $CUDA_HOME /bin/nvcc"
331+ else
332+ NVCC=" $CUDA_HOME /bin/nvcc"
333+ fi
334+
292335 echo " Compiling CUDA ($ARCH ) training backend with $ENV binding..."
293336 $NVCC -c -arch=$ARCH -Xcompiler -fPIC \
294337 -Xcompiler=-D_GLIBCXX_USE_CXX11_ABI=1 \
@@ -320,6 +363,152 @@ if [ -z "$MODE" ]; then
320363 " ${LINK_CMD[@]} "
321364 echo " Built: $OUTPUT "
322365
366+ elif [ -z " $MODE " ] && [ " $BACKEND " = " rocm" ]; then
367+ mapfile -t ROCM_INFO < <( " $PYTHON_BIN " - << 'PY '
368+ import os
369+ from torch.utils.cpp_extension import ROCM_HOME, library_paths, include_paths
370+
371+ rocm_home = os.environ.get("ROCM_HOME") or ROCM_HOME
372+ if not rocm_home:
373+ raise SystemExit("ROCM_HOME not found. Install/use a ROCm-enabled PyTorch environment.")
374+ print(rocm_home)
375+ print(os.environ.get("HIPCC") or os.path.join(rocm_home, "bin", "hipcc"))
376+ print(os.pathsep.join(include_paths("cuda")))
377+ print(os.pathsep.join(library_paths("cuda")))
378+ PY
379+ )
380+ ROCM_HOME=${ROCM_INFO[0]}
381+ HIPCC=${ROCM_INFO[1]}
382+ ROCM_INCLUDE_PATHS=${ROCM_INFO[2]}
383+ ROCM_LIBRARY_PATHS=${ROCM_INFO[3]}
384+
385+ if [ ! -x " $HIPCC " ]; then
386+ if command -v hipcc > /dev/null 2>&1 ; then
387+ HIPCC=$( command -v hipcc)
388+ else
389+ echo " Error: hipcc not found"
390+ exit 1
391+ fi
392+ fi
393+
394+ if [ -z " $HIP_CLANG_PATH " ] || [ ! -x " $HIP_CLANG_PATH /clang++" ]; then
395+ for dir in " $ROCM_HOME /lib/llvm/bin" /usr/lib/llvm/* /bin; do
396+ if [ -x " $dir /clang++" ]; then
397+ export HIP_CLANG_PATH=" $dir "
398+ break
399+ fi
400+ done
401+ fi
402+
403+ HIPIFY_SRC=" build/hip/src"
404+ HIPIFY_SRC_ABS=" $( pwd) /$HIPIFY_SRC "
405+ SRC_ABS=" $( pwd) /src"
406+ echo " Hipifying CUDA sources into $HIPIFY_SRC ..."
407+ rm -rf " $HIPIFY_SRC "
408+ " $PYTHON_BIN " - << PY
409+ from torch.utils.hipify import hipify_python
410+ hipify_python.hipify(
411+ project_directory="$SRC_ABS ",
412+ output_directory="$HIPIFY_SRC_ABS ",
413+ includes=["*"],
414+ show_progress=False,
415+ show_detailed=False,
416+ is_pytorch_extension=True,
417+ )
418+ PY
419+ cp " $HIPIFY_SRC /vecenv_hip.h" " $HIPIFY_SRC /vecenv.h"
420+ " $PYTHON_BIN " - << PY
421+ path = "$HIPIFY_SRC /pufferlib.hip"
422+ with open(path) as f:
423+ src = f.read()
424+ src = src.replace('#include "vecenv_hip.h"', '#include "vecenv.h"')
425+ with open(path, 'w') as f:
426+ f.write(src)
427+ PY
428+
429+ ROCM_IFLAGS=()
430+ IFS=' :' read -ra ROCM_INC_ARR <<< " $ROCM_INCLUDE_PATHS"
431+ for dir in " ${ROCM_INC_ARR[@]} " ; do
432+ [ -n " $dir " ] && ROCM_IFLAGS+=(" -I$dir " )
433+ done
434+ ROCM_LFLAGS=()
435+ ROCM_RPATH_FLAGS=()
436+ if [ -d /usr/lib64 ]; then
437+ ROCM_LFLAGS+=(" -L/usr/lib64" )
438+ ROCM_RPATH_FLAGS+=(" -Wl,-rpath,/usr/lib64" )
439+ fi
440+ IFS=' :' read -ra ROCM_LIB_ARR <<< " $ROCM_LIBRARY_PATHS"
441+ for dir in " ${ROCM_LIB_ARR[@]} " ; do
442+ [ -n " $dir " ] || continue
443+ [ " $dir " = " /usr/lib" ] && [ -d /usr/lib64 ] && continue
444+ ROCM_LFLAGS+=(" -L$dir " )
445+ ROCM_RPATH_FLAGS+=(" -Wl,-rpath,$dir " )
446+ done
447+ ROCM_OMP_LIB=" "
448+ for dir in /usr/lib64 /usr/lib /usr/local/lib; do
449+ if [ -f " $dir /libomp.so" ]; then
450+ ROCM_LFLAGS+=(" -L$dir " )
451+ ROCM_RPATH_FLAGS+=(" -Wl,-rpath,$dir " )
452+ ROCM_OMP_LIB=" -lomp"
453+ break
454+ elif [ -f " $dir /libomp5.so" ]; then
455+ ROCM_LFLAGS+=(" -L$dir " )
456+ ROCM_RPATH_FLAGS+=(" -Wl,-rpath,$dir " )
457+ ROCM_OMP_LIB=" -lomp5"
458+ break
459+ fi
460+ done
461+
462+ ROCM_ARCH_FLAGS=()
463+ if [ -n " $PYTORCH_ROCM_ARCH " ]; then
464+ IFS=' ;' read -ra ROCM_ARCH_ARR <<< " $PYTORCH_ROCM_ARCH"
465+ for arch in " ${ROCM_ARCH_ARR[@]} " ; do
466+ [ -n " $arch " ] && ROCM_ARCH_FLAGS+=(" --offload-arch=$arch " )
467+ done
468+ fi
469+
470+ HIPCC_OPT=()
471+ if [ -n " $DEBUG " ]; then
472+ HIPCC_OPT=(-O0 -g)
473+ else
474+ HIPCC_OPT=(-O2)
475+ fi
476+
477+ echo " Compiling ROCm/HIP training backend with $ENV binding..."
478+ " $HIPCC " " ${ROCM_ARCH_FLAGS[@]} " -c -fPIC \
479+ -D_GLIBCXX_USE_CXX11_ABI=1 \
480+ -DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION \
481+ -DPLATFORM_DESKTOP \
482+ -DUSE_ROCM \
483+ -std=c++17 \
484+ -I. -I" $HIPIFY_SRC " -I$SRC_DIR -Ivendor -I$RAYLIB_NAME /include \
485+ -I$PYTHON_INCLUDE -I$PYBIND_INCLUDE -I$NUMPY_INCLUDE \
486+ " ${ROCM_IFLAGS[@]} " \
487+ -fopenmp \
488+ -Wno-c++11-narrowing \
489+ -DENV_BINDING_SRC=\" $BINDING_SRC \" \
490+ -DENV_NAME=$ENV \
491+ $PRECISION " ${HIPCC_OPT[@]} " \
492+ " $HIPIFY_SRC /bindings.hip" -o build/bindings.o
493+
494+ " $HIPCC " -c -fPIC -std=c++17 \
495+ " ${ROCM_IFLAGS[@]} " \
496+ src/rocm_cuda_shim.cpp -o build/rocm_cuda_shim.o
497+
498+ LINK_CMD=(
499+ ${CXX:- g++} -shared -fPIC -fopenmp
500+ build/bindings.o build/rocm_cuda_shim.o " $RAYLIB_A "
501+ " ${ROCM_LFLAGS[@]} "
502+ " ${ROCM_RPATH_FLAGS[@]} "
503+ -lamdhip64 -lhipblas -lhiprand -lrccl -lrocm_smi64
504+ $ROCM_OMP_LIB
505+ $LINK_OPT
506+ " ${SHARED_LDFLAGS[@]} "
507+ -o " $OUTPUT "
508+ )
509+ " ${LINK_CMD[@]} "
510+ echo " Built: $OUTPUT "
511+
323512elif [ " $MODE " = " cpu" ]; then
324513 echo " Compiling CPU training backend..."
325514 ${CXX:- g++} -c -fPIC \
0 commit comments