Skip to content

Commit 89bfbdf

Browse files
Bowen Fuclaude
andcommitted
refactor: minimize changes outside annotation folder
Revert non-essential modifications to core torch_tensorrt files, keeping only what TTA strictly requires: _compile.py: - Restore module_type == _ModuleType.ep branch (preserve EP input handling) - Restore load() with extra_files/kwargs support - Restore save() with all original params (extra_files, use_legacy_exporter, dynamic_shapes, Input type annotations, full docstring) - Restore original imports (inspect, Dict/Tuple, default_device, etc.) - Keep only the post-trace hook loop as the TTA addition _defaults.py / _settings.py: - Remove editable_timing_cache, error_on_timing_cache_miss (autotune, out of scope) - Restore DECOMPOSE_ATTENTION and decompose_attention field - Restore cpu_memory_budget: Optional[int] - Keep profiling_verbosity (needed for ILayer.metadata inspection) _TRTInterpreter.py: - Remove algorithm_selector parameter (autotune, out of scope) - Remove _mark_debug_candidates / mark_debug logic (debug feature, out of scope) - Remove editable_timing_cache / error_on_timing_cache_miss flag handling Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 917e8be commit 89bfbdf

4 files changed

Lines changed: 130 additions & 78 deletions

File tree

py/torch_tensorrt/_compile.py

Lines changed: 123 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from __future__ import annotations
22

33
import collections.abc
4+
import inspect
45
import logging
56
import platform
67
import warnings
78
from enum import Enum
8-
from typing import Any, Callable, List, Optional, Sequence, Set, Union
9+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
910

1011
import torch
1112
from torch_tensorrt._enums import dtype
@@ -23,9 +24,9 @@
2324
from torch_tensorrt.fx.lower import compile as fx_compile
2425
from torch_tensorrt.fx.utils import LowerPrecision
2526

26-
InputType = Union[Input, torch.Tensor, InputTensorSpec]
27-
else:
2827
InputType = Union[Input, torch.Tensor]
28+
else:
29+
InputType = Union[Input, torch.Tensor] # type: ignore
2930

3031
if ENABLED_FEATURES.torchscript_frontend:
3132
import torch_tensorrt.ts
@@ -49,7 +50,13 @@
4950
from torch_tensorrt.dynamo._compiler import (
5051
save_cross_compiled_exported_program as dynamo_save_cross_compiled_exported_program,
5152
)
53+
from torch_tensorrt.dynamo._defaults import default_device
54+
from torch_tensorrt.dynamo._tracer import (
55+
get_dynamic_shapes_args,
56+
get_dynamic_shapes_kwargs,
57+
)
5258
from torch_tensorrt.dynamo._tracer import trace as dynamo_trace
59+
from torch_tensorrt.dynamo.utils import get_torch_inputs
5360

5461
logger = logging.getLogger(__name__)
5562

