Skip to content

Commit 17c662c

Browse files
committed
cleanup trtexec based benchmarking code
Signed-off-by: dmoodie <dmoodie@nvidia.com>
1 parent 304f0e1 commit 17c662c

2 files changed

Lines changed: 454 additions & 92 deletions

File tree

modelopt/onnx/quantization/autotune/benchmark.py

Lines changed: 202 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import tempfile
3737
import time
3838
from abc import ABC, abstractmethod
39+
from dataclasses import dataclass
3940
from pathlib import Path
4041
from typing import Any
4142
from urllib.parse import parse_qs, urlparse
@@ -153,6 +154,178 @@ def _write_log_file(self, file: Path | str | None, content: str) -> None:
153154
)
154155
_STD_PATTERN = r"\[I\]\s+GPU Compute Time:.*?median\s*=\s*([\d.]+)\s*ms"
155156

157+
_URL_PASSWORD_RE = re.compile(r"(://[^:/?#@]+):[^@/?#]+@")
158+
159+
160+
def _redact_url_password(s: str) -> str:
161+
"""Replace any ``scheme://user:password@host`` substring with ``user:******@host``.
162+
163+
Used so SSH passwords supplied via ``--remoteAutoTuningConfig`` don't leak
164+
into log messages or exception strings.
165+
"""
166+
return _URL_PASSWORD_RE.sub(r"\1:******@", s)
167+
168+
169+
def _build_base_trtexec_cmd(
170+
*,
171+
timing_runs: int,
172+
warmup_runs: int,
173+
engine_path: str,
174+
timing_cache_file: str,
175+
plugin_libraries: list[str] | None = None,
176+
log: Any = None,
177+
) -> list[str]:
178+
"""Build the static portion of the trtexec command line (no ``--onnx=`` yet).
179+
180+
Plugin libraries that don't exist on disk are skipped with a warning if a
181+
logger is supplied. The leading ``trtexec`` binary path is not included —
182+
the caller is responsible for prepending it.
183+
184+
Args:
185+
timing_runs: Value for ``--avgRuns`` and ``--iterations``.
186+
warmup_runs: Value for ``--warmUp``.
187+
engine_path: Path used for ``--saveEngine=``.
188+
timing_cache_file: Path used for ``--timingCacheFile=``.
189+
plugin_libraries: Paths to ``.so`` libraries for ``--staticPlugins``.
190+
log: Optional logger used to warn about missing plugins and trace adds.
191+
"""
192+
cmd = [
193+
f"--avgRuns={timing_runs}",
194+
f"--iterations={timing_runs}",
195+
f"--warmUp={warmup_runs}",
196+
"--stronglyTyped",
197+
f"--saveEngine={engine_path}",
198+
f"--timingCacheFile={timing_cache_file}",
199+
]
200+
for plugin_lib in plugin_libraries or []:
201+
plugin_path = Path(plugin_lib).resolve()
202+
if not plugin_path.exists():
203+
if log is not None:
204+
log.warning(f"Plugin library not found: {plugin_path}")
205+
continue
206+
cmd.append(f"--staticPlugins={plugin_path}")
207+
if log is not None:
208+
log.debug(f"Added plugin library: {plugin_path}")
209+
return cmd
210+
211+
212+
def _extract_remote_config_value(trtexec_args: list[str], *, log: Any = None) -> str | None:
213+
"""Find the value of ``--remoteAutoTuningConfig`` in ``trtexec_args``.
214+
215+
Supports both inline (``--remoteAutoTuningConfig=value``) and split
216+
(``--remoteAutoTuningConfig value``) forms.
217+
218+
Returns:
219+
The value as a string, or ``None`` if the flag is absent. Returning
220+
an empty string is possible (e.g. ``--remoteAutoTuningConfig=``); the
221+
caller decides whether to treat that as an error.
222+
223+
Raises:
224+
ValueError: If the flag appears more than once, has no value at the
225+
end of the list, or is malformed (e.g. missing the ``=``
226+
separator). SSH passwords in malformed args are redacted before
227+
being included in the error or debug log.
228+
"""
229+
matches = [a for a in trtexec_args if "--remoteAutoTuningConfig" in a]
230+
if not matches:
231+
return None
232+
if len(matches) != 1:
233+
raise ValueError("Exactly one --remoteAutoTuningConfig argument is required")
234+
235+
for i, arg in enumerate(trtexec_args):
236+
if not arg.startswith("--remoteAutoTuningConfig"):
237+
continue
238+
if arg == "--remoteAutoTuningConfig":
239+
if i + 1 >= len(trtexec_args):
240+
raise ValueError("Missing value for --remoteAutoTuningConfig")
241+
return trtexec_args[i + 1]
242+
if arg.startswith("--remoteAutoTuningConfig="):
243+
return arg.split("=", 1)[1]
244+
# Malformed: starts with the flag name but neither uses ``=`` nor is
245+
# the bare flag. Redact any embedded SSH password before surfacing.
246+
redacted_arg = _redact_url_password(arg)
247+
if log is not None:
248+
log.debug(f"Parsing remoteAutoTuningConfig arg: {redacted_arg}")
249+
raise ValueError(f"Malformed --remoteAutoTuningConfig argument: {redacted_arg}")
250+
return None # pragma: no cover — unreachable; ``matches`` proved presence
251+
252+
253+
@dataclass(frozen=True)
254+
class _RemoteAutotuningConfig:
255+
"""Resolved remote-autotuning destination parsed from a ``ssh://`` URL."""
256+
257+
user: str
258+
password: str # may be empty when no password was supplied
259+
ip: str
260+
port: int
261+
options: dict[str, str]
262+
bin_path: str # dirname of ``remote_exec_path``
263+
lib_path: str # value of ``remote_lib_path``
264+
265+
266+
def _parse_remote_autotuning_url(url: str) -> _RemoteAutotuningConfig:
267+
"""Parse a ``--remoteAutoTuningConfig`` URL into structured fields.
268+
269+
Required URL form::
270+
271+
ssh://user[:password]@host[:port]?remote_exec_path=PATH&remote_lib_path=PATH
272+
273+
Raises:
274+
ValueError: If the scheme is not ``ssh://``; if user or host are
275+
missing; or if required query parameters are missing or
276+
duplicated. Duplicate keys are rejected explicitly because
277+
silently collapsing them would produce empty remote paths
278+
downstream.
279+
"""
280+
if not url.startswith("ssh://"):
281+
raise ValueError("Only 'ssh://' remote autotuning config URLs are supported")
282+
parsed = urlparse(url)
283+
if parsed.username is None:
284+
raise ValueError("Unable to parse remote user from --remoteAutoTuningConfig")
285+
if parsed.hostname is None:
286+
raise ValueError("Unable to parse remote IP from --remoteAutoTuningConfig")
287+
288+
parsed_query = parse_qs(parsed.query)
289+
duplicates = sorted(k for k, v in parsed_query.items() if len(v) > 1)
290+
if duplicates:
291+
raise ValueError(f"Duplicate query parameters in --remoteAutoTuningConfig: {duplicates}")
292+
options = {k: v[0] for k, v in parsed_query.items()}
293+
294+
required_params = ["remote_exec_path", "remote_lib_path"]
295+
missing = [p for p in required_params if p not in options]
296+
if missing:
297+
raise ValueError(
298+
f"Missing required query parameters in --remoteAutoTuningConfig: {missing}"
299+
)
300+
301+
return _RemoteAutotuningConfig(
302+
user=parsed.username,
303+
password=parsed.password or "",
304+
ip=parsed.hostname,
305+
port=parsed.port if parsed.port is not None else 22,
306+
options=options,
307+
bin_path=os.path.dirname(options["remote_exec_path"]),
308+
lib_path=options["remote_lib_path"],
309+
)
310+
311+
312+
def _ensure_remote_autotuning_flags(trtexec_args: list[str], *, log: Any = None) -> list[str]:
313+
"""Return ``trtexec_args`` with ``--safe`` and ``--skipInference`` appended if missing.
314+
315+
Remote autotuning requires both flags. A warning is emitted for each flag
316+
that has to be injected so the user sees that their argv was modified.
317+
"""
318+
result = list(trtexec_args)
319+
for flag in ("--safe", "--skipInference"):
320+
if flag in result:
321+
continue
322+
if log is not None:
323+
log.warning(
324+
f"Remote autotuning requires '{flag}' to be set. Adding it to trtexec arguments."
325+
)
326+
result.append(flag)
327+
return result
328+
156329

