Skip to content

Commit d73e75a

Browse files
committed
Support for ZEL_DRIVERS_ORDER to order based off user input
Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
1 parent ff8c99d commit d73e75a

4 files changed

Lines changed: 1066 additions & 4 deletions

File tree

source/loader/ze_loader.cpp

Lines changed: 273 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99

1010
#include "driver_discovery.h"
1111
#include <iostream>
12+
#include <string>
13+
#include <vector>
14+
#include <map>
15+
#include <set>
16+
#include <sstream>
17+
#include <cstdlib>
18+
#include <algorithm>
1219

1320
#ifdef __linux__
1421
#include <unistd.h>
@@ -72,6 +79,266 @@ namespace loader
7279
return a.driverType < b.driverType;
7380
}
7481

82+
// Helper function to map driver type string to enum
83+
zel_driver_type_t stringToDriverType(const std::string& typeStr) {
84+
if (typeStr == "DISCRETE_GPU_ONLY") {
85+
return ZEL_DRIVER_TYPE_DISCRETE_GPU;
86+
} else if (typeStr == "GPU") {
87+
return ZEL_DRIVER_TYPE_GPU;
88+
} else if (typeStr == "INTEGRATED_GPU_ONLY") {
89+
return ZEL_DRIVER_TYPE_INTEGRATED_GPU;
90+
} else if (typeStr == "NPU") {
91+
return ZEL_DRIVER_TYPE_NPU;
92+
}
93+
return ZEL_DRIVER_TYPE_FORCE_UINT32; // Invalid
94+
}
95+
96+
// Helper function to trim whitespace
97+
std::string trim(const std::string& str) {
98+
const std::string whitespace = " \t\n\r\f\v";
99+
size_t start = str.find_first_not_of(whitespace);
100+
if (start == std::string::npos) return "";
101+
size_t end = str.find_last_not_of(whitespace);
102+
return str.substr(start, end - start + 1);
103+
}
104+
105+
// Structure to hold parsed ordering instructions
106+
struct DriverOrderSpec {
107+
enum Type { BY_GLOBAL_INDEX, BY_TYPE, BY_TYPE_AND_INDEX } type;
108+
uint32_t globalIndex = 0;
109+
zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32;
110+
uint32_t typeIndex = 0;
111+
};
112+
113+
// Parse ZEL_DRIVERS_ORDER environment variable
114+
std::vector<DriverOrderSpec> parseDriverOrder(const std::string& orderStr) {
115+
std::vector<DriverOrderSpec> specs;
116+
117+
// Split by comma
118+
std::vector<std::string> tokens;
119+
std::stringstream ss(orderStr);
120+
std::string token;
121+
122+
while (std::getline(ss, token, ',')) {
123+
token = trim(token);
124+
if (token.empty()) continue;
125+
126+
DriverOrderSpec spec;
127+
128+
// Check if it contains a colon (type:index format)
129+
size_t colonPos = token.find(':');
130+
if (colonPos != std::string::npos) {
131+
// Format: <driver_type>:<driver_index>
132+
std::string typeStr = trim(token.substr(0, colonPos));
133+
std::string indexStr = trim(token.substr(colonPos + 1));
134+
135+
spec.driverType = stringToDriverType(typeStr);
136+
if (spec.driverType == ZEL_DRIVER_TYPE_FORCE_UINT32) {
137+
continue; // Invalid driver type, skip
138+
}
139+
140+
try {
141+
spec.typeIndex = std::stoul(indexStr);
142+
spec.type = DriverOrderSpec::BY_TYPE_AND_INDEX;
143+
specs.push_back(spec);
144+
} catch (const std::exception&) {
145+
// Invalid index, skip
146+
continue;
147+
}
148+
} else {
149+
// Check if it's a pure number (global index) or driver type
150+
try {
151+
spec.globalIndex = std::stoul(token);
152+
spec.type = DriverOrderSpec::BY_GLOBAL_INDEX;
153+
specs.push_back(spec);
154+
} catch (const std::exception&) {
155+
// Not a number, try as driver type
156+
spec.driverType = stringToDriverType(token);
157+
if (spec.driverType != ZEL_DRIVER_TYPE_FORCE_UINT32) {
158+
spec.type = DriverOrderSpec::BY_TYPE;
159+
specs.push_back(spec);
160+
}
161+
}
162+
}
163+
}
164+
165+
return specs;
166+
}
167+
168+
void context_t::driverOrdering(driver_vector_t *drivers) {
169+
const char* orderEnvVar = std::getenv("ZEL_DRIVERS_ORDER");
170+
if (!orderEnvVar || strlen(orderEnvVar) == 0) {
171+
return; // No ordering specified
172+
}
173+
174+
std::string orderStr(orderEnvVar);
175+
std::vector<DriverOrderSpec> specs = parseDriverOrder(orderStr);
176+
177+
if (specs.empty()) {
178+
if (debugTraceEnabled) {
179+
std::string message = "driverOrdering: ZEL_DRIVERS_ORDER parsing failed or empty: " + orderStr;
180+
debug_trace_message(message, "");
181+
}
182+
return;
183+
}
184+
185+
if (debugTraceEnabled) {
186+
std::string message = "driverOrdering:ZEL_DRIVERS_ORDER parsing successful: " + orderStr + ", specs count: " + std::to_string(specs.size());
187+
debug_trace_message(message, "");
188+
}
189+
190+
// Create a copy of the original driver vector for reference
191+
driver_vector_t originalDrivers = *drivers;
192+
193+
driver_vector_t discreteGPUDrivers;
194+
driver_vector_t integratedGPUDrivers;
195+
driver_vector_t npuDrivers;
196+
driver_vector_t gpuDrivers;
197+
198+
std::vector<uint32_t> discreteGPUIndices;
199+
std::vector<uint32_t> integratedGPUIndices;
200+
std::vector<uint32_t> npuIndices;
201+
std::vector<uint32_t> gpuIndices;
202+
203+
// Group drivers by type and track their original indices
204+
for (uint32_t i = 0; i < originalDrivers.size(); ++i) {
205+
const auto& driver = originalDrivers[i];
206+
switch (driver.driverType) {
207+
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
208+
discreteGPUDrivers.push_back(driver);
209+
discreteGPUIndices.push_back(i);
210+
break;
211+
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
212+
integratedGPUDrivers.push_back(driver);
213+
integratedGPUIndices.push_back(i);
214+
break;
215+
case ZEL_DRIVER_TYPE_GPU:
216+
gpuDrivers.push_back(driver);
217+
gpuIndices.push_back(i);
218+
break;
219+
case ZEL_DRIVER_TYPE_NPU:
220+
npuDrivers.push_back(driver);
221+
npuIndices.push_back(i);
222+
break;
223+
case ZEL_DRIVER_TYPE_OTHER:
224+
npuDrivers.push_back(driver);
225+
npuIndices.push_back(i);
226+
break;
227+
case ZEL_DRIVER_TYPE_MIXED:
228+
// Mixed drivers go to gpuDrivers
229+
gpuDrivers.push_back(driver);
230+
gpuIndices.push_back(i);
231+
break;
232+
default:
233+
break;
234+
}
235+
}
236+
237+
// Create new ordered driver vector
238+
driver_vector_t orderedDrivers;
239+
std::set<uint32_t> usedGlobalIndices;
240+
std::set<std::pair<zel_driver_type_t, uint32_t>> usedTypeIndices;
241+
242+
// Apply ordering specifications
243+
for (const auto& spec : specs) {
244+
switch (spec.type) {
245+
case DriverOrderSpec::BY_GLOBAL_INDEX:
246+
if (spec.globalIndex < originalDrivers.size() &&
247+
usedGlobalIndices.find(spec.globalIndex) == usedGlobalIndices.end()) {
248+
orderedDrivers.push_back(originalDrivers[spec.globalIndex]);
249+
usedGlobalIndices.insert(spec.globalIndex);
250+
}
251+
break;
252+
253+
case DriverOrderSpec::BY_TYPE:
254+
// Add all drivers of this type that haven't been used
255+
{
256+
std::vector<uint32_t>* typeIndices = nullptr;
257+
switch (spec.driverType) {
258+
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
259+
typeIndices = &discreteGPUIndices;
260+
break;
261+
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
262+
typeIndices = &integratedGPUIndices;
263+
break;
264+
case ZEL_DRIVER_TYPE_GPU:
265+
typeIndices = &gpuIndices;
266+
break;
267+
case ZEL_DRIVER_TYPE_NPU:
268+
case ZEL_DRIVER_TYPE_OTHER:
269+
typeIndices = &npuIndices;
270+
break;
271+
default:
272+
break;
273+
}
274+
275+
if (typeIndices) {
276+
for (uint32_t globalIdx : *typeIndices) {
277+
if (usedGlobalIndices.find(globalIdx) == usedGlobalIndices.end()) {
278+
orderedDrivers.push_back(originalDrivers[globalIdx]);
279+
usedGlobalIndices.insert(globalIdx);
280+
}
281+
}
282+
}
283+
}
284+
break;
285+
286+
case DriverOrderSpec::BY_TYPE_AND_INDEX:
287+
{
288+
std::vector<uint32_t>* typeIndices = nullptr;
289+
switch (spec.driverType) {
290+
case ZEL_DRIVER_TYPE_DISCRETE_GPU:
291+
typeIndices = &discreteGPUIndices;
292+
break;
293+
case ZEL_DRIVER_TYPE_INTEGRATED_GPU:
294+
typeIndices = &integratedGPUIndices;
295+
break;
296+
case ZEL_DRIVER_TYPE_GPU:
297+
typeIndices = &gpuIndices;
298+
break;
299+
case ZEL_DRIVER_TYPE_NPU:
300+
case ZEL_DRIVER_TYPE_OTHER:
301+
typeIndices = &npuIndices;
302+
break;
303+
default:
304+
break;
305+
}
306+
307+
if (typeIndices && spec.typeIndex < typeIndices->size()) {
308+
auto typeIndexPair = std::make_pair(spec.driverType, spec.typeIndex);
309+
if (usedTypeIndices.find(typeIndexPair) == usedTypeIndices.end()) {
310+
uint32_t globalIdx = (*typeIndices)[spec.typeIndex];
311+
if (usedGlobalIndices.find(globalIdx) == usedGlobalIndices.end()) {
312+
orderedDrivers.push_back(originalDrivers[globalIdx]);
313+
usedGlobalIndices.insert(globalIdx);
314+
usedTypeIndices.insert(typeIndexPair);
315+
}
316+
}
317+
}
318+
}
319+
break;
320+
}
321+
}
322+
323+
// Add remaining drivers in their original order
324+
for (uint32_t i = 0; i < originalDrivers.size(); ++i) {
325+
if (usedGlobalIndices.find(i) == usedGlobalIndices.end()) {
326+
orderedDrivers.push_back(originalDrivers[i]);
327+
}
328+
}
329+
330+
// Replace the original driver vector with the ordered one
331+
*drivers = orderedDrivers;
332+
333+
if (debugTraceEnabled) {
334+
std::string message = "driverOrdering: Drivers after ZEL_DRIVERS_ORDER:";
335+
for (uint32_t i = 0; i < drivers->size(); ++i) {
336+
message += "\n[" + std::to_string(i) + "] Driver Type: " + std::to_string((*drivers)[i].driverType) + " Driver Name: " + (*drivers)[i].name;
337+
}
338+
debug_trace_message(message, "");
339+
}
340+
}
341+
75342
bool context_t::driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly) {
76343
ze_init_driver_type_desc_t permissiveDesc = {};
77344
permissiveDesc.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
@@ -246,6 +513,10 @@ namespace loader
246513
}
247514
debug_trace_message(message, "");
248515
}
516+
517+
// Apply driver ordering based on ZEL_DRIVERS_ORDER environment variable
518+
driverOrdering(drivers);
519+
249520
return true;
250521
}
251522

