@@ -85,11 +85,21 @@ static Try<set<Gpu>> enumerateGpus(
8585 if (flags.nvidia_gpu_devices .isSome ()) {
8686 indices = flags.nvidia_gpu_devices .get ();
8787 } else {
88- for (size_t i = 0 ; i < resources.gpus ().getOrElse (0 ); ++i) {
88+ Try<unsigned int > available = nvml::deviceGetCount ();
89+ if (available.isError ()) {
90+ return Error (" Failed to nvml::deviceGetCount: " + available.error ());
91+ }
92+
93+ for (unsigned int i = 0 ; i < available.get (); ++i) {
8994 indices.push_back (i);
9095 }
9196 }
9297
98+ Try<unsigned int > caps_major = nvml::systemGetCapsMajor ();
99+ if (caps_major.isError ()) {
100+ return Error (" Failed to get nvidia caps major: " + caps_major.error ());
101+ }
102+
93103 set<Gpu> gpus;
94104
95105 foreach (unsigned int index, indices) {
@@ -103,17 +113,91 @@ static Try<set<Gpu>> enumerateGpus(
103113 return Error (" Failed to nvml::deviceGetMinorNumber: " + minor.error ());
104114 }
105115
106- Gpu gpu;
107- gpu.major = NVIDIA_MAJOR_DEVICE;
108- gpu.minor = minor.get ();
116+ Try<bool > ismig = nvml::deviceGetMigMode (handle.get ());
117+ if (ismig.isError ()) {
118+ return Error (" Failed to nvml::deviceGetMigMode: " + ismig.error ());
119+ }
120+
121+ if (!ismig.get ()) {
122+ Gpu gpu;
123+ gpu.major = NVIDIA_MAJOR_DEVICE;
124+ gpu.minor = minor.get ();
125+
126+ gpus.insert (gpu);
109127
110- gpus.insert (gpu);
128+ continue ;
129+ }
130+
131+ Try<unsigned int > migcount = nvml::deviceGetMigDeviceCount (handle.get ());
132+ if (migcount.isError ()) {
133+ return Error (" Failed to nvml::deviceGetMigDeviceCount: " + migcount.error ());
134+ }
135+
136+ for (unsigned int migindex = 0 ; migindex < migcount.get (); migindex++) {
137+ Try<nvmlDevice_t> mighandle = nvml::deviceGetMigDeviceHandleByIndex (handle.get (), migindex);
138+ if (mighandle.isError ()) {
139+ return Error (" Failed to nvml::deviceGetMigDeviceHandleByIndex: " + mighandle.error ());
140+ }
141+
142+ Try<unsigned int > gi_minor = nvml::deviceGetGpuInstanceMinor (mighandle.get ());
143+ if (gi_minor.isError ()) {
144+ return Error (" Failed to nvml::deviceGetGpuInstanceMinor: " + gi_minor.error ());
145+ }
146+
147+ Try<unsigned int > ci_minor = nvml::deviceGetComputeInstanceMinor (mighandle.get ());
148+ if (ci_minor.isError ()) {
149+ return Error (" Failed to nvml::deviceGetComputeInstanceMinor: " + ci_minor.error ());
150+ }
151+
152+ Gpu gpu;
153+ gpu.major = NVIDIA_MAJOR_DEVICE;
154+ gpu.minor = minor.get ();
155+ gpu.ismig = true ;
156+ gpu.caps_major = caps_major.get ();
157+ gpu.gi_minor = gi_minor.get ();
158+ gpu.ci_minor = ci_minor.get ();
159+
160+ gpus.insert (gpu);
161+ }
111162 }
112163
113164 return gpus;
114165}
115166
116167
168+ static Try<unsigned int > countGpuInstancesForDevices (
169+ const vector<unsigned int >& devices)
170+ {
171+ unsigned int count = 0 ;
172+
173+ foreach (unsigned int device, devices) {
174+ Try<nvmlDevice_t> handle = nvml::deviceGetHandleByIndex (device);
175+ if (handle.isError ()) {
176+ return Error (" Failed to nvml::deviceGetHandleByIndex: " + handle.error ());
177+ }
178+
179+ Try<bool > ismig = nvml::deviceGetMigMode (handle.get ());
180+ if (ismig.isError ()) {
181+ return Error (" Failed to nvml::deviceGetMigMode: " + ismig.error ());
182+ }
183+
184+ if (!ismig.get ()) {
185+ count++;
186+ continue ;
187+ }
188+
189+ Try<unsigned int > migcount = nvml::deviceGetMigDeviceCount (handle.get ());
190+ if (migcount.isError ()) {
191+ return Error (" Failed to nvml::deviceGetMigDeviceCount: " + migcount.error ());
192+ }
193+
194+ count += migcount.get ();
195+ }
196+
197+ return count;
198+ }
199+
200+
117201// To determine the proper number of GPU resources to return, we
118202// need to check both --resources and --nvidia_gpu_devices.
119203// There are two cases to consider:
@@ -174,11 +258,6 @@ static Try<Resources> enumerateGpuResources(const Flags& flags)
174258 return Error (" Failed to nvml::initialize: " + initialized.error ());
175259 }
176260
177- Try<unsigned int > available = nvml::deviceGetCount ();
178- if (available.isError ()) {
179- return Error (" Failed to nvml::deviceGetCount: " + available.error ());
180- }
181-
182261 // The `Resources` wrapper does not allow us to distinguish between
183262 // a user specifying "gpus:0" in the --resources flag and not
184263 // specifying "gpus" at all. To help with this we short circuit
@@ -225,9 +304,11 @@ static Try<Resources> enumerateGpuResources(const Flags& flags)
225304 return Error (" '--nvidia_gpu_devices' contains duplicates" );
226305 }
227306
228- if (flags.nvidia_gpu_devices ->size () != resources.gpus ().get ()) {
229- return Error (" '--resources' and '--nvidia_gpu_devices' specify"
230- " different numbers of GPU devices" );
307+ Try<unsigned int > available = countGpuInstancesForDevices (unique);
308+ if (available.isError ()) {
309+ return Error (" Failed to count all GPU instances for devices"
310+ " specified by --nvidia_gpu_devices: "
311+ + available.error ());
231312 }
232313
233314 if (resources.gpus ().get () > available.get ()) {
@@ -238,6 +319,22 @@ static Try<Resources> enumerateGpuResources(const Flags& flags)
238319 return resources;
239320 }
240321
322+ Try<unsigned int > available = nvml::deviceGetCount ();
323+ if (available.isError ()) {
324+ return Error (" Failed to nvml::deviceGetCount: " + available.error ());
325+ }
326+
327+ vector<unsigned int > indices;
328+ for (unsigned int i = 0 ; i < available.get (); ++i) {
329+ indices.push_back (i);
330+ }
331+
332+ available = countGpuInstancesForDevices (indices);
333+ if (available.isError ()) {
334+ return Error (" Failed to count all GPU instances: "
335+ + available.error ());
336+ }
337+
241338 return Resources::parse (
242339 " gpus" ,
243340 stringify (available.get ()),
@@ -378,7 +475,15 @@ Future<Nothing> NvidiaGpuAllocator::deallocate(const set<Gpu>& gpus)
378475bool operator <(const Gpu& left, const Gpu& right)
379476{
380477 if (left.major == right.major ) {
381- return left.minor < right.minor ;
478+ // Either or both aren't MIG, comparing major/minor is enough
479+ if (!left.ismig || !right.ismig || (left.minor != right.minor )) {
480+ return left.minor < right.minor ;
481+ }
482+
483+ if (left.gi_minor == right.gi_minor ) {
484+ return left.ci_minor < right.ci_minor ;
485+ }
486+ return left.gi_minor < right.gi_minor ;
382487 }
383488 return left.major < right.major ;
384489}
@@ -404,7 +509,14 @@ bool operator>=(const Gpu& left, const Gpu& right)
404509
405510bool operator ==(const Gpu& left, const Gpu& right)
406511{
407- return left.major == right.major && left.minor == right.minor ;
512+ if (left.ismig != right.ismig )
513+ return false ;
514+
515+ if (!left.ismig )
516+ return left.major == right.major && left.minor == right.minor ;
517+
518+ return left.major == right.major && left.minor == right.minor
519+ && left.gi_minor == right.gi_minor && left.ci_minor == right.ci_minor ;
408520}
409521
410522
@@ -416,7 +528,10 @@ bool operator!=(const Gpu& left, const Gpu& right)
416528
417529ostream& operator <<(ostream& stream, const Gpu& gpu)
418530{
419- return stream << gpu.major << ' .' << gpu.minor ;
531+ if (gpu.ismig )
532+ return stream << gpu.major << ' .' << gpu.minor << ' :' << gpu.gi_minor << ' .' << gpu.ci_minor ;
533+ else
534+ return stream << gpu.major << ' .' << gpu.minor ;
420535}
421536
422537} // namespace slave {
0 commit comments