Skip to content

Commit 5c516a4

Browse files
author
Connor Baker
authored
Merge pull request #249259 from ConnorBaker/feat/torch-use-cuda-redist
python3Packages.torch: migrate to CUDA redist from CUDA Toolkit
2 parents aa1f784 + b0bd194 commit 5c516a4

1 file changed

Lines changed: 68 additions & 35 deletions

File tree

pkgs/development/python-modules/torch/default.nix

Lines changed: 68 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{ stdenv, lib, fetchFromGitHub, buildPythonPackage, python,
1+
{ stdenv, lib, fetchFromGitHub, fetchpatch, buildPythonPackage, python,
22
config, cudaSupport ? config.cudaSupport, cudaPackages, magma,
33
useSystemNccl ? true,
44
MPISupport ? false, mpi,
@@ -52,17 +52,8 @@
5252

5353
let
5454
inherit (lib) lists strings trivial;
55-
inherit (cudaPackages) cudatoolkit cudaFlags cudnn nccl;
56-
in
55+
inherit (cudaPackages) cudaFlags cudnn nccl;
5756

58-
assert cudaSupport -> stdenv.isLinux;
59-
assert cudaSupport -> (cudaPackages.cudaMajorVersion == "11");
60-
61-
# confirm that cudatoolkits are sync'd across dependencies
62-
assert !(MPISupport && cudaSupport) || mpi.cudatoolkit == cudatoolkit;
63-
assert !cudaSupport || magma.cudaPackages.cudatoolkit == cudatoolkit;
64-
65-
let
6657
setBool = v: if v then "1" else "0";
6758

6859
# https://github.com/pytorch/pytorch/blob/v2.0.1/torch/utils/cpp_extension.py#L1744
@@ -103,23 +94,6 @@ let
10394
throw "No GPU targets specified"
10495
);
10596

106-
cudatoolkit_joined = symlinkJoin {
107-
name = "${cudatoolkit.name}-unsplit";
108-
# nccl is here purely for semantic grouping it could be moved to nativeBuildInputs
109-
paths = [ cudatoolkit.out cudatoolkit.lib nccl.dev nccl.out ];
110-
};
111-
112-
# Normally libcuda.so.1 is provided at runtime by nvidia-x11 via
113-
# LD_LIBRARY_PATH=/run/opengl-driver/lib. We only use the stub
114-
# libcuda.so from cudatoolkit for running tests, so that we don’t have
115-
# to recompile pytorch on every update to nvidia-x11 or the kernel.
116-
cudaStub = linkFarm "cuda-stub" [{
117-
name = "libcuda.so.1";
118-
path = "${cudatoolkit}/lib/stubs/libcuda.so";
119-
}];
120-
cudaStubEnv = lib.optionalString cudaSupport
121-
"LD_LIBRARY_PATH=${cudaStub}\${LD_LIBRARY_PATH:+:}$LD_LIBRARY_PATH ";
122-
12397
rocmtoolkit_joined = symlinkJoin {
12498
name = "rocm-merged";
12599

@@ -160,6 +134,12 @@ in buildPythonPackage rec {
160134
# base is 10.12. Until we upgrade, we can fall back on the older
161135
# pthread support.
162136
./pthreadpool-disable-gcd.diff
137+
] ++ lib.optionals stdenv.isLinux [
138+
# Propagate CUPTI to Kineto by overriding the search path with environment variables.
139+
(fetchpatch {
140+
url = "https://github.com/pytorch/pytorch/pull/108847/commits/7ae4d7c0e2dec358b4fe81538efe9da5eb580ec9.patch";
141+
hash = "sha256-skFaDg98xcJqJfzxWk+qhUxPLHDStqvd0mec3PgksIg=";
142+
})
163143
];
164144

165145
postPatch = lib.optionalString rocmSupport ''
@@ -184,6 +164,13 @@ in buildPythonPackage rec {
184164
--replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
185165
"set(ROCM_PATH \$ENV{ROCM_PATH})''\nset(ROCM_VERSION ${lib.concatStrings (lib.intersperse "0" (lib.splitString "." hip.version))})"
186166
''
167+
# Detection of NCCL version doesn't work particularly well when using the static binary.
168+
+ lib.optionalString cudaSupport ''
169+
substituteInPlace cmake/Modules/FindNCCL.cmake \
170+
--replace \
171+
'message(FATAL_ERROR "Found NCCL header version and library version' \
172+
'message(WARNING "Found NCCL header version and library version'
173+
''
187174
# error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
188175
# This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
189176
+ lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.targetPlatform.darwinSdkVersion "11.0") ''
@@ -192,12 +179,16 @@ in buildPythonPackage rec {
192179
inline void *aligned_alloc(size_t align, size_t size)'
193180
'';
194181

182+
# NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
183+
# when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
184+
# without extreme care to ensure they don't lock each other out of shared resources.
185+
# For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
195186
preConfigure = lib.optionalString cudaSupport ''
196187
export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
197-
export CC=${cudatoolkit.cc}/bin/gcc CXX=${cudatoolkit.cc}/bin/g++
198-
'' + lib.optionalString (cudaSupport && cudnn != null) ''
199188
export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
200189
export CUDNN_LIB_DIR=${cudnn.lib}/lib
190+
export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include
191+
export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib
201192
'' + lib.optionalString rocmSupport ''
202193
export ROCM_PATH=${rocmtoolkit_joined}
203194
export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
@@ -256,6 +247,7 @@ in buildPythonPackage rec {
256247
PYTORCH_BUILD_NUMBER = 0;
257248

258249
USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL
250+
USE_STATIC_NCCL = setBool useSystemNccl;
259251

260252
# Suppress a weird warning in mkl-dnn, part of ideep in pytorch
261253
# (upstream seems to have fixed this in the wrong place?)
@@ -286,12 +278,43 @@ in buildPythonPackage rec {
286278
pybind11
287279
pythonRelaxDepsHook
288280
removeReferencesTo
289-
] ++ lib.optionals cudaSupport [ cudatoolkit_joined ]
290-
++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
281+
] ++ lib.optionals cudaSupport (with cudaPackages; [
282+
autoAddOpenGLRunpathHook
283+
cuda_nvcc
284+
])
285+
++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
291286

292287
buildInputs = [ blas blas.provider pybind11 ]
293288
++ lib.optionals stdenv.isLinux [ linuxHeaders_5_19 ] # TMP: avoid "flexible array member" errors for now
294-
++ lib.optionals cudaSupport [ cudnn.dev cudnn.lib nccl ]
289+
++ lib.optionals cudaSupport (with cudaPackages; [
290+
cuda_cccl.dev # <thrust/*>
291+
cuda_cudart # cuda_runtime.h and libraries
292+
cuda_cupti.dev # For kineto
293+
cuda_cupti.lib # For kineto
294+
cuda_nvcc.dev # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too
295+
cuda_nvml_dev.dev # <nvml.h>
296+
cuda_nvrtc.dev
297+
cuda_nvrtc.lib
298+
cuda_nvtx.dev
299+
cuda_nvtx.lib # -llibNVToolsExt
300+
cudnn.dev
301+
cudnn.lib
302+
libcublas.dev
303+
libcublas.lib
304+
libcufft.dev
305+
libcufft.lib
306+
libcurand.dev
307+
libcurand.lib
308+
libcusolver.dev
309+
libcusolver.lib
310+
libcusparse.dev
311+
libcusparse.lib
312+
nccl.dev # Provides nccl.h AND a static copy of NCCL!
313+
] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
314+
cuda_nvprof.dev # <cuda_profiler_api.h>
315+
] ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [
316+
cuda_profiler_api.dev # <cuda_profiler_api.h>
317+
])
295318
++ lib.optionals rocmSupport [ openmp ]
296319
++ lib.optionals (cudaSupport || rocmSupport) [ magma ]
297320
++ lib.optionals stdenv.isLinux [ numactl ]
@@ -335,7 +358,6 @@ in buildPythonPackage rec {
335358

336359
checkPhase = with lib.versions; with lib.strings; concatStringsSep " " [
337360
"runHook preCheck"
338-
cudaStubEnv
339361
"${python.interpreter} test/run_test.py"
340362
"--exclude"
341363
(concatStringsSep " " [
@@ -419,6 +441,17 @@ in buildPythonPackage rec {
419441
license = licenses.bsd3;
420442
maintainers = with maintainers; [ teh thoughtpolice tscholak ]; # tscholak esp. for darwin-related builds
421443
platforms = with platforms; linux ++ lib.optionals (!cudaSupport && !rocmSupport) darwin;
422-
broken = rocmSupport && cudaSupport; # CUDA and ROCm are mutually exclusive
444+
broken = builtins.any trivial.id [
445+
# CUDA and ROCm are mutually exclusive
446+
(cudaSupport && rocmSupport)
447+
# CUDA is only supported on Linux
448+
(cudaSupport && !stdenv.isLinux)
449+
# Only CUDA 11 is currently supported
450+
(cudaSupport && (cudaPackages.cudaMajorVersion != "11"))
451+
# MPI cudatoolkit does not match cudaPackages.cudatoolkit
452+
(MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit))
453+
# Magma cudaPackages does not match cudaPackages
454+
(cudaSupport && (magma.cudaPackages != cudaPackages))
455+
];
423456
};
424457
}

0 commit comments

Comments
 (0)