Skip to content

Commit faaf22c

Browse files
author
Connor Baker
committed
libtorch: work on some cuda refactoring
1 parent 062a9c4 commit faaf22c

1 file changed

Lines changed: 48 additions & 25 deletions

File tree

  • pkgs/development/libraries/science/math/libtorch

pkgs/development/libraries/science/math/libtorch/default.nix

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
fetchFromGitHub,
33
fetchpatch,
44
pkgs,
5-
symlinkJoin,
65
# nativeBuildInputs
76
asmjit,
87
blas,
@@ -23,6 +22,7 @@
2322
mpi,
2423
ninja,
2524
numactl,
25+
onnx,
2626
protobuf,
2727
psimd,
2828
pthreadpool,
@@ -43,7 +43,7 @@
4343
useXnnpack ? true,
4444
useZstd ? true,
4545
}: let
46-
inherit (lib) lists;
46+
inherit (lib) lists strings;
4747
setBool = bool:
4848
if bool
4949
then "ON"
@@ -82,28 +82,6 @@
8282
};
8383
});
8484

85-
cuda-redist = symlinkJoin {
86-
name = "cuda-redist";
87-
paths = with cudaPackages;
88-
[
89-
autoAddOpenGLRunpathHook
90-
cuda_cccl # <thrust> and CUB
91-
cuda_cudart
92-
cuda_cupti # Needed by Kineto for GPU profiling
93-
cuda_nvcc
94-
cuda_nvml_dev
95-
cuda_nvrtc
96-
cuda_nvtx
97-
libcublas
98-
libcufft
99-
libcurand
100-
libcusolver
101-
libcusparse
102-
nccl.dev
103-
]
104-
++ lists.optionals useCudnn [cudnn];
105-
};
106-
10785
mkDerivation =
10886
if useCuda
10987
then cudaPackages.backendStdenv.mkDerivation
@@ -171,6 +149,7 @@ in
171149
rm -rf FXdiv*
172150
rm -rf gloo*
173151
rm -rf ideep/mkl-dnn*
152+
rm -rf onnx*
174153
rm -rf protobuf*
175154
rm -rf psimd*
176155
rm -rf pthreadpool*
@@ -235,6 +214,7 @@ in
235214
fxdiv
236215
gflags
237216
glog
217+
onnx
238218
protobuf
239219
psimd
240220
pthreadpool
@@ -248,7 +228,14 @@ in
248228
zlib
249229
]
250230
# Optional dependencies
251-
++ lists.optionals useCuda [cuda-redist]
231+
++ lists.optionals useCuda (
232+
# TODO(@connorbaker): Is this correct that we need both cudart and nvcc as native dependencies?
233+
with cudaPackages; [
234+
autoAddOpenGLRunpathHook
235+
cuda_cudart # cuda_runtime.h
236+
cuda_nvcc # crt/host_config.h
237+
]
238+
)
252239
++ lists.optionals useGloo [gloo]
253240
++ lists.optionals useMagma [magma]
254241
++ lists.optionals useMkldnn [oneDNN.dev] # oneDNN is the new name for MKL-DNN
@@ -257,13 +244,48 @@ in
257244
++ lists.optionals useXnnpack [xnnpack]
258245
++ lists.optionals useZstd [zstd.dev];
259246

247+
# TODO(@connorbaker): Currently CUDA build fails with:
248+
# CMake Error at cmake/public/cuda.cmake:65 (message):
249+
# Found two conflicting CUDA installs:
250+
#
251+
# V11.8.89 in
252+
# '/nix/store/rsjxr5b5zifa0wbpziwqfzg7lncfz0f0-cuda_cudart-11.8.89/include'
253+
# and
254+
#
255+
# V11.8.89 in
256+
# '/nix/store/rsjxr5b5zifa0wbpziwqfzg7lncfz0f0-cuda_cudart-11.8.89/include;/nix/store/nljxvgbp6fy0q7cbrp5l5igv57p5fa3v-cuda_nvcc-11.8.89/include;/nix/store/mfk63jcw2r77asgai82rzbzbph10dhh8-cuda_cccl-11.8.89/include;/nix/store/0xhbghrnf7x289m78c8ha2dm6n83wfbg-cuda_cupti-11.8.87/include;/nix/store/4x7gb192a6pskj2skwn9s3m0vnn73bff-cuda_nvml_dev-11.8.86/include;/nix/store/00p0i6kqw6qjbrc4fddqfnv07zcg7gi1-cuda_nvrtc-11.8.89/include;/nix/store/953p97p0inb7wdj50qcz47dy3lh58vhq-cuda_nvtx-11.8.86/include;/nix/store/qsm8bjydfnapr77wzlyzyzcsnkc0yrh2-libcublas-11.11.3.6/include;/nix/store/fszipvg6jw9dsj2lz1izwy7363mwh4fj-libcufft-10.9.0.58/include;/nix/store/8r9kj0rh0kk9iqi32kkm1bdxqb8jipbr-libcurand-10.3.0.86/include;/nix/store/f0d08h7g4apgngbyrgqvpjxmlp3azf0m-libcusolver-11.4.1.48/include;/nix/store/141gw8r2ypg27186mzg81rhndl402l80-libcusparse-11.7.5.86/include;/nix/store/z5ppzlnw5wzy5bbvhm76kfmjmirpkqhb-cuda_profiler_api-11.8.86/include'
257+
buildInputs = lists.optionals useCuda (with cudaPackages;
258+
[
259+
(lib.getDev nccl)
260+
cuda_cccl # <thrust/*>
261+
cuda_cupti
262+
cuda_nvml_dev # <nvml.h>
263+
cuda_nvrtc
264+
cuda_nvtx # -llibNVToolsExt
265+
libcublas
266+
libcufft
267+
libcurand
268+
libcusolver
269+
libcusparse
270+
nccl
271+
]
272+
++ lists.optionals useCudnn [cudnn]
273+
++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
274+
cuda_nvprof # <cuda_profiler_api.h>
275+
]
276+
++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [
277+
cuda_profiler_api # <cuda_profiler_api.h>
278+
]);
279+
260280
cmakeFlags =
261281
# Core configuration options
262282
[
263283
"-DATEN_NO_TEST:BOOL=ON"
264284
"-DBUILD_PYTHON:BOOL=OFF"
265285
"-DBUILD_SHARED_LIBS:BOOL=ON"
266286
"-DCMAKE_BUILD_TYPE:STRING=Release"
287+
"-DCMAKE_C_STANDARD:STRING=17"
288+
"-DCMAKE_CXX_STANDARD:STRING=17"
267289
"-DUSE_PRECOMPILED_HEADERS:BOOL=ON"
268290
]
269291
# Core dependencies
@@ -279,6 +301,7 @@ in
279301
"-DUSE_SYSTEM_FMT:BOOL=ON"
280302
"-DUSE_SYSTEM_FP16:BOOL=ON"
281303
"-DUSE_SYSTEM_FXDIV:BOOL=ON"
304+
"-DUSE_SYSTEM_ONNX:BOOL=ON"
282305
"-DUSE_SYSTEM_PSIMD:BOOL=ON"
283306
"-DUSE_SYSTEM_PTHREADPOOL:BOOL=ON"
284307
"-DUSE_SYSTEM_PYBIND11:BOOL=ON"

0 commit comments

Comments
 (0)