Skip to content

Commit 9075a3b

Browse files
committed
Change the logic of CUDA detection
1 parent 180366e commit 9075a3b

File tree

1 file changed

+6
-28
lines changed

1 file changed

+6
-28
lines changed

tensorflow_cc/cmake/build_tensorflow.sh.in

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,38 +32,16 @@ export NCCL_INSTALL_PATH=${NCCL_INSTALL_PATH:-/usr}
3232
export PYTHON_BIN_PATH=${PYTHON_BIN_PATH:-"$(which python3)"}
3333
export PYTHON_LIB_PATH="$($PYTHON_BIN_PATH -c 'import site; print(site.getsitepackages()[0])')"
3434

35-
# configure cuda environmental variables
36-
37-
if [ -e /opt/cuda ]; then
38-
echo "Using CUDA from /opt/cuda"
39-
export CUDA_TOOLKIT_PATH=/opt/cuda
40-
elif [ -e /usr/local/cuda ]; then
41-
echo "Using CUDA from /usr/local/cuda"
42-
export CUDA_TOOLKIT_PATH=/usr/local/cuda
43-
fi
44-
45-
if [ -e /opt/cuda/include/cudnn.h ]; then
46-
echo "Using CUDNN from /opt/cuda"
47-
export CUDNN_INSTALL_PATH=/opt/cuda
48-
elif [ -e /usr/local/cuda/include/cudnn.h ]; then
49-
echo "Using CUDNN from /usr/local/cuda"
50-
export CUDNN_INSTALL_PATH=/usr/local/cuda
51-
elif [ -e /usr/include/cudnn.h ]; then
52-
echo "Using CUDNN from /usr"
53-
export CUDNN_INSTALL_PATH=/usr
54-
fi
55-
56-
if [ "@ALLOW_CUDA@" = "ON" ] && [ -n "${CUDA_TOOLKIT_PATH}" ]; then
57-
if [[ -z "${CUDNN_INSTALL_PATH}" ]]; then
58-
echo "CUDA found but no cudnn.h found. Please install cuDNN."
59-
exit 1
60-
fi
35+
# check if cuda support requested and supported
36+
if [ "@ALLOW_CUDA@" = "ON" ] && hash nvcc 2>/dev/null; then
6137
echo "CUDA support enabled"
6238
cuda_config_opts="--config=cuda"
6339
export TF_NEED_CUDA=1
6440
export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-"3.5,7.0"} # default from configure.py
65-
export TF_CUDA_VERSION="$($CUDA_TOOLKIT_PATH/bin/nvcc --version | sed -n 's/^.*release \(.*\),.*/\1/p')"
66-
export TF_CUDNN_VERSION="$(sed -n 's/^#define CUDNN_MAJOR\s*\(.*\).*/\1/p' $CUDNN_INSTALL_PATH/include/cudnn.h)"
41+
export TF_CUDA_PATHS=${TF_CUDA_PATHS:-"/opt/cuda,/usr/local/cuda,/usr/local,/usr/cuda,/usr"}
42+
export TF_CUDA_VERSION="$(nvcc --version | sed -n 's/^.*release \(.*\),.*/\1/p')"
43+
export TF_NCCL_VERSION="$(find / -name 'libnccl.so.*' | tail -n1 | sed -r 's/^.*\.so\.//')"
44+
export TF_CUDNN_VERSION="$(find / -name 'libcudnn.so.*' | tail -n1 | sed -r 's/^.*\.so\.//')"
6745

6846
# choose the right version of CUDA compiler
6947
if [ -z "$GCC_HOST_COMPILER_PATH" ]; then

0 commit comments

Comments
 (0)