|
36 | 36 | import tempfile |
37 | 37 | import time |
38 | 38 | from abc import ABC, abstractmethod |
| 39 | +from dataclasses import dataclass |
39 | 40 | from pathlib import Path |
40 | 41 | from typing import Any |
41 | 42 | from urllib.parse import parse_qs, urlparse |
@@ -153,6 +154,178 @@ def _write_log_file(self, file: Path | str | None, content: str) -> None: |
153 | 154 | ) |
154 | 155 | _STD_PATTERN = r"\[I\]\s+GPU Compute Time:.*?median\s*=\s*([\d.]+)\s*ms" |
155 | 156 |
|
| 157 | +_URL_PASSWORD_RE = re.compile(r"(://[^:/?#@]+):[^@/?#]+@") |
| 158 | + |
| 159 | + |
| 160 | +def _redact_url_password(s: str) -> str: |
| 161 | + """Replace any ``scheme://user:password@host`` substring with ``user:******@host``. |
| 162 | +
|
| 163 | + Used so SSH passwords supplied via ``--remoteAutoTuningConfig`` don't leak |
| 164 | + into log messages or exception strings. |
| 165 | + """ |
| 166 | + return _URL_PASSWORD_RE.sub(r"\1:******@", s) |
| 167 | + |
| 168 | + |
| 169 | +def _build_base_trtexec_cmd( |
| 170 | + *, |
| 171 | + timing_runs: int, |
| 172 | + warmup_runs: int, |
| 173 | + engine_path: str, |
| 174 | + timing_cache_file: str, |
| 175 | + plugin_libraries: list[str] | None = None, |
| 176 | + log: Any = None, |
| 177 | +) -> list[str]: |
| 178 | + """Build the static portion of the trtexec command line (no ``--onnx=`` yet). |
| 179 | +
|
| 180 | + Plugin libraries that don't exist on disk are skipped with a warning if a |
| 181 | + logger is supplied. The leading ``trtexec`` binary path is not included — |
| 182 | + the caller is responsible for prepending it. |
| 183 | +
|
| 184 | + Args: |
| 185 | + timing_runs: Value for ``--avgRuns`` and ``--iterations``. |
| 186 | + warmup_runs: Value for ``--warmUp``. |
| 187 | + engine_path: Path used for ``--saveEngine=``. |
| 188 | + timing_cache_file: Path used for ``--timingCacheFile=``. |
| 189 | + plugin_libraries: Paths to ``.so`` libraries for ``--staticPlugins``. |
| 190 | + log: Optional logger used to warn about missing plugins and trace adds. |
| 191 | + """ |
| 192 | + cmd = [ |
| 193 | + f"--avgRuns={timing_runs}", |
| 194 | + f"--iterations={timing_runs}", |
| 195 | + f"--warmUp={warmup_runs}", |
| 196 | + "--stronglyTyped", |
| 197 | + f"--saveEngine={engine_path}", |
| 198 | + f"--timingCacheFile={timing_cache_file}", |
| 199 | + ] |
| 200 | + for plugin_lib in plugin_libraries or []: |
| 201 | + plugin_path = Path(plugin_lib).resolve() |
| 202 | + if not plugin_path.exists(): |
| 203 | + if log is not None: |
| 204 | + log.warning(f"Plugin library not found: {plugin_path}") |
| 205 | + continue |
| 206 | + cmd.append(f"--staticPlugins={plugin_path}") |
| 207 | + if log is not None: |
| 208 | + log.debug(f"Added plugin library: {plugin_path}") |
| 209 | + return cmd |
| 210 | + |
| 211 | + |
| 212 | +def _extract_remote_config_value(trtexec_args: list[str], *, log: Any = None) -> str | None: |
| 213 | + """Find the value of ``--remoteAutoTuningConfig`` in ``trtexec_args``. |
| 214 | +
|
| 215 | + Supports both inline (``--remoteAutoTuningConfig=value``) and split |
| 216 | + (``--remoteAutoTuningConfig value``) forms. |
| 217 | +
|
| 218 | + Returns: |
| 219 | + The value as a string, or ``None`` if the flag is absent. Returning |
| 220 | + an empty string is possible (e.g. ``--remoteAutoTuningConfig=``); the |
| 221 | + caller decides whether to treat that as an error. |
| 222 | +
|
| 223 | + Raises: |
| 224 | + ValueError: If the flag appears more than once, has no value at the |
| 225 | + end of the list, or is malformed (e.g. missing the ``=`` |
| 226 | + separator). SSH passwords in malformed args are redacted before |
| 227 | + being included in the error or debug log. |
| 228 | + """ |
| 229 | + matches = [a for a in trtexec_args if "--remoteAutoTuningConfig" in a] |
| 230 | + if not matches: |
| 231 | + return None |
| 232 | + if len(matches) != 1: |
| 233 | + raise ValueError("Exactly one --remoteAutoTuningConfig argument is required") |
| 234 | + |
| 235 | + for i, arg in enumerate(trtexec_args): |
| 236 | + if not arg.startswith("--remoteAutoTuningConfig"): |
| 237 | + continue |
| 238 | + if arg == "--remoteAutoTuningConfig": |
| 239 | + if i + 1 >= len(trtexec_args): |
| 240 | + raise ValueError("Missing value for --remoteAutoTuningConfig") |
| 241 | + return trtexec_args[i + 1] |
| 242 | + if arg.startswith("--remoteAutoTuningConfig="): |
| 243 | + return arg.split("=", 1)[1] |
| 244 | + # Malformed: starts with the flag name but neither uses ``=`` nor is |
| 245 | + # the bare flag. Redact any embedded SSH password before surfacing. |
| 246 | + redacted_arg = _redact_url_password(arg) |
| 247 | + if log is not None: |
| 248 | + log.debug(f"Parsing remoteAutoTuningConfig arg: {redacted_arg}") |
| 249 | + raise ValueError(f"Malformed --remoteAutoTuningConfig argument: {redacted_arg}") |
| 250 | + return None # pragma: no cover — unreachable; ``matches`` proved presence |
| 251 | + |
| 252 | + |
| 253 | +@dataclass(frozen=True) |
| 254 | +class _RemoteAutotuningConfig: |
| 255 | + """Resolved remote-autotuning destination parsed from a ``ssh://`` URL.""" |
| 256 | + |
| 257 | + user: str |
| 258 | + password: str # may be empty when no password was supplied |
| 259 | + ip: str |
| 260 | + port: int |
| 261 | + options: dict[str, str] |
| 262 | + bin_path: str # dirname of ``remote_exec_path`` |
| 263 | + lib_path: str # value of ``remote_lib_path`` |
| 264 | + |
| 265 | + |
| 266 | +def _parse_remote_autotuning_url(url: str) -> _RemoteAutotuningConfig: |
| 267 | + """Parse a ``--remoteAutoTuningConfig`` URL into structured fields. |
| 268 | +
|
| 269 | + Required URL form:: |
| 270 | +
|
| 271 | + ssh://user[:password]@host[:port]?remote_exec_path=PATH&remote_lib_path=PATH |
| 272 | +
|
| 273 | + Raises: |
| 274 | + ValueError: If the scheme is not ``ssh://``; if user or host are |
| 275 | + missing; or if required query parameters are missing or |
| 276 | + duplicated. Duplicate keys are rejected explicitly because |
| 277 | + silently collapsing them would produce empty remote paths |
| 278 | + downstream. |
| 279 | + """ |
| 280 | + if not url.startswith("ssh://"): |
| 281 | + raise ValueError("Only 'ssh://' remote autotuning config URLs are supported") |
| 282 | + parsed = urlparse(url) |
| 283 | + if parsed.username is None: |
| 284 | + raise ValueError("Unable to parse remote user from --remoteAutoTuningConfig") |
| 285 | + if parsed.hostname is None: |
| 286 | + raise ValueError("Unable to parse remote IP from --remoteAutoTuningConfig") |
| 287 | + |
| 288 | + parsed_query = parse_qs(parsed.query) |
| 289 | + duplicates = sorted(k for k, v in parsed_query.items() if len(v) > 1) |
| 290 | + if duplicates: |
| 291 | + raise ValueError(f"Duplicate query parameters in --remoteAutoTuningConfig: {duplicates}") |
| 292 | + options = {k: v[0] for k, v in parsed_query.items()} |
| 293 | + |
| 294 | + required_params = ["remote_exec_path", "remote_lib_path"] |
| 295 | + missing = [p for p in required_params if p not in options] |
| 296 | + if missing: |
| 297 | + raise ValueError( |
| 298 | + f"Missing required query parameters in --remoteAutoTuningConfig: {missing}" |
| 299 | + ) |
| 300 | + |
| 301 | + return _RemoteAutotuningConfig( |
| 302 | + user=parsed.username, |
| 303 | + password=parsed.password or "", |
| 304 | + ip=parsed.hostname, |
| 305 | + port=parsed.port if parsed.port is not None else 22, |
| 306 | + options=options, |
| 307 | + bin_path=os.path.dirname(options["remote_exec_path"]), |
| 308 | + lib_path=options["remote_lib_path"], |
| 309 | + ) |
| 310 | + |
| 311 | + |
| 312 | +def _ensure_remote_autotuning_flags(trtexec_args: list[str], *, log: Any = None) -> list[str]: |
| 313 | + """Return ``trtexec_args`` with ``--safe`` and ``--skipInference`` appended if missing. |
| 314 | +
|
| 315 | + Remote autotuning requires both flags. A warning is emitted for each flag |
| 316 | + that has to be injected so the user sees that their argv was modified. |
| 317 | + """ |
| 318 | + result = list(trtexec_args) |
| 319 | + for flag in ("--safe", "--skipInference"): |
| 320 | + if flag in result: |
| 321 | + continue |
| 322 | + if log is not None: |
| 323 | + log.warning( |
| 324 | + f"Remote autotuning requires '{flag}' to be set. Adding it to trtexec arguments." |
| 325 | + ) |
| 326 | + result.append(flag) |
| 327 | + return result |
| 328 | + |
156 | 329 |
|
157 | 330 | class TrtExecBenchmark(Benchmark): |
158 | 331 | """TensorRT benchmark using trtexec command-line tool. |
@@ -182,120 +355,58 @@ def __init__( |
182 | 355 | Example: ['--fp16', '--workspace=4096', '--verbose'] |
183 | 356 | """ |
184 | 357 | super().__init__(timing_cache_file, warmup_runs, timing_runs, plugin_libraries) |
185 | | - self.trtexec_args = trtexec_args if trtexec_args is not None else [] |
| 358 | + self.trtexec_args = list(trtexec_args) if trtexec_args is not None else [] |
186 | 359 | self.temp_dir = tempfile.mkdtemp(prefix="trtexec_benchmark_") |
187 | 360 | self.engine_path = os.path.join(self.temp_dir, "engine.trt") |
188 | 361 | self.temp_model_path = os.path.join(self.temp_dir, "temp_model.onnx") |
189 | 362 | self.logger.debug(f"Created temporary engine directory: {self.temp_dir}") |
190 | 363 | self.logger.debug(f"Temporary model path: {self.temp_model_path}") |
191 | 364 |
|
192 | | - self._base_cmd = [ |
193 | | - f"--avgRuns={self.timing_runs}", |
194 | | - f"--iterations={self.timing_runs}", |
195 | | - f"--warmUp={self.warmup_runs}", |
196 | | - "--stronglyTyped", |
197 | | - f"--saveEngine={self.engine_path}", |
198 | | - f"--timingCacheFile={self.timing_cache_file}", |
199 | | - ] |
200 | | - |
201 | | - for plugin_lib in self.plugin_libraries: |
202 | | - plugin_path = Path(plugin_lib).resolve() |
203 | | - if not plugin_path.exists(): |
204 | | - self.logger.warning(f"Plugin library not found: {plugin_path}") |
205 | | - continue |
206 | | - self._base_cmd.append(f"--staticPlugins={plugin_path}") |
207 | | - self.logger.debug(f"Added plugin library: {plugin_path}") |
| 365 | + self._base_cmd = _build_base_trtexec_cmd( |
| 366 | + timing_runs=self.timing_runs, |
| 367 | + warmup_runs=self.warmup_runs, |
| 368 | + engine_path=self.engine_path, |
| 369 | + timing_cache_file=self.timing_cache_file, |
| 370 | + plugin_libraries=self.plugin_libraries, |
| 371 | + log=self.logger, |
| 372 | + ) |
208 | 373 |
|
209 | | - trtexec_args = self.trtexec_args or [] |
210 | | - self.has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args) |
| 374 | + # Defaults for remote-autotuning fields; overwritten when configured. |
| 375 | + self.has_remote_config: bool = False |
211 | 376 | self.remote_ip: str | None = None |
212 | 377 | self.remote_port: int = 22 |
213 | 378 | self.remote_user: str = "root" |
214 | 379 | self.remote_password: str = "" |
215 | 380 | self.remote_engine_path: str = "trtexec_benchmark_model.trt" |
216 | 381 | self.remote_bin_path: str = "trtexec" |
| 382 | + self.remote_lib_path: str = "" |
| 383 | + self.remote_options: dict[str, str] = {} |
217 | 384 |
|
218 | | - if self.has_remote_config: |
219 | | - remote_config = [arg for arg in trtexec_args if "--remoteAutoTuningConfig" in arg] |
220 | | - if len(remote_config) != 1: |
221 | | - raise ValueError("Exactly one --remoteAutoTuningConfig argument is required") |
222 | | - # Parse --remoteAutoTuningConfig argument, which may be given as: |
223 | | - # ('--remoteAutoTuningConfig=ssh://user:pass@host:port?...') or |
224 | | - # ('--remoteAutoTuningConfig', 'ssh://user:pass@host:port?...') |
225 | | - # |
226 | | - # The logic: find the arg starting with '--remoteAutoTuningConfig' |
227 | | - # If formatted as '--remoteAutoTuningConfig=...', split off the '=' |
228 | | - # Otherwise, grab the next argument. |
229 | | - config_arg_value: str | None = None |
230 | | - for i, arg in enumerate(trtexec_args): |
231 | | - if arg.startswith("--remoteAutoTuningConfig"): |
232 | | - if arg == "--remoteAutoTuningConfig": |
233 | | - # Value should be the next argument |
234 | | - if i + 1 < len(trtexec_args): |
235 | | - config_arg_value = trtexec_args[i + 1] |
236 | | - else: |
237 | | - raise ValueError("Missing value for --remoteAutoTuningConfig") |
238 | | - elif arg.startswith("--remoteAutoTuningConfig="): |
239 | | - config_arg_value = arg.split("=", 1)[1] |
240 | | - else: |
241 | | - raise ValueError(f"Malformed --remoteAutoTuningConfig argument: {arg}") |
242 | | - break |
243 | | - if not config_arg_value: |
| 385 | + remote_value = _extract_remote_config_value(self.trtexec_args, log=self.logger) |
| 386 | + if remote_value is not None: |
| 387 | + self.has_remote_config = True |
| 388 | + if not remote_value: |
244 | 389 | raise ValueError("Could not parse --remoteAutoTuningConfig argument") |
245 | | - remote_config_str: str = config_arg_value |
246 | | - |
247 | | - if not remote_config_str.startswith("ssh://"): |
248 | | - raise ValueError("Only 'ssh://' remote autotuning config URLs are supported") |
249 | | - parsed = urlparse(remote_config_str) |
250 | | - self.remote_user = parsed.username |
251 | | - self.remote_password = parsed.password |
252 | | - self.remote_ip = parsed.hostname |
253 | | - self.remote_port = parsed.port |
254 | | - if self.remote_user is None: |
255 | | - raise ValueError("Unable to parse remote user from --remoteAutoTuningConfig") |
256 | | - if self.remote_ip is None: |
257 | | - raise ValueError("Unable to parse remote IP from --remoteAutoTuningConfig") |
258 | | - if self.remote_port is None: |
259 | | - self.remote_port = 22 |
260 | | - # Parse query options into a dict. Reject duplicate keys: silently |
261 | | - # collapsing them would let a stray ``?remote_exec_path=a&remote_exec_path=b`` |
262 | | - # produce ``os.path.dirname(str(['a', 'b']))`` == '' downstream. |
263 | | - parsed_query = parse_qs(parsed.query) |
264 | | - duplicates = sorted(k for k, v in parsed_query.items() if len(v) > 1) |
265 | | - if duplicates: |
266 | | - raise ValueError( |
267 | | - f"Duplicate query parameters in --remoteAutoTuningConfig: {duplicates}" |
268 | | - ) |
269 | | - self.remote_options = {k: v[0] for k, v in parsed_query.items()} |
270 | | - required_params = ["remote_exec_path", "remote_lib_path"] |
271 | | - missing = [p for p in required_params if p not in self.remote_options] |
272 | | - if missing: |
273 | | - raise ValueError( |
274 | | - f"Missing required query parameters in --remoteAutoTuningConfig: {missing}" |
275 | | - ) |
276 | | - self.remote_bin_path = os.path.dirname(str(self.remote_options["remote_exec_path"])) |
277 | | - self.remote_lib_path = str(self.remote_options["remote_lib_path"]) |
| 390 | + config = _parse_remote_autotuning_url(remote_value) |
| 391 | + self.remote_user = config.user |
| 392 | + self.remote_password = config.password |
| 393 | + self.remote_ip = config.ip |
| 394 | + self.remote_port = config.port |
| 395 | + self.remote_options = config.options |
| 396 | + self.remote_bin_path = config.bin_path |
| 397 | + self.remote_lib_path = config.lib_path |
278 | 398 | try: |
279 | 399 | _check_for_trtexec(min_version="10.15") |
280 | 400 | self.logger.debug("TensorRT Python API version >= 10.15 detected") |
281 | | - if "--safe" not in trtexec_args: |
282 | | - self.logger.warning( |
283 | | - "Remote autotuning requires '--safe' to be set. Adding it to trtexec arguments." |
284 | | - ) |
285 | | - self.trtexec_args.append("--safe") |
286 | | - if "--skipInference" not in trtexec_args: |
287 | | - self.logger.warning( |
288 | | - "Remote autotuning requires '--skipInference' to be set. Adding it to trtexec arguments." |
289 | | - ) |
290 | | - self.trtexec_args.append("--skipInference") |
291 | | - except ImportError as e: |
| 401 | + except ImportError: |
292 | 402 | self.logger.warning( |
293 | 403 | "Remote autotuning is not supported with TensorRT version < 10.15." |
294 | 404 | ) |
295 | | - raise e |
| 405 | + raise |
| 406 | + self.trtexec_args = _ensure_remote_autotuning_flags(self.trtexec_args, log=self.logger) |
296 | 407 |
|
297 | 408 | self.is_safe = "--safe" in self.trtexec_args |
298 | | - self._base_cmd.extend(trtexec_args) |
| 409 | + self._base_cmd.extend(self.trtexec_args) |
299 | 410 |
|
300 | 411 | self.logger.debug(f"Base command template: {' '.join(self._base_cmd)}") |
301 | 412 |
|
|
0 commit comments