@@ -577,7 +848,7 @@ namespace loader
577848
GET_FUNCTION_PTR(validationLayer, "zelLoaderGetVersion"));
578849
zel_component_version_t compVersion;
579850
if(getVersion && ZE_RESULT_SUCCESS == getVersion(&compVersion))
580-
{
851+
{
581852
compVersions.push_back(compVersion);
582853
}
583854
} else if (debugTraceEnabled) {
@@ -602,7 +873,7 @@ namespace loader
602873
GET_FUNCTION_PTR(tracingLayer, "zelLoaderGetVersion"));
603874
zel_component_version_t compVersion;
604875
if(getVersion && ZE_RESULT_SUCCESS == getVersion(&compVersion))
605-
{
876+
{
606877
compVersions.push_back(compVersion);
607878
}
608879
} else if (debugTraceEnabled) {

source/loader/ze_loader_internal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace loader
3838
ZEL_DRIVER_TYPE_INTEGRATED_GPU = 2, ///< The driver has Integrated GPUs only
3939
ZEL_DRIVER_TYPE_MIXED = 3, ///< The driver has Heterogenous driver types not limited to GPU or NPU.
4040
ZEL_DRIVER_TYPE_OTHER = 4, ///< The driver has No GPU Devices and has other device types only
41+
ZEL_DRIVER_TYPE_NPU = 5, ///< The driver has NPU devices only
4142
ZEL_DRIVER_TYPE_FORCE_UINT32 = 0x7fffffff
4243

4344
} zel_driver_type_t;
@@ -150,6 +151,7 @@ namespace loader
150151
ze_result_t init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly);
151152
void add_loader_version();
152153
bool driverSorting(driver_vector_t *drivers, ze_init_driver_type_desc_t* desc, bool sysmanOnly);
154+
void driverOrdering(driver_vector_t *drivers);
153155
~context_t();
154156
bool intercept_enabled = false;
155157
bool debugTraceEnabled = false;

0 commit comments

Comments
 (0)