157330
class TrtExecBenchmark(Benchmark):
158331
"""TensorRT benchmark using trtexec command-line tool.
@@ -182,120 +355,58 @@ def __init__(
182355
Example: ['--fp16', '--workspace=4096', '--verbose']
183356
"""
184357
super().__init__(timing_cache_file, warmup_runs, timing_runs, plugin_libraries)
185-
self.trtexec_args = trtexec_args if trtexec_args is not None else []
358+
self.trtexec_args = list(trtexec_args) if trtexec_args is not None else []
186359
self.temp_dir = tempfile.mkdtemp(prefix="trtexec_benchmark_")
187360
self.engine_path = os.path.join(self.temp_dir, "engine.trt")
188361
self.temp_model_path = os.path.join(self.temp_dir, "temp_model.onnx")
189362
self.logger.debug(f"Created temporary engine directory: {self.temp_dir}")
190363
self.logger.debug(f"Temporary model path: {self.temp_model_path}")
191364

192-
self._base_cmd = [
193-
f"--avgRuns={self.timing_runs}",
194-
f"--iterations={self.timing_runs}",
195-
f"--warmUp={self.warmup_runs}",
196-
"--stronglyTyped",
197-
f"--saveEngine={self.engine_path}",
198-
f"--timingCacheFile={self.timing_cache_file}",
199-
]
200-
201-
for plugin_lib in self.plugin_libraries:
202-
plugin_path = Path(plugin_lib).resolve()
203-
if not plugin_path.exists():
204-
self.logger.warning(f"Plugin library not found: {plugin_path}")
205-
continue
206-
self._base_cmd.append(f"--staticPlugins={plugin_path}")
207-
self.logger.debug(f"Added plugin library: {plugin_path}")
365+
self._base_cmd = _build_base_trtexec_cmd(
366+
timing_runs=self.timing_runs,
367+
warmup_runs=self.warmup_runs,
368+
engine_path=self.engine_path,
369+
timing_cache_file=self.timing_cache_file,
370+
plugin_libraries=self.plugin_libraries,
371+
log=self.logger,
372+
)
208373