@@ -175,7 +182,7 @@ def compile(
175182
ir: str = "default",
176183
inputs: Optional[Sequence[InputType]] = None,
177184
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
178-
kwarg_inputs: Optional[dict[Any, Any]] = None,
185+
kwarg_inputs: Optional[Dict[str, Any]] = None,
179186
enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None,
180187
**kwargs: Any,
181188
) -> (
@@ -301,13 +308,18 @@ def _fx_input_interface(
301308
if not isinstance(arg_inputs, collections.abc.Sequence):
302309
arg_inputs = [arg_inputs] # type: ignore
303310

304-
# Export the module
305311
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
306312
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
307313

308-
exp_program = dynamo_trace(
309-
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
310-
)
314+
if module_type == _ModuleType.ep:
315+
exp_program = module
316+
else:
317+
exp_program = dynamo_trace(
318+
module,
319+
torchtrt_arg_inputs,
320+
kwarg_inputs=torchtrt_kwarg_inputs,
321+
**kwargs,
322+
)
311323
# Run post-trace hooks.
312324
from torch_tensorrt.dynamo._compiler import _post_trace_hooks
313325
for _hook in _post_trace_hooks:
@@ -329,7 +341,7 @@ def _fx_input_interface(
329341
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
330342

331343

332-
@needs_cross_compile
344+
@needs_cross_compile # type: ignore[misc]
333345
def cross_compile_for_windows(
334346
module: torch.nn.Module,
335347
file_path: str,
@@ -567,36 +579,56 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any:
567579
return dynamo_load_cross_compiled_exported_program(file_path)
568580

569581

570-
def load(file_path: str = "") -> Any:
582+
def load(
583+
file_path: str = "", extra_files: Optional[dict[str, Any]] = None, **kwargs: Any
584+
) -> Any:
571585
"""
572586
Load either a Torchscript model or ExportedProgram.
573587
574588
Loads a TorchScript or ExportedProgram file from disk. File type will be detect the type using try, except.
575589
576590
Arguments:
577591
file_path (str): Path to file on the disk
592+
extra_files (dict[str, Any]): Extra files to load with the model
593+
594+
Example:
595+
# Load with extra files.
596+
extra_files = {"foo.txt": ""} # values will be replaced with serialized data
597+
ep = torch.export.load("exported_program.pt2", extra_files=extra_files)
598+
print(extra_files["foo.txt"])
578599
579600
Raises:
580601
ValueError: If there is no file or the file is not either a TorchScript file or ExportedProgram file
581602
"""
603+
582604
try:
583605
logger.debug(f"Loading the provided file {file_path} using torch.jit.load()")
584-
ts_module = torch.jit.load(file_path)
606+
ts_module = function_overload_with_kwargs(
607+
torch.export.load,
608+
file_path,
609+
extra_files=extra_files,
610+
**kwargs,
611+
)
585612
return ts_module
586613
except Exception:
587614
logger.info(
588-
f"Loading the provided file {file_path} via torch.jit.load() failed with the following error",
615+
f"Loading the provided file {file_path} via torch.export.load() failed with the following error",
589616
exc_info=True,
590617
)
591618
pass
592619

593620
try:
594621
logger.debug(f"Loading the provided file {file_path} using torch.export.load()")
595-
exp_program = torch.export.load(file_path)
622+
exp_program = function_overload_with_kwargs(
623+
torch.jit.load,
624+
file_path,
625+
_extra_files=extra_files,
626+
**kwargs,
627+
)
596628
return exp_program
597629
except Exception:
598630
logger.info(
599-
f"Loading the provided file {file_path} via torch.export.load() failed with the following error",
631+
f"Loading the provided file {file_path} via torch.jit.load() (after failing to load with torch.export.load()) failed with the following error",
600632
exc_info=True,
601633
)
602634
raise ValueError(
@@ -608,36 +640,104 @@ def save(
608640
module: Any,
609641
file_path: str = "",
610642
*,
643+
extra_files: Optional[dict[str, str]] = None,
611644
output_format: str = "exported_program",
612-
inputs: Optional[Sequence[torch.Tensor]] = None,
613-
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
614-
kwarg_inputs: Optional[dict[str, Any]] = None,
645+
inputs: Optional[Sequence[torch.Tensor | Input]] = None,
646+
arg_inputs: Optional[Sequence[torch.Tensor | Input]] = None,
647+
kwarg_inputs: Optional[Dict[str, Any]] = None,
615648
retrace: bool = True,
649+
use_legacy_exporter: Optional[bool] = None,
616650
pickle_protocol: int = 2,
651+
dynamic_shapes: Optional[Dict[str, Any]] = None,
617652
**kwargs: Any,
618653
) -> None:
619654
"""
620655
Save the model to disk in the specified output format.
621656
622657
Arguments:
623658
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module
624-
inputs (torch.Tensor): Torch input tensors
625-
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
626-
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
659+
inputs (Union[torch.Tensor, torch_tensorrt.Input]): Torch input tensors or Input specifications
660+
arg_inputs (Tuple[Union[torch.Tensor, torch_tensorrt.Input], ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
661+
kwarg_inputs (dict[str, Union[torch.Tensor, torch_tensorrt.Input]]): Optional, kwarg inputs to the module forward function.
627662
output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor.
628663
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
629-
This flag is experimental for now.
664+
665+
For TRT-compiled modules with dynamic shapes, both retrace=True and retrace=False are supported:
666+
667+
- **retrace=True**: Automatically detects symbolic shape metadata in the TRT module and preserves it
668+
without retracing. This is the recommended approach as it maintains the exact symbolic shapes
669+
from the original compilation.
670+
671+
- **retrace=False**: Directly serializes the existing graph metadata without any re-export.
672+
This is faster but may not be compatible with all torch.export consumers.
673+
674+
For static shape models, retrace=True performs a standard torch.export.export() call.
675+
676+
use_legacy_exporter (Optional[bool]): Override the exporter used when serializing a torch.fx.GraphModule.
677+
By default (None) the choice is made automatically:
678+
679+
- ``retrace=False`` always uses the legacy exporter (pure graph surgery, no re-execution).
680+
- ``retrace=True`` with dynamic shapes uses ``torch.export.export`` on the inlined graph,
681+
which produces a fully standards-compliant ExportedProgram.
682+
683+
Set to ``True`` to force the legacy exporter regardless of ``retrace``.
684+
Set to ``False`` to force ``torch.export.export`` on the inlined graph; this requires
685+
example inputs and a live CUDA device.
686+
630687
pickle_protocol (int): The pickle protocol to use to save the model. Default is 2. Increase this to 4 or higher for large models
688+
dynamic_shapes (Optional[Union[dict[str, Any], tuple[Any, ...]]]): Dynamic shape specifications for re-exporting the model.
689+
690+
**Method 1: Explicit dynamic_shapes (torch.export style)**
691+
692+
Provide explicit torch.export.Dim specifications::
693+
694+
# For a single input with dynamic batch dimension
695+
dyn_batch = torch.export.Dim("batch", min=1, max=32)
696+
dynamic_shapes = {"x": {0: dyn_batch}}
697+
torch_tensorrt.save(model, "model.ep", arg_inputs=[example_tensor], dynamic_shapes=dynamic_shapes)
698+
699+
# For multiple inputs
700+
dynamic_shapes = ({"x": {0: dyn_batch}}, {"y": {0: dyn_batch}})
701+
702+
**Method 2: Inferred from torch_tensorrt.Input**
703+
704+
Pass torch_tensorrt.Input objects with min/opt/max shapes in arg_inputs/kwarg_inputs,
705+
and dynamic_shapes will be inferred automatically::
706+
707+
inputs = [
708+
torch_tensorrt.Input(
709+
min_shape=(1, 3, 224, 224),
710+
opt_shape=(8, 3, 224, 224),
711+
max_shape=(32, 3, 224, 224),
712+
name="x" # Optional: name for better dim naming
713+
)
714+
]
715+
torch_tensorrt.save(model, "model.ep", arg_inputs=inputs) # dynamic_shapes inferred!
716+
717+
**Important Limitations:**
718+
719+
- Automatic inference creates **separate Dim objects for each input**. If your model requires
720+
multiple inputs to share the same dimension (e.g., matching batch sizes), you MUST use
721+
Method 1 with explicit shared Dim objects::
722+
723+
batch = torch.export.Dim("batch", min=1, max=8)
724+
dynamic_shapes = {"x": {0: batch}, "mask": {0: batch}} # Shared batch dimension
725+
726+
- Automatic inference is **disabled for mixed Input/Tensor inputs** to avoid spurious
727+
equality constraints. Use explicit dynamic_shapes for these cases.
728+
729+
- If both dynamic_shapes and Input objects are provided, the explicit dynamic_shapes
730+
parameter takes precedence.
631731
"""
632732
if isinstance(module, CudaGraphsTorchTensorRTModule):
633733
module = module.compiled_module
634734
module_type = _parse_module_type(module)
635735
accepted_formats = {"exported_program", "torchscript", "aot_inductor"}
636736
if arg_inputs is not None and not all(
637-
isinstance(input, torch.Tensor) for input in arg_inputs
737+
isinstance(input, (torch.Tensor, Input)) for input in arg_inputs
638738
):
639739
raise ValueError(
640-
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
740+
"Not all inputs provided are torch.Tensor or torch_tensorrt.Input objects. Please provide inputs of a valid type"
641741
)
642742
if arg_inputs and inputs:
643743
raise AssertionError(

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@
5555
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
5656
TILING_OPTIMIZATION_LEVEL = "none"
5757
L2_LIMIT_FOR_TILING = -1
58-
EDITABLE_TIMING_CACHE = False
59-
ERROR_ON_TIMING_CACHE_MISS = False
6058
USE_DISTRIBUTED_MODE_TRACE = False
6159
OFFLOAD_MODULE_TO_CPU = False
6260
ENABLE_AUTOCAST = False
@@ -69,6 +67,7 @@
6967
ENABLE_RESOURCE_PARTITIONING = False
7068
CPU_MEMORY_BUDGET = None
7169
DYNAMICALLY_ALLOCATE_RESOURCES = False
70+
DECOMPOSE_ATTENTION = False
7271

7372
if platform.system() == "Linux":
7473
import pwd

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,16 @@
2020
DLA_GLOBAL_DRAM_SIZE,
2121
DLA_LOCAL_DRAM_SIZE,
2222
DLA_SRAM_SIZE,
23+
DECOMPOSE_ATTENTION,
2324
DYNAMICALLY_ALLOCATE_RESOURCES,
2425
DRYRUN,
25-
EDITABLE_TIMING_CACHE,
2626
ENABLE_AUTOCAST,
2727
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
2828
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
2929
ENABLE_RESOURCE_PARTITIONING,
3030
ENABLE_WEIGHT_STREAMING,
3131
ENABLED_PRECISIONS,
3232
ENGINE_CAPABILITY,
33-
ERROR_ON_TIMING_CACHE_MISS,
3433
HARDWARE_COMPATIBLE,
3534
IMMUTABLE_WEIGHTS,
3635
L2_LIMIT_FOR_TILING,
@@ -110,8 +109,6 @@ class CompilationSettings:
110109
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
111110
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
112111
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
113-
editable_timing_cache (bool): Allow TensorRT to write new timing measurements into the timing cache during build (TRT 10.8+). Enable this on the first run so the cache is fully populated; subsequent runs can then load the cache and reproduce the same tactic selection. Default: False.
114-
error_on_timing_cache_miss (bool): Raise a build error if any tactic's timing is not found in the loaded timing cache (TRT 10.8+). Use in combination with a pre-populated ``timing_cache_path`` to guarantee that no re-profiling occurs and tactic selection is identical to the seed run, producing bitwise-identical engines. Default: False.
115112
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
116113
enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True.
117114
autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used.
@@ -122,6 +119,7 @@ class CompilationSettings:
122119
autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None.
123120
offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation
124121
dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines
122+
decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
125123
"""
126124

127125
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -163,8 +161,6 @@ class CompilationSettings:
163161
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
164162
tiling_optimization_level: str = TILING_OPTIMIZATION_LEVEL
165163
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
166-
editable_timing_cache: bool = EDITABLE_TIMING_CACHE
167-
error_on_timing_cache_miss: bool = ERROR_ON_TIMING_CACHE_MISS
168164
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
169165
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
170166
enable_autocast: bool = ENABLE_AUTOCAST
@@ -181,8 +177,9 @@ class CompilationSettings:
181177
AUTOCAST_CALIBRATION_DATALOADER
182178
)
183179
enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING
184-
cpu_memory_budget: int = CPU_MEMORY_BUDGET
180+
cpu_memory_budget: Optional[int] = CPU_MEMORY_BUDGET
185181
dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES
182+
decompose_attention: bool = DECOMPOSE_ATTENTION
186183
profiling_verbosity: Optional[Any] = None
187184

188185
def __getstate__(self) -> dict[str, Any]:
@@ -223,6 +220,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
223220
"autocast_max_output_threshold",
224221
"autocast_max_depth_of_reduction",
225222
"autocast_calibration_dataloader",
223+
"decompose_attention",
226224
}
227225

228226

0 commit comments

Comments
 (0)