|
5 | 5 | * SPDX-License-Identifier: MIT |
6 | 6 | * |
7 | 7 | */ |
8 | | -#include "ze_loader_internal.h" |
| 8 | +#include "ze_loader_utils.h" |
9 | 9 |
|
10 | 10 | #include "driver_discovery.h" |
11 | 11 | #include <iostream> |
12 | | -#include <string> |
13 | | -#include <vector> |
14 | | -#include <map> |
15 | | -#include <set> |
16 | | -#include <sstream> |
17 | | -#include <cstdlib> |
18 | | -#include <algorithm> |
19 | 12 |
|
20 | 13 | #ifdef __linux__ |
21 | 14 | #include <unistd.h> |
@@ -79,92 +72,6 @@ namespace loader |
79 | 72 | return a.driverType < b.driverType; |
80 | 73 | } |
81 | 74 |
|
82 | | - // Helper function to map driver type string to enum |
83 | | - zel_driver_type_t stringToDriverType(const std::string& typeStr) { |
84 | | - if (typeStr == "DISCRETE_GPU_ONLY") { |
85 | | - return ZEL_DRIVER_TYPE_DISCRETE_GPU; |
86 | | - } else if (typeStr == "GPU") { |
87 | | - return ZEL_DRIVER_TYPE_GPU; |
88 | | - } else if (typeStr == "INTEGRATED_GPU_ONLY") { |
89 | | - return ZEL_DRIVER_TYPE_INTEGRATED_GPU; |
90 | | - } else if (typeStr == "NPU") { |
91 | | - return ZEL_DRIVER_TYPE_NPU; |
92 | | - } |
93 | | - return ZEL_DRIVER_TYPE_FORCE_UINT32; // Invalid |
94 | | - } |
95 | | - |
96 | | - // Helper function to trim whitespace |
97 | | - std::string trim(const std::string& str) { |
98 | | - const std::string whitespace = " \t\n\r\f\v"; |
99 | | - size_t start = str.find_first_not_of(whitespace); |
100 | | - if (start == std::string::npos) return ""; |
101 | | - size_t end = str.find_last_not_of(whitespace); |
102 | | - return str.substr(start, end - start + 1); |
103 | | - } |
104 | | - |
105 | | - // Structure to hold parsed ordering instructions |
106 | | - struct DriverOrderSpec { |
107 | | - enum Type { BY_GLOBAL_INDEX, BY_TYPE, BY_TYPE_AND_INDEX } type; |
108 | | - uint32_t globalIndex = 0; |
109 | | - zel_driver_type_t driverType = ZEL_DRIVER_TYPE_FORCE_UINT32; |
110 | | - uint32_t typeIndex = 0; |
111 | | - }; |
112 | | - |
113 | | - // Parse ZEL_DRIVERS_ORDER environment variable |
114 | | - std::vector<DriverOrderSpec> parseDriverOrder(const std::string& orderStr) { |
115 | | - std::vector<DriverOrderSpec> specs; |
116 | | - |
117 | | - // Split by comma |
118 | | - std::vector<std::string> tokens; |
119 | | - std::stringstream ss(orderStr); |
120 | | - std::string token; |
121 | | - |
122 | | - while (std::getline(ss, token, ',')) { |
123 | | - token = trim(token); |
124 | | - if (token.empty()) continue; |
125 | | - |
126 | | - DriverOrderSpec spec; |
127 | | - |
128 | | - // Check if it contains a colon (type:index format) |
129 | | - size_t colonPos = token.find(':'); |
130 | | - if (colonPos != std::string::npos) { |
131 | | - // Format: <driver_type>:<driver_index> |
132 | | - std::string typeStr = trim(token.substr(0, colonPos)); |
133 | | - std::string indexStr = trim(token.substr(colonPos + 1)); |
134 | | - |
135 | | - spec.driverType = stringToDriverType(typeStr); |
136 | | - if (spec.driverType == ZEL_DRIVER_TYPE_FORCE_UINT32) { |
137 | | - continue; // Invalid driver type, skip |
138 | | - } |
139 | | - |
140 | | - try { |
141 | | - spec.typeIndex = std::stoul(indexStr); |
142 | | - spec.type = DriverOrderSpec::BY_TYPE_AND_INDEX; |
143 | | - specs.push_back(spec); |
144 | | - } catch (const std::exception&) { |
145 | | - // Invalid index, skip |
146 | | - continue; |
147 | | - } |
148 | | - } else { |
149 | | - // Check if it's a pure number (global index) or driver type |
150 | | - try { |
151 | | - spec.globalIndex = std::stoul(token); |
152 | | - spec.type = DriverOrderSpec::BY_GLOBAL_INDEX; |
153 | | - specs.push_back(spec); |
154 | | - } catch (const std::exception&) { |
155 | | - // Not a number, try as driver type |
156 | | - spec.driverType = stringToDriverType(token); |
157 | | - if (spec.driverType != ZEL_DRIVER_TYPE_FORCE_UINT32) { |
158 | | - spec.type = DriverOrderSpec::BY_TYPE; |
159 | | - specs.push_back(spec); |
160 | | - } |
161 | | - } |
162 | | - } |
163 | | - } |
164 | | - |
165 | | - return specs; |
166 | | - } |
167 | | - |
168 | 75 | void context_t::driverOrdering(driver_vector_t *drivers) { |
169 | 76 | std::string orderStr = getenv_string("ZEL_DRIVERS_ORDER"); |
170 | 77 | if (orderStr.empty()) { |
@@ -241,15 +148,15 @@ namespace loader |
241 | 148 | // Apply ordering specifications |
242 | 149 | for (const auto& spec : specs) { |
243 | 150 | switch (spec.type) { |
244 | | - case DriverOrderSpec::BY_GLOBAL_INDEX: |
| 151 | + case DriverOrderSpecType::BY_GLOBAL_INDEX: |
245 | 152 | if (spec.globalIndex < originalDrivers.size() && |
246 | 153 | usedGlobalIndices.find(spec.globalIndex) == usedGlobalIndices.end()) { |
247 | 154 | orderedDrivers.push_back(originalDrivers[spec.globalIndex]); |
248 | 155 | usedGlobalIndices.insert(spec.globalIndex); |
249 | 156 | } |
250 | 157 | break; |
251 | 158 |
|
252 | | - case DriverOrderSpec::BY_TYPE: |
| 159 | + case DriverOrderSpecType::BY_TYPE: |
253 | 160 | // Add all drivers of this type that haven't been used |
254 | 161 | { |
255 | 162 | std::vector<uint32_t>* typeIndices = nullptr; |
@@ -282,7 +189,7 @@ namespace loader |
282 | 189 | } |
283 | 190 | break; |
284 | 191 |
|
285 | | - case DriverOrderSpec::BY_TYPE_AND_INDEX: |
| 192 | + case DriverOrderSpecType::BY_TYPE_AND_INDEX: |
286 | 193 | { |
287 | 194 | std::vector<uint32_t>* typeIndices = nullptr; |
288 | 195 | switch (spec.driverType) { |
|
0 commit comments