209-
trtexec_args = self.trtexec_args or []
210-
self.has_remote_config = any("--remoteAutoTuningConfig" in arg for arg in trtexec_args)
374+
# Defaults for remote-autotuning fields; overwritten when configured.
375+
self.has_remote_config: bool = False
211376
self.remote_ip: str | None = None
212377
self.remote_port: int = 22
213378
self.remote_user: str = "root"
214379
self.remote_password: str = ""
215380
self.remote_engine_path: str = "trtexec_benchmark_model.trt"
216381
self.remote_bin_path: str = "trtexec"
382+
self.remote_lib_path: str = ""
383+
self.remote_options: dict[str, str] = {}
217384

218-
if self.has_remote_config:
219-
remote_config = [arg for arg in trtexec_args if "--remoteAutoTuningConfig" in arg]
220-
if len(remote_config) != 1:
221-
raise ValueError("Exactly one --remoteAutoTuningConfig argument is required")
222-
# Parse --remoteAutoTuningConfig argument, which may be given as:
223-
# ('--remoteAutoTuningConfig=ssh://user:pass@host:port?...') or
224-
# ('--remoteAutoTuningConfig', 'ssh://user:pass@host:port?...')
225-
#
226-
# The logic: find the arg starting with '--remoteAutoTuningConfig'
227-
# If formatted as '--remoteAutoTuningConfig=...', split off the '='
228-
# Otherwise, grab the next argument.
229-
config_arg_value: str | None = None
230-
for i, arg in enumerate(trtexec_args):
231-
if arg.startswith("--remoteAutoTuningConfig"):
232-
if arg == "--remoteAutoTuningConfig":
233-
# Value should be the next argument
234-
if i + 1 < len(trtexec_args):
235-
config_arg_value = trtexec_args[i + 1]
236-
else:
237-
raise ValueError("Missing value for --remoteAutoTuningConfig")
238-
elif arg.startswith("--remoteAutoTuningConfig="):
239-
config_arg_value = arg.split("=", 1)[1]
240-
else:
241-
raise ValueError(f"Malformed --remoteAutoTuningConfig argument: {arg}")
242-
break
243-
if not config_arg_value:
385+
remote_value = _extract_remote_config_value(self.trtexec_args, log=self.logger)
386+
if remote_value is not None:
387+
self.has_remote_config = True
388+
if not remote_value:
244389
raise ValueError("Could not parse --remoteAutoTuningConfig argument")
245-
remote_config_str: str = config_arg_value
246-
247-
if not remote_config_str.startswith("ssh://"):
248-
raise ValueError("Only 'ssh://' remote autotuning config URLs are supported")
249-
parsed = urlparse(remote_config_str)
250-
self.remote_user = parsed.username
251-
self.remote_password = parsed.password
252-
self.remote_ip = parsed.hostname
253-
self.remote_port = parsed.port
254-
if self.remote_user is None:
255-
raise ValueError("Unable to parse remote user from --remoteAutoTuningConfig")
256-
if self.remote_ip is None:
257-
raise ValueError("Unable to parse remote IP from --remoteAutoTuningConfig")
258-
if self.remote_port is None:
259-
self.remote_port = 22
260-
# Parse query options into a dict. Reject duplicate keys: silently
261-
# collapsing them would let a stray ``?remote_exec_path=a&remote_exec_path=b``
262-
# produce ``os.path.dirname(str(['a', 'b']))`` == '' downstream.
263-
parsed_query = parse_qs(parsed.query)
264-
duplicates = sorted(k for k, v in parsed_query.items() if len(v) > 1)
265-
if duplicates:
266-
raise ValueError(
267-
f"Duplicate query parameters in --remoteAutoTuningConfig: {duplicates}"
268-
)
269-
self.remote_options = {k: v[0] for k, v in parsed_query.items()}
270-
required_params = ["remote_exec_path", "remote_lib_path"]
271-
missing = [p for p in required_params if p not in self.remote_options]
272-
if missing:
273-
raise ValueError(
274-
f"Missing required query parameters in --remoteAutoTuningConfig: {missing}"
275-
)
276-
self.remote_bin_path = os.path.dirname(str(self.remote_options["remote_exec_path"]))
277-
self.remote_lib_path = str(self.remote_options["remote_lib_path"])
390+
config = _parse_remote_autotuning_url(remote_value)
391+
self.remote_user = config.user
392+
self.remote_password = config.password
393+
self.remote_ip = config.ip
394+
self.remote_port = config.port
395+
self.remote_options = config.options
396+
self.remote_bin_path = config.bin_path
397+
self.remote_lib_path = config.lib_path
278398
try:
279399
_check_for_trtexec(min_version="10.15")
280400
self.logger.debug("TensorRT Python API version >= 10.15 detected")
281-
if "--safe" not in trtexec_args:
282-
self.logger.warning(
283-
"Remote autotuning requires '--safe' to be set. Adding it to trtexec arguments."
284-
)
285-
self.trtexec_args.append("--safe")
286-
if "--skipInference" not in trtexec_args:
287-
self.logger.warning(
288-
"Remote autotuning requires '--skipInference' to be set. Adding it to trtexec arguments."
289-
)
290-
self.trtexec_args.append("--skipInference")
291-
except ImportError as e:
401+
except ImportError:
292402
self.logger.warning(
293403
"Remote autotuning is not supported with TensorRT version < 10.15."
294404
)
295-
raise e
405+
raise
406+
self.trtexec_args = _ensure_remote_autotuning_flags(self.trtexec_args, log=self.logger)
296407

297408
self.is_safe = "--safe" in self.trtexec_args
298-
self._base_cmd.extend(trtexec_args)
409+
self._base_cmd.extend(self.trtexec_args)
299410

300411
self.logger.debug(f"Base command template: {' '.join(self._base_cmd)}")
301412

0 commit comments

Comments
 (0)