diff --git a/.gitignore b/.gitignore index f08d97d448..25a7acad99 100644 --- a/.gitignore +++ b/.gitignore @@ -81,3 +81,4 @@ coverage.xml *.log *.pt2 examples/torchtrt_aoti_example/torchtrt_aoti_example +CLAUDE.md diff --git a/docsrc/debugging/capture_and_replay.rst b/docsrc/debugging/capture_and_replay.rst index cbd7295502..c80a8c509e 100644 --- a/docsrc/debugging/capture_and_replay.rst +++ b/docsrc/debugging/capture_and_replay.rst @@ -13,11 +13,56 @@ Prerequisites Quick start: Capture -------------------- +Example ``test.py``: + +.. code-block:: python + + import torch + import torch_tensorrt as torchtrt + import torchvision.models as models + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(3, 3, 3, padding=1, stride=1, bias=True) + + def forward(self, x): + return self.conv(x) + + model = MyModule().eval().to("cuda") + input = torch.randn((1, 3, 3)).to("cuda").to(torch.float32) + + compile_spec = { + "inputs": [ + torchtrt.Input( + min_shape=(1, 3, 3), + opt_shape=(2, 3, 3), + max_shape=(3, 3, 3), + dtype=torch.float32, + ) + ], + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + "use_python_runtime": True, + } + + try: + with torchtrt.dynamo.Debugger( + "graphs", + logging_dir="debuglogs", + ): + trt_mod = torchtrt.compile(model, **compile_spec) + + except Exception as e: + raise e + + print("done.....") + .. code-block:: bash TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1 python test.py -You should see ``shim.json`` and ``shim.bin`` generated in ``/tmp/torch_tensorrt_{current_user}/shim``. +When ``TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1`` is set, capture and replay files are automatically saved under ``debuglogs/capture_replay/`` (i.e., the ``capture_replay`` subdirectory of ``logging_dir``). You should see ``capture.json`` and associated ``.bin`` files generated there. Replay: Build the engine from the capture ----------------------------------------- @@ -26,7 +71,7 @@ Use ``tensorrt_player`` to replay the captured build without the original framew .. code-block:: bash - tensorrt_player -j /absolute/path/to/shim.json -o /absolute/path/to/output_engine + tensorrt_player -j debuglogs/capture_replay/capture.json -o /absolute/path/to/output_engine This produces a serialized TensorRT engine at ``output_engine``. diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh index 38825a1f43..3bd1dbe6f1 100755 --- a/packaging/pre_build_script.sh +++ b/packaging/pre_build_script.sh @@ -3,7 +3,7 @@ set -x # Install dependencies -python3 -m pip install pyyaml +python3 -m pip install pyyaml packaging if [[ $(uname -m) == "aarch64" ]]; then IS_AARCH64=true @@ -59,6 +59,18 @@ fi export TORCH_BUILD_NUMBER=$(python -c "import torch, urllib.parse as ul; print(ul.quote_plus(torch.__version__))") export TORCH_INSTALL_PATH=$(python -c "import torch, os; print(os.path.dirname(torch.__file__))") +if [[ -z "${TORCH_INSTALL_PATH}" ]]; then + echo "ERROR: TORCH_INSTALL_PATH is empty — could not locate torch installation." + echo "Ensure the active Python environment has torch installed, or set TORCH_PATH explicitly." + exit 1 +fi + +if [[ ! -d "${TORCH_INSTALL_PATH}/include/c10" ]]; then + echo "ERROR: torch at '${TORCH_INSTALL_PATH}' is missing include/c10/ C++ headers." + echo "Install a full PyTorch wheel (pip install torch) that includes dev headers." + exit 1 +fi + # CU_UPPERBOUND eg:13.2 or 12.9 # tensorrt tar for linux and windows are different across cuda version # for sbsa it is the same tar across cuda version diff --git a/py/torch_tensorrt/_TensorRTProxyModule.py b/py/torch_tensorrt/_TensorRTProxyModule.py index cc88953c01..75751a70d1 100644 --- a/py/torch_tensorrt/_TensorRTProxyModule.py +++ b/py/torch_tensorrt/_TensorRTProxyModule.py @@ -1,12 +1,11 @@ import ctypes import importlib -import importlib.util import importlib.metadata +import importlib.util import logging import os import platform import sys -import tempfile from types import ModuleType from typing import Any, Dict, List @@ -54,6 +53,7 @@ def enable_capture_tensorrt_api_recording() -> None: elif platform.uname().processor == "aarch64": linux_lib_path.append("/usr/lib/aarch64-linux-gnu") + tensorrt_lib_path = None for path in linux_lib_path: if os.path.isfile(os.path.join(path, "libtensorrt_shim.so")): try: @@ -74,24 +74,7 @@ def enable_capture_tensorrt_api_recording() -> None: os.environ["TRT_SHIM_NVINFER_LIB_NAME"] = os.path.join( tensorrt_lib_path, "libnvinfer.so" ) - import pwd - - current_user = pwd.getpwuid(os.getuid())[0] - shim_temp_dir = os.path.join( - tempfile.gettempdir(), f"torch_tensorrt_{current_user}/shim" - ) - os.makedirs(shim_temp_dir, exist_ok=True) - json_file_name = os.path.join(shim_temp_dir, "shim.json") - os.environ["TRT_SHIM_OUTPUT_JSON_FILE"] = json_file_name - bin_file_name = os.path.join(shim_temp_dir, "shim.bin") - # if exists, delete the file, so that we can capture the new one - if os.path.exists(json_file_name): - os.remove(json_file_name) - if os.path.exists(bin_file_name): - os.remove(bin_file_name) - _LOGGER.info( - f"Capturing TensorRT API calls feature is enabled and the captured output is in the {shim_temp_dir} directory" - ) + _LOGGER.info("Capturing TensorRT API calls feature is enabled") # TensorRTProxyModule is a proxy module that allows us to register the tensorrt or tensorrt-rtx package diff --git a/py/torch_tensorrt/dynamo/debug/_Debugger.py b/py/torch_tensorrt/dynamo/debug/_Debugger.py index e565929861..7b645c04a4 100644 --- a/py/torch_tensorrt/dynamo/debug/_Debugger.py +++ b/py/torch_tensorrt/dynamo/debug/_Debugger.py @@ -1,4 +1,5 @@ import contextlib +import ctypes import functools import logging import os @@ -33,7 +34,6 @@ def __init__( capture_fx_graph_before: Optional[List[str]] = None, capture_fx_graph_after: Optional[List[str]] = None, save_engine_profile: bool = False, - capture_tensorrt_api_recording: bool = False, profile_format: str = "perfetto", engine_builder_monitor: bool = True, logging_dir: str = DEBUG_LOGGING_DIR, @@ -51,9 +51,6 @@ def __init__( after execution of a lowering pass. Defaults to None. save_engine_profile (bool): Whether to save TensorRT engine profiling information. Defaults to False. - capture_tensorrt_api_recording (bool): Whether to enable the capture TensorRT API recording feature, when this is enabled, it will output the catputure TensorRT API recording in the /tmp/torch_tensorrt_{current_user}/shim directory. - It is part of the TensorRT capture and replay feature, the captured output will be able to replay for debug purpose. - Defaults to False. profile_format (str): Format for profiling data. Choose from 'perfetto', 'trex', 'cudagraph'. If you need to generate engine graph using the profiling files, set it to 'trex' and use the C++ runtime. If you need to generate cudagraph visualization, set it to 'cudagraph'. @@ -67,6 +64,31 @@ def __init__( """ os.makedirs(logging_dir, exist_ok=True) + + # Auto-detect TensorRT API capture from environment variable + env_flag = os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE", None) + capture_tensorrt_api_recording = env_flag is not None and ( + env_flag == "1" or env_flag.lower() == "true" + ) + + if capture_tensorrt_api_recording: + if not sys.platform.startswith("linux"): + _LOGGER.warning( + f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring TORCHTRT_ENABLE_TENSORRT_API_CAPTURE for {sys.platform}" + ) + capture_tensorrt_api_recording = False + elif ENABLED_FEATURES.tensorrt_rtx: + _LOGGER.warning( + "Capturing TensorRT API calls is not supported for TensorRT-RTX, therefore ignoring TORCHTRT_ENABLE_TENSORRT_API_CAPTURE" + ) + capture_tensorrt_api_recording = False + else: + _LOGGER.info("Capturing TensorRT API calls feature is enabled") + + if capture_tensorrt_api_recording: + capture_replay_dir = os.path.join(logging_dir, "capture_replay") + os.makedirs(capture_replay_dir, exist_ok=True) + self.cfg = DebuggerConfig( log_level=log_level, save_engine_profile=save_engine_profile, @@ -98,23 +120,6 @@ def __init__( self.capture_fx_graph_before = capture_fx_graph_before self.capture_fx_graph_after = capture_fx_graph_after - if self.cfg.capture_tensorrt_api_recording: - if not sys.platform.startswith("linux"): - _LOGGER.warning( - f"Capturing TensorRT API calls is only supported on Linux, therefore ignoring the capture_tensorrt_api_recording setting for {sys.platform}" - ) - elif ENABLED_FEATURES.tensorrt_rtx: - _LOGGER.warning( - "Capturing TensorRT API calls is not supported for TensorRT-RTX, therefore ignoring the capture_tensorrt_api_recording setting" - ) - else: - env_flag = os.environ.get("TORCHTRT_ENABLE_TENSORRT_API_CAPTURE", None) - if env_flag is None or (env_flag != "1" and env_flag.lower() != "true"): - _LOGGER.warning( - "In order to capture TensorRT API calls, please invoke the script with environment variable TORCHTRT_ENABLE_TENSORRT_API_CAPTURE=1" - ) - _LOGGER.info("Capturing TensorRT API calls feature is enabled") - def __enter__(self) -> None: self.original_lvl = _LOGGER.getEffectiveLevel() if ENABLED_FEATURES.torch_tensorrt_runtime: @@ -166,6 +171,8 @@ def __enter__(self) -> None: for c in _DEBUG_ENABLED_CLS ] + self.set_capture_tensorrt_api_recording_json_file() + def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: dictConfig(self.get_logging_config(None)) @@ -224,3 +231,36 @@ def get_logging_config(self, log_level: Optional[int] = None) -> dict[str, Any]: } config["loggers"][""]["handlers"].append("file") return config + + def set_capture_tensorrt_api_recording_json_file(self) -> None: + if self.cfg.capture_tensorrt_api_recording is False: + return + + capture_replay_dir = os.path.join(self.cfg.logging_dir, "capture_replay") + json_file = os.path.join(capture_replay_dir, "capture.json") + + if os.path.isfile(json_file): + os.remove(json_file) + + nvinfer_lib = os.environ.get("TRT_SHIM_NVINFER_LIB_NAME", None) + if nvinfer_lib is None: + _LOGGER.warning( + "TRT_SHIM_NVINFER_LIB_NAME is not set, therefore capturing TensorRT API recording is not supported" + ) + return + lib_path = os.path.dirname(nvinfer_lib) + shim_path = os.path.join(lib_path, "libtensorrt_shim.so") + if not os.path.isfile(shim_path): + _LOGGER.warning( + f"libtensorrt_shim.so is not found in the {lib_path} directory, therefore capturing TensorRT API recording is not supported" + ) + return + try: + shim_lib = ctypes.CDLL(shim_path, mode=ctypes.RTLD_GLOBAL) + shim_lib.trtShimSetOutputJsonFile(json_file.encode("utf-8")) + _LOGGER.info(f"TensorRT API recording will be saved to {json_file}") + except Exception as e: + _LOGGER.warning( + f"Failed to set the output JSON file for TensorRT API recording: {e}" + ) + return diff --git a/third_party/libtorch/BUILD b/third_party/libtorch/BUILD index 37309f7209..5f36debe1c 100644 --- a/third_party/libtorch/BUILD +++ b/third_party/libtorch/BUILD @@ -39,9 +39,12 @@ cc_library( exclude = [ "include/torch/csrc/api/include/**/*.h", ], - ) + glob([ - "include/torch/csrc/api/include/**/*.h", - ]), + ) + glob( + [ + "include/torch/csrc/api/include/**/*.h", + ], + allow_empty = True, + ), includes = [ "include", "include/torch/csrc/api/include/", @@ -58,9 +61,12 @@ cc_library( ":windows": ["lib/c10_cuda.lib"], "//conditions:default": ["lib/libc10_cuda.so"], }), - hdrs = glob([ - "include/c10/**/*.h", - ]), + hdrs = glob( + [ + "include/c10/**/*.h", + ], + allow_empty = True, + ), strip_include_prefix = "include", deps = [ ":c10", @@ -73,17 +79,23 @@ cc_library( ":windows": ["lib/c10.lib"], "//conditions:default": ["lib/libc10.so"], }), - hdrs = glob([ - "include/c10/**/*.h", - ]), + hdrs = glob( + [ + "include/c10/**/*.h", + ], + allow_empty = True, + ), strip_include_prefix = "include", ) cc_library( name = "ATen", - hdrs = glob([ - "include/ATen/**/*.h", - ]), + hdrs = glob( + [ + "include/ATen/**/*.h", + ], + allow_empty = True, + ), strip_include_prefix = "include", ) @@ -97,8 +109,11 @@ cc_library( "lib/libcaffe2_nvrtc.so", ], }), - hdrs = glob([ - "include/caffe2/**/*.h", - ]), + hdrs = glob( + [ + "include/caffe2/**/*.h", + ], + allow_empty = True, + ), strip_include_prefix = "include", ) diff --git a/third_party/tensorrt/local/BUILD b/third_party/tensorrt/local/BUILD index b28ef63e7c..1d860b2c25 100644 --- a/third_party/tensorrt/local/BUILD +++ b/third_party/tensorrt/local/BUILD @@ -83,7 +83,7 @@ cc_import( name = "nvinfer_static_lib", static_library = select({ ":aarch64_linux": "lib/aarch64-linux-gnu/libnvinfer_static.a", - ":ci_rhel_x86_64_linux": "lib64/libnvinfer_static.a", + ":ci_rhel_x86_64_linux": "lib/libnvinfer_static.a", ":windows": "lib/nvinfer_10.lib", "//conditions:default": "lib/x86_64-linux-gnu/libnvinfer_static.a", }), @@ -94,7 +94,7 @@ cc_import( name = "nvinfer_lib", shared_library = select({ ":aarch64_linux": "lib/aarch64-linux-gnu/libnvinfer.so", - ":ci_rhel_x86_64_linux": "lib64/libnvinfer.so", + ":ci_rhel_x86_64_linux": "lib/libnvinfer.so", ":windows": "bin/nvinfer_10.dll", "//conditions:default": "lib/x86_64-linux-gnu/libnvinfer.so", }), @@ -122,7 +122,7 @@ cc_import( name = "nvparsers_lib", shared_library = select({ ":aarch64_linux": "lib/aarch64-linux-gnu/libnvparsers.so", - ":ci_rhel_x86_64_linux": "lib64/libnvparsers.so", + ":ci_rhel_x86_64_linux": "lib/libnvparsers.so", ":windows": "lib/nvparsers.dll", "//conditions:default": "lib/x86_64-linux-gnu/libnvparsers.so", }), @@ -186,7 +186,7 @@ cc_import( name = "nvonnxparser_lib", shared_library = select({ ":aarch64_linux": "lib/aarch64-linux-gnu/libnvonnxparser.so", - ":ci_rhel_x86_64_linux": "lib64/libnvonnxparser.so", + ":ci_rhel_x86_64_linux": "lib/libnvonnxparser.so", ":windows": "lib/nvonnxparser.dll", "//conditions:default": "lib/x86_64-linux-gnu/libnvonnxparser.so", }), @@ -242,7 +242,7 @@ cc_import( name = "nvonnxparser_runtime_lib", shared_library = select({ ":aarch64_linux": "lib/x86_64-linux-gnu/libnvonnxparser_runtime.so", - ":ci_rhel_x86_64_linux": "lib64/libnvonnxparser_runtime.so", + ":ci_rhel_x86_64_linux": "lib/libnvonnxparser_runtime.so", ":windows": "lib/nvonnxparser_runtime.dll", "//conditions:default": "lib/x86_64-linux-gnu/libnvonnxparser_runtime.so", }), @@ -290,7 +290,7 @@ cc_import( name = "nvcaffeparser_lib", shared_library = select({ ":aarch64_linux": "lib/aarch64-linux-gnu/libnvcaffe_parsers.so", - ":ci_rhel_x86_64_linux": "lib64/libnvcaffe_parsers.so", + ":ci_rhel_x86_64_linux": "lib/libnvcaffe_parsers.so", ":windows": "lib/nvcaffe_parsers.dll", "//conditions:default": "lib/x86_64-linux-gnu/libnvcaffe_parsers.so", }), @@ -338,7 +338,7 @@ cc_library( name = "nvinferplugin", srcs = select({ ":aarch64_linux": ["lib/aarch64-linux-gnu/libnvinfer_plugin.so"], - ":ci_rhel_x86_64_linux": ["lib64/libnvinfer_plugin.so"], + ":ci_rhel_x86_64_linux": ["lib/libnvinfer_plugin.so"], ":windows": ["lib/nvinfer_plugin_10.lib"], "//conditions:default": ["lib/x86_64-linux-gnu/libnvinfer_plugin.so"], }), diff --git a/toolchains/local_torch.bzl b/toolchains/local_torch.bzl index 52eb641c93..37b5291c31 100644 --- a/toolchains/local_torch.bzl +++ b/toolchains/local_torch.bzl @@ -76,6 +76,15 @@ def _local_torch_impl(ctx): torch_path = ctx.path(torch_dir) + # Validate that the installation has the expected C++ headers. + c10_include = torch_path.get_child("include").get_child("c10") + if not c10_include.exists: + fail( + "torch at '" + torch_dir + "' is missing include/c10/ C++ headers. " + + "Install a full PyTorch wheel (pip install torch) that includes dev headers, " + + "or set TORCH_PATH to the correct directory.", + ) + # Symlink the subdirectories the BUILD file references into the synthetic repo for sub in ["include", "lib", "share"]: child = torch_path.get_child(sub)