Skip to content

Commit 95baf16

Browse files
Add the capture replay feature improvement for 10.16 (#4158)
1 parent 7f6048e commit 95baf16

File tree

8 files changed

+171
-66
lines changed

8 files changed

+171
-66
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,4 @@ coverage.xml
8181
*.log
8282
*.pt2
8383
examples/torchtrt_aoti_example/torchtrt_aoti_example
84+
CLAUDE.md

docsrc/debugging/capture_and_replay.rst

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,56 @@ Prerequisites
1313
Quick start: Capture
1414
--------------------
1515

16+
Example ``test.py``:
17+
18+
.. code-block:: python
19+
20+
import torch
21+
import torch_tensorrt as torchtrt
22+
import torchvision.models as models
23+
class MyModule(torch.nn.Module):
24+
def __init__(self):
25+
super().__init__()
26+
self.conv = torch.nn.Conv1d(3, 3, 3, padding=1, stride=1, bias=True)
27+
28+
def forward(self, x):
29+
return self.conv(x)
30+
31+
model = MyModule().eval().to("cuda")
32+
input = torch.randn((1, 3, 3)).to("cuda").to(torch.float32)
33+
34+
compile_spec = {
35+
"inputs": [
36+
torchtrt.Input(
37+
min_shape=(1, 3, 3),
38+
opt_shape=(2, 3, 3),
39+
max_shape=(3, 3, 3),
40+
dtype=torch.float32,
41+
)
42+
],
43+
"min_block_size": 1,
44+
"cache_built_engines": False,
45+
"reuse_cached_engines": False,
46+
"use_python_runtime": True,
47+
}
48+
49+
try:
50+
with torchtrt.dynamo.Debugger(
51+
"graphs",
52+
logging_dir="debuglogs",
53+
):
54+
trt_mod = torchtrt.compile(model, **compile_spec)
55+
56+
except Exception as e:
57+
raise e
58+
59+
print("done.....")
60+
1661
.. code-block:: bash
1762
1863
TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1 python test.py
1964
20-
You should see ``shim.json`` and ``shim.bin`` generated in ``/tmp/torch_tensorrt_{current_user}/shim``.
65+
When ``TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1`` is set, capture and replay files are automatically saved under ``debuglogs/capture_replay/`` (i.e., the ``capture_replay`` subdirectory of ``logging_dir``). You should see ``capture.json`` and associated ``.bin`` files generated there.
2166

2267
Replay: Build the engine from the capture
2368
-----------------------------------------
@@ -26,7 +71,7 @@ Use ``tensorrt_player`` to replay the captured build without the original framew
2671

2772
.. code-block:: bash
2873
29-
tensorrt_player -j /absolute/path/to/shim.json -o /absolute/path/to/output_engine
74+
tensorrt_player -j debuglogs/capture_replay/capture.json -o /absolute/path/to/output_engine
3075
3176
This produces a serialized TensorRT engine at ``output_engine``.
3277

packaging/pre_build_script.sh

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
set -x
44

55
# Install dependencies
6-
python3 -m pip install pyyaml
6+
python3 -m pip install pyyaml packaging
77

88
if [[ $(uname -m) == "aarch64" ]]; then
99
IS_AARCH64=true
@@ -59,6 +59,18 @@ fi
5959
export TORCH_BUILD_NUMBER=$(python -c "import torch, urllib.parse as ul; print(ul.quote_plus(torch.__version__))")
6060
export TORCH_INSTALL_PATH=$(python -c "import torch, os; print(os.path.dirname(torch.__file__))")
6161

62+
if [[ -z "${TORCH_INSTALL_PATH}" ]]; then
63+
echo "ERROR: TORCH_INSTALL_PATH is empty — could not locate torch installation."
64+
echo "Ensure the active Python environment has torch installed, or set TORCH_PATH explicitly."
65+
exit 1
66+
fi
67+
68+
if [[ ! -d "${TORCH_INSTALL_PATH}/include/c10" ]]; then
69+
echo "ERROR: torch at '${TORCH_INSTALL_PATH}' is missing include/c10/ C++ headers."
70+
echo "Install a full PyTorch wheel (pip install torch) that includes dev headers."
71+
exit 1
72+
fi
73+
6274
# CU_UPPERBOUND eg:13.2 or 12.9
6375
# tensorrt tar for linux and windows are different across cuda version
6476
# for sbsa it is the same tar across cuda version

py/torch_tensorrt/_TensorRTProxyModule.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import ctypes
22
import importlib
3-
import importlib.util
43
import importlib.metadata
4+
import importlib.util
55
import logging
66
import os
77
import platform
88
import sys
9-
import tempfile
109
from types import ModuleType
1110
from typing import Any, Dict, List
1211

@@ -54,6 +53,7 @@ def enable_capture_tensorrt_api_recording() -> None:
5453
elif platform.uname().processor == "aarch64":
5554
linux_lib_path.append("/usr/lib/aarch64-linux-gnu")
5655

56+
tensorrt_lib_path = None
5757
for path in linux_lib_path:
5858
if os.path.isfile(os.path.join(path, "libtensorrt_shim.so")):
5959
try:
@@ -74,24 +74,7 @@ def enable_capture_tensorrt_api_recording() -> None:
7474
os.environ["TRT_SHIM_NVINFER_LIB_NAME"] = os.path.join(
7575
tensorrt_lib_path, "libnvinfer.so"
7676
)
77-
import pwd
78-
79-
current_user = pwd.getpwuid(os.getuid())[0]
80-
shim_temp_dir = os.path.join(
81-
tempfile.gettempdir(), f"torch_tensorrt_{current_user}/shim"
82-
)
83-
os.makedirs(shim_temp_dir, exist_ok=True)
84-
json_file_name = os.path.join(shim_temp_dir, "shim.json")
85-
os.environ["TRT_SHIM_OUTPUT_JSON_FILE"] = json_file_name
86-
bin_file_name = os.path.join(shim_temp_dir, "shim.bin")
87-
# if exists, delete the file, so that we can capture the new one
88-
if os.path.exists(json_file_name):
89-
os.remove(json_file_name)
90-
if os.path.exists(bin_file_name):
91-
os.remove(bin_file_name)
92-
_LOGGER.info(
93-
f"Capturing TensorRT API calls feature is enabled and the captured output is in the {shim_temp_dir} directory"
94-
)
77+
_LOGGER.info("Capturing TensorRT API calls feature is enabled")
9578

9679

9780
# TensorRTProxyModule is a proxy module that allows us to register the tensorrt or tensorrt-rtx package

py/torch_tensorrt/dynamo/debug/_Debugger.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import ctypes
23
import functools
34
import logging
45
import os
@@ -33,7 +34,6 @@ def __init__(
3334
capture_fx_graph_before: Optional[List[str]] = None,
3435
capture_fx_graph_after: Optional[List[str]] = None,
3536
save_engine_profile: bool = False,
36-
capture_tensorrt_api_recording: bool = False,
3737
profile_format: str = "perfetto",
3838
engine_builder_monitor: bool = True,
3939
logging_dir: str = DEBUG_LOGGING_DIR,
@@ -51,9 +51,6 @@ def __init__(
5151
after execution of a lowering pass. Defaults to None.
5252
save_engine_profile (bool): Whether to save TensorRT engine profiling information.
5353
Defaults to False.
54-
capture_tensorrt_api_recording (bool): Whether to enable the capture TensorRT API recording feature, when this is enabled, it will output the catputure TensorRT API recording in the /tmp/torch_tensorrt_{current_user}/shim directory.
55-
It is part of the TensorRT capture and replay feature, the captured output will be able to replay for debug purpose.
56-
Defaults to False.
5754
profile_format (str): Format for profiling data. Choose from 'perfetto', 'trex', 'cudagraph'.
5855
If you need to generate engine graph using the profiling files, set it to 'trex' and use the C++ runtime.
5956
If you need to generate cudagraph visualization, set it to 'cudagraph'.
@@ -67,6 +64,31 @@ def __init__(
6764
"""
6865

6966
os.makedirs(logging_dir, exist_ok=True)
67+
68+
# Auto-detect TensorRT API capture from environment variable
69+
env_flag = os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE", None)
70+
capture_tensorrt_api_recording = env_flag is not None and (
71+
env_flag == "1" or env_flag.lower() == "true"
72+
)
73+
74+
if capture_tensorrt_api_recording:
75+
if not sys.platform.startswith("linux"):
76+
_LOGGER.warning(
77+
f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring TORCHTRT_ENABLE_TENSORRT_API_CAPTURE for {sys.platform}"
78+
)
79+
capture_tensorrt_api_recording = False
80+
elif ENABLED_FEATURES.tensorrt_rtx:
81+
_LOGGER.warning(
82+
"Capturing TensorRT API calls is not supported for TensorRT-RTX, therefore ignoring TORCHTRT_ENABLE_TENSORRT_API_CAPTURE"
83+
)
84+
capture_tensorrt_api_recording = False
85+
else:
86+
_LOGGER.info("Capturing TensorRT API calls feature is enabled")
87+
88+
if capture_tensorrt_api_recording:
89+
capture_replay_dir = os.path.join(logging_dir, "capture_replay")
90+
os.makedirs(capture_replay_dir, exist_ok=True)
91+
7092
self.cfg = DebuggerConfig(
7193
log_level=log_level,
7294
save_engine_profile=save_engine_profile,
@@ -98,23 +120,6 @@ def __init__(
98120
self.capture_fx_graph_before = capture_fx_graph_before
99121
self.capture_fx_graph_after = capture_fx_graph_after
100122

101-
if self.cfg.capture_tensorrt_api_recording:
102-
if not sys.platform.startswith("linux"):
103-
_LOGGER.warning(
104-
f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring the capture_tensorrt_api_recording setting for {sys.platform}"
105-
)
106-
elif ENABLED_FEATURES.tensorrt_rtx:
107-
_LOGGER.warning(
108-
"Capturing TensorRT API calls is not supported for TensorRT-RTX, therefore ignoring the capture_tensorrt_api_recording setting"
109-
)
110-
else:
111-
env_flag = os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE", None)
112-
if env_flag is None or (env_flag != "1" and env_flag.lower() != "true"):
113-
_LOGGER.warning(
114-
"In order to capture TensorRT API calls, please invoke the script with environment variable TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1"
115-
)
116-
_LOGGER.info("Capturing TensorRT API calls feature is enabled")
117-
118123
def __enter__(self) -> None:
119124
self.original_lvl = _LOGGER.getEffectiveLevel()
120125
if ENABLED_FEATURES.torch_tensorrt_runtime:
@@ -166,6 +171,8 @@ def __enter__(self) -> None:
166171
for c in _DEBUG_ENABLED_CLS
167172
]
168173

174+
self.set_capture_tensorrt_api_recording_json_file()
175+
169176
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
170177

171178
dictConfig(self.get_logging_config(None))
@@ -224,3 +231,36 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]:
224231
}
225232
config["loggers"][""]["handlers"].append("file")
226233
return config
234+
235+
def set_capture_tensorrt_api_recording_json_file(self) -> None:
236+
if self.cfg.capture_tensorrt_api_recording is False:
237+
return
238+
239+
capture_replay_dir = os.path.join(self.cfg.logging_dir, "capture_replay")
240+
json_file = os.path.join(capture_replay_dir, "capture.json")
241+
242+
if os.path.isfile(json_file):
243+
os.remove(json_file)
244+
245+
nvinfer_lib = os.environ.get("TRT_SHIM_NVINFER_LIB_NAME", None)
246+
if nvinfer_lib is None:
247+
_LOGGER.warning(
248+
"TRT_SHIM_NVINFER_LIB_NAME is not set, therefore capturing TensorRT API recording is not supported"
249+
)
250+
return
251+
lib_path = os.path.dirname(nvinfer_lib)
252+
shim_path = os.path.join(lib_path, "libtensorrt_shim.so")
253+
if not os.path.isfile(shim_path):
254+
_LOGGER.warning(
255+
f"libtensorrt_shim.so is not found in the {lib_path} directory, therefore capturing TensorRT API recording is not supported"
256+
)
257+
return
258+
try:
259+
shim_lib = ctypes.CDLL(shim_path, mode=ctypes.RTLD_GLOBAL)
260+
shim_lib.trtShimSetOutputJsonFile(json_file.encode("utf-8"))
261+
_LOGGER.info(f"TensorRT API recording will be saved to {json_file}")
262+
except Exception as e:
263+
_LOGGER.warning(
264+
f"Failed to set the output JSON file for TensorRT API recording: {e}"
265+
)
266+
return

third_party/libtorch/BUILD

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ cc_library(
3939
exclude = [
4040
"include/torch/csrc/api/include/**/*.h",
4141
],
42-
) + glob([
43-
"include/torch/csrc/api/include/**/*.h",
44-
]),
42+
) + glob(
43+
[
44+
"include/torch/csrc/api/include/**/*.h",
45+
],
46+
allow_empty = True,
47+
),
4548
includes = [
4649
"include",
4750
"include/torch/csrc/api/include/",
@@ -58,9 +61,12 @@ cc_library(
5861
":windows": ["lib/c10_cuda.lib"],
5962
"//conditions:default": ["lib/libc10_cuda.so"],
6063
}),
61-
hdrs = glob([
62-
"include/c10/**/*.h",
63-
]),
64+
hdrs = glob(
65+
[
66+
"include/c10/**/*.h",
67+
],
68+
allow_empty = True,
69+
),
6470
strip_include_prefix = "include",
6571
deps = [
6672
":c10",
@@ -73,17 +79,23 @@ cc_library(
7379
":windows": ["lib/c10.lib"],
7480
"//conditions:default": ["lib/libc10.so"],
7581
}),
76-
hdrs = glob([
77-
"include/c10/**/*.h",
78-
]),
82+
hdrs = glob(
83+
[
84+
"include/c10/**/*.h",
85+
],
86+
allow_empty = True,
87+
),
7988
strip_include_prefix = "include",
8089
)
8190

8291
cc_library(
8392
name = "ATen",
84-
hdrs = glob([
85-
"include/ATen/**/*.h",
86-
]),
93+
hdrs = glob(
94+
[
95+
"include/ATen/**/*.h",
96+
],
97+
allow_empty = True,
98+
),
8799
strip_include_prefix = "include",
88100
)
89101

@@ -97,8 +109,11 @@ cc_library(
97109
"lib/libcaffe2_nvrtc.so",
98110
],
99111
}),
100-
hdrs = glob([
101-
"include/caffe2/**/*.h",
102-
]),
112+
hdrs = glob(
113+
[
114+
"include/caffe2/**/*.h",
115+
],
116+
allow_empty = True,
117+
),
103118
strip_include_prefix = "include",
104119
)

0 commit comments

Comments
 (0)