55#include " cuda_dlib.h"
66#include " cudnn_dlibapi.h"
77#include < math_constants.h>
8+ #include < cstdlib>
9+ #include < cstring>
810
911
1012namespace dlib
1113{
14+ namespace
15+ {
16+ bool cuda_device_available (
17+ )
18+ {
19+ int num_devices;
20+ return cudaGetDeviceCount (&num_devices) == cudaSuccess && num_devices > 0 ;
21+ }
22+
23+ bool cuda_disabled_by_environment (
24+ )
25+ {
26+ const char * var = std::getenv (" DLIB_DISABLE_CUDA_USE" );
27+ return var != nullptr &&
28+ std::strcmp (var, " " ) != 0 &&
29+ std::strcmp (var, " 0" ) != 0 &&
30+ std::strcmp (var, " false" ) != 0 &&
31+ std::strcmp (var, " False" ) != 0 &&
32+ std::strcmp (var, " FALSE" ) != 0 ;
33+ }
34+
35+ bool use_cuda_impl (
36+ )
37+ {
38+ static const bool var = !cuda_disabled_by_environment () && cuda_device_available ();
39+ return var;
40+ }
41+
42+ }
43+
1244 namespace cuda
1345 {
1446
@@ -18,21 +50,34 @@ namespace dlib
1850 int dev
1951 )
2052 {
53+ if (!use_cuda ())
54+ {
55+ DLIB_CASSERT (dev == 0 , " dlib::cuda::set_device(id) called with an invalid device id." );
56+ return ;
57+ }
58+
2159 CHECK_CUDA (cudaSetDevice (dev));
2260 }
2361
2462 int get_device (
2563 )
2664 {
2765 int dev = 0 ;
28- CHECK_CUDA (cudaGetDevice (&dev));
66+ if (use_cuda ())
67+ CHECK_CUDA (cudaGetDevice (&dev));
2968 return dev;
3069 }
3170
3271 std::string get_device_name (
3372 int device
3473 )
3574 {
75+ if (!use_cuda ())
76+ {
77+ DLIB_CASSERT (device == 0 , " dlib::cuda::get_device_name(device) called with an invalid device id." );
78+ return " CUDA_DISABLED" ;
79+ }
80+
3681 cudaDeviceProp props;
3782 CHECK_CUDA (cudaGetDeviceProperties (&props, device));
3883 return props.name ;
@@ -41,19 +86,32 @@ namespace dlib
4186 void set_current_device_blocking_sync (
4287 )
4388 {
44- CHECK_CUDA (cudaSetDeviceFlags (cudaDeviceScheduleBlockingSync));
89+ if (use_cuda ())
90+ CHECK_CUDA (cudaSetDeviceFlags (cudaDeviceScheduleBlockingSync));
91+ }
92+
93+ bool use_cuda (
94+ )
95+ {
96+ return use_cuda_impl ();
4597 }
4698
4799 int get_num_devices (
48100 )
49101 {
102+ if (!use_cuda ())
103+ return 0 ;
104+
50105 int num_devices;
51106 CHECK_CUDA (cudaGetDeviceCount (&num_devices));
52107 return num_devices;
53108 }
54109
55110 bool can_access_peer (int device_id, int peer_device_id)
56111 {
112+ if (!use_cuda ())
113+ return false ;
114+
57115 int can_access;
58116 CHECK_CUDA (cudaDeviceCanAccessPeer (&can_access, device_id, peer_device_id));
59117 return can_access != 0 ;
@@ -65,6 +123,9 @@ namespace dlib
65123
66124 void device_synchronize (int dev)
67125 {
126+ if (!use_cuda ())
127+ return ;
128+
68129 raii_set_device set_dev (dev);
69130 CHECK_CUDA (cudaDeviceSynchronize ());
70131 }
@@ -76,6 +137,9 @@ namespace dlib
76137 int peer_device_id
77138 ) : call_disable(false ), device_id(device_id), peer_device_id(peer_device_id)
78139 {
140+ if (!use_cuda ())
141+ return ;
142+
79143 raii_set_device set_dev (device_id);
80144
81145 auto err = cudaDeviceEnablePeerAccess (peer_device_id, 0 );
@@ -3220,4 +3284,3 @@ namespace dlib
32203284
32213285 }
32223286}
3223-
0 commit comments