1- { stdenv , lib , fetchFromGitHub , buildPythonPackage , python ,
1+ { stdenv , lib , fetchFromGitHub , fetchpatch , buildPythonPackage , python ,
22 config , cudaSupport ? config . cudaSupport , cudaPackages , magma ,
33 useSystemNccl ? true ,
44 MPISupport ? false , mpi ,
5252
5353let
5454 inherit ( lib ) lists strings trivial ;
55- inherit ( cudaPackages ) cudatoolkit cudaFlags cudnn nccl ;
56- in
55+ inherit ( cudaPackages ) cudaFlags cudnn nccl ;
5756
58- assert cudaSupport -> stdenv . isLinux ;
59- assert cudaSupport -> ( cudaPackages . cudaMajorVersion == "11" ) ;
60-
61- # confirm that cudatoolkits are sync'd across dependencies
62- assert ! ( MPISupport && cudaSupport ) || mpi . cudatoolkit == cudatoolkit ;
63- assert ! cudaSupport || magma . cudaPackages . cudatoolkit == cudatoolkit ;
64-
65- let
6657 setBool = v : if v then "1" else "0" ;
6758
6859 # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/utils/cpp_extension.py#L1744
10394 throw "No GPU targets specified"
10495 ) ;
10596
106- cudatoolkit_joined = symlinkJoin {
107- name = "${ cudatoolkit . name } -unsplit" ;
108- # nccl is here purely for semantic grouping it could be moved to nativeBuildInputs
109- paths = [ cudatoolkit . out cudatoolkit . lib nccl . dev nccl . out ] ;
110- } ;
111-
112- # Normally libcuda.so.1 is provided at runtime by nvidia-x11 via
113- # LD_LIBRARY_PATH=/run/opengl-driver/lib. We only use the stub
114- # libcuda.so from cudatoolkit for running tests, so that we don’t have
115- # to recompile pytorch on every update to nvidia-x11 or the kernel.
116- cudaStub = linkFarm "cuda-stub" [ {
117- name = "libcuda.so.1" ;
118- path = "${ cudatoolkit } /lib/stubs/libcuda.so" ;
119- } ] ;
120- cudaStubEnv = lib . optionalString cudaSupport
121- "LD_LIBRARY_PATH=${ cudaStub } \ ${LD_LIBRARY_PATH:+:}$LD_LIBRARY_PATH " ;
122-
12397 rocmtoolkit_joined = symlinkJoin {
12498 name = "rocm-merged" ;
12599
@@ -160,6 +134,12 @@ in buildPythonPackage rec {
160134 # base is 10.12. Until we upgrade, we can fall back on the older
161135 # pthread support.
162136 ./pthreadpool-disable-gcd.diff
137+ ] ++ lib . optionals stdenv . isLinux [
138+ # Propagate CUPTI to Kineto by overriding the search path with environment variables.
139+ ( fetchpatch {
140+ url = "https://github.com/pytorch/pytorch/pull/108847/commits/7ae4d7c0e2dec358b4fe81538efe9da5eb580ec9.patch" ;
141+ hash = "sha256-skFaDg98xcJqJfzxWk+qhUxPLHDStqvd0mec3PgksIg=" ;
142+ } )
163143 ] ;
164144
165145 postPatch = lib . optionalString rocmSupport ''
@@ -184,6 +164,13 @@ in buildPythonPackage rec {
184164 --replace "set(ROCM_PATH \$ENV{ROCM_PATH})" \
185165 "set(ROCM_PATH \$ENV{ROCM_PATH})''\n set(ROCM_VERSION ${ lib . concatStrings ( lib . intersperse "0" ( lib . splitString "." hip . version ) ) } )"
186166 ''
167+ # Detection of NCCL version doesn't work particularly well when using the static binary.
168+ + lib . optionalString cudaSupport ''
169+ substituteInPlace cmake/Modules/FindNCCL.cmake \
170+ --replace \
171+ 'message(FATAL_ERROR "Found NCCL header version and library version' \
172+ 'message(WARNING "Found NCCL header version and library version'
173+ ''
187174 # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
188175 # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
189176 + lib . optionalString ( stdenv . isDarwin && lib . versionOlder stdenv . targetPlatform . darwinSdkVersion "11.0" ) ''
@@ -192,12 +179,16 @@ in buildPythonPackage rec {
192179 inline void *aligned_alloc(size_t align, size_t size)'
193180 '' ;
194181
182+ # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
183+ # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
184+ # without extreme care to ensure they don't lock each other out of shared resources.
185+ # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
195186 preConfigure = lib . optionalString cudaSupport ''
196187 export TORCH_CUDA_ARCH_LIST="${ gpuTargetString } "
197- export CC=${ cudatoolkit . cc } /bin/gcc CXX=${ cudatoolkit . cc } /bin/g++
198- '' + lib . optionalString ( cudaSupport && cudnn != null ) ''
199188 export CUDNN_INCLUDE_DIR=${ cudnn . dev } /include
200189 export CUDNN_LIB_DIR=${ cudnn . lib } /lib
190+ export CUPTI_INCLUDE_DIR=${ cudaPackages . cuda_cupti . dev } /include
191+ export CUPTI_LIBRARY_DIR=${ cudaPackages . cuda_cupti . lib } /lib
201192 '' + lib . optionalString rocmSupport ''
202193 export ROCM_PATH=${ rocmtoolkit_joined }
203194 export ROCM_SOURCE_DIR=${ rocmtoolkit_joined }
@@ -256,6 +247,7 @@ in buildPythonPackage rec {
256247 PYTORCH_BUILD_NUMBER = 0 ;
257248
258249 USE_SYSTEM_NCCL = setBool useSystemNccl ; # don't build pytorch's third_party NCCL
250+ USE_STATIC_NCCL = setBool useSystemNccl ;
259251
260252 # Suppress a weird warning in mkl-dnn, part of ideep in pytorch
261253 # (upstream seems to have fixed this in the wrong place?)
@@ -286,12 +278,43 @@ in buildPythonPackage rec {
286278 pybind11
287279 pythonRelaxDepsHook
288280 removeReferencesTo
289- ] ++ lib . optionals cudaSupport [ cudatoolkit_joined ]
290- ++ lib . optionals rocmSupport [ rocmtoolkit_joined ] ;
281+ ] ++ lib . optionals cudaSupport ( with cudaPackages ; [
282+ autoAddOpenGLRunpathHook
283+ cuda_nvcc
284+ ] )
285+ ++ lib . optionals rocmSupport [ rocmtoolkit_joined ] ;
291286
292287 buildInputs = [ blas blas . provider pybind11 ]
293288 ++ lib . optionals stdenv . isLinux [ linuxHeaders_5_19 ] # TMP: avoid "flexible array member" errors for now
294- ++ lib . optionals cudaSupport [ cudnn . dev cudnn . lib nccl ]
289+ ++ lib . optionals cudaSupport ( with cudaPackages ; [
290+ cuda_cccl . dev # <thrust/*>
291+ cuda_cudart # cuda_runtime.h and libraries
292+ cuda_cupti . dev # For kineto
293+ cuda_cupti . lib # For kineto
294+ cuda_nvcc . dev # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too
295+ cuda_nvml_dev . dev # <nvml.h>
296+ cuda_nvrtc . dev
297+ cuda_nvrtc . lib
298+ cuda_nvtx . dev
299+ cuda_nvtx . lib # -llibNVToolsExt
300+ cudnn . dev
301+ cudnn . lib
302+ libcublas . dev
303+ libcublas . lib
304+ libcufft . dev
305+ libcufft . lib
306+ libcurand . dev
307+ libcurand . lib
308+ libcusolver . dev
309+ libcusolver . lib
310+ libcusparse . dev
311+ libcusparse . lib
312+ nccl . dev # Provides nccl.h AND a static copy of NCCL!
313+ ] ++ lists . optionals ( strings . versionOlder cudaVersion "11.8" ) [
314+ cuda_nvprof . dev # <cuda_profiler_api.h>
315+ ] ++ lists . optionals ( strings . versionAtLeast cudaVersion "11.8" ) [
316+ cuda_profiler_api . dev # <cuda_profiler_api.h>
317+ ] )
295318 ++ lib . optionals rocmSupport [ openmp ]
296319 ++ lib . optionals ( cudaSupport || rocmSupport ) [ magma ]
297320 ++ lib . optionals stdenv . isLinux [ numactl ]
@@ -335,7 +358,6 @@ in buildPythonPackage rec {
335358
336359 checkPhase = with lib . versions ; with lib . strings ; concatStringsSep " " [
337360 "runHook preCheck"
338- cudaStubEnv
339361 "${ python . interpreter } test/run_test.py"
340362 "--exclude"
341363 ( concatStringsSep " " [
@@ -419,6 +441,17 @@ in buildPythonPackage rec {
419441 license = licenses . bsd3 ;
420442 maintainers = with maintainers ; [ teh thoughtpolice tscholak ] ; # tscholak esp. for darwin-related builds
421443 platforms = with platforms ; linux ++ lib . optionals ( ! cudaSupport && ! rocmSupport ) darwin ;
422- broken = rocmSupport && cudaSupport ; # CUDA and ROCm are mutually exclusive
444+ broken = builtins . any trivial . id [
445+ # CUDA and ROCm are mutually exclusive
446+ ( cudaSupport && rocmSupport )
447+ # CUDA is only supported on Linux
448+ ( cudaSupport && ! stdenv . isLinux )
449+ # Only CUDA 11 is currently supported
450+ ( cudaSupport && ( cudaPackages . cudaMajorVersion != "11" ) )
451+ # MPI cudatoolkit does not match cudaPackages.cudatoolkit
452+ ( MPISupport && cudaSupport && ( mpi . cudatoolkit != cudaPackages . cudatoolkit ) )
453+ # Magma cudaPackages does not match cudaPackages
454+ ( cudaSupport && ( magma . cudaPackages != cudaPackages ) )
455+ ] ;
423456 } ;
424457}
0 commit comments