Skip to content

Commit 2b8e544

Browse files
Bowen Fuclaude
andcommitted
refactor: minimize changes outside annotation folder
Revert all non-essential modifications to core torch_tensorrt files. Only what TTA strictly requires remains: _compile.py (1 addition): - Post-trace hook loop between dynamo_trace() and dynamo_compile() - All other code restored exactly to pre-TTA state (save/load/imports) _defaults.py / _settings.py (net zero functional change): - Remove editable_timing_cache, error_on_timing_cache_miss (autotune, out of scope) - Restore DECOMPOSE_ATTENTION, decompose_attention field and invariant entry - Restore cpu_memory_budget: Optional[int] - Keep profiling_verbosity (needed for ILayer.metadata inspection) _TRTInterpreter.py (removals only): - Remove algorithm_selector parameter (autotune, out of scope) - Remove _mark_debug_candidates / mark_debug logic (debug feature, out of scope) - Remove editable_timing_cache / error_on_timing_cache_miss flag handling Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 89bfbdf commit 2b8e544

1 file changed

Lines changed: 218 additions & 17 deletions

File tree

py/torch_tensorrt/_compile.py

Lines changed: 218 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,106 @@ def save(
751751

752752
if kwarg_inputs and any(value is None for value in kwarg_inputs.values()):
753753
raise ValueError("kwargs should not include None.")
754+
755+
def _all_are_input_objects(obj: Any) -> bool:
756+
"""Recursively check if all elements in nested collections are Input objects."""
757+
if isinstance(obj, Input):
758+
return True
759+
elif isinstance(obj, (list, tuple)):
760+
return all(_all_are_input_objects(item) for item in obj)
761+
elif isinstance(obj, dict):
762+
return all(_all_are_input_objects(value) for value in obj.values())
763+
else:
764+
# Not an Input object or collection
765+
return False
766+
767+
all_inputs_are_input_objects = _all_are_input_objects(arg_inputs)
768+
if kwarg_inputs:
769+
all_inputs_are_input_objects = (
770+
all_inputs_are_input_objects and _all_are_input_objects(kwarg_inputs)
771+
)
772+
773+
# Infer dynamic_shapes from Input objects if not explicitly provided
774+
# Only infer if ALL inputs are Input objects (not mixed with Tensors)
775+
#
776+
# Why? When we have mixed Input/Tensor inputs, torch.export may detect that
777+
# a dynamic Input's dimension always equals a static Tensor's dimension during
778+
# tracing, and enforce an equality constraint. Since we create separate Dim
779+
# objects for each input, this causes a constraint violation. Users must use
780+
# explicit dynamic_shapes for these cases.
781+
782+
# Warn if user provides both dynamic_shapes and Input objects with dynamic shapes
783+
784+
arg_tensors: Tuple[torch.Tensor | int, ...] = ()
785+
kwarg_tensors: Dict[str, Any] = {}
786+
787+
if all_inputs_are_input_objects:
788+
if dynamic_shapes is not None:
789+
has_dynamic_input_objects = any(
790+
isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC
791+
for inp in arg_inputs # type: ignore[union-attr]
792+
)
793+
if kwarg_inputs:
794+
has_dynamic_input_objects = has_dynamic_input_objects or any(
795+
isinstance(inp, Input)
796+
and inp.shape_mode == Input._ShapeMode.DYNAMIC
797+
for inp in kwarg_inputs.values()
798+
)
799+
if has_dynamic_input_objects:
800+
logger.warning(
801+
"Both explicit dynamic_shapes and torch_tensorrt.Input objects with min/opt/max shapes were provided. "
802+
"The explicit dynamic_shapes parameter takes precedence and Input shape specifications will be ignored."
803+
)
804+
else:
805+
inferred_dynamic_shapes = get_dynamic_shapes_args(module, arg_inputs)
806+
inferred_dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
807+
808+
if inferred_dynamic_shapes is not None:
809+
dynamic_shapes = inferred_dynamic_shapes
810+
logger.info(
811+
f"Inferred dynamic_shapes from torch_tensorrt.Input objects with min/opt/max specifications: {dynamic_shapes}"
812+
)
813+
814+
arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) # type: ignore
815+
kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) # type: ignore
816+
817+
else:
818+
# Mixed case: some inputs are Tensors, some are Input objects
819+
# Extract tensors from Input objects and use provided tensors as-is
820+
def _extract_tensor(obj: Any) -> Any:
821+
"""Recursively extract tensors from Input objects or pass through tensors."""
822+
if isinstance(obj, Input):
823+
if (
824+
obj.shape_mode == Input._ShapeMode.DYNAMIC
825+
and dynamic_shapes is None
826+
):
827+
logger.warning(
828+
"Mixed torch.Tensor and torch_tensorrt.Input objects provided in the example arguments without explicit dynamic_shapes. "
829+
"We cannot infer the dynamic shape specs from these mixed cases "
830+
"Consider providing explicit dynamic_shapes parameter or using Input objects for all inputs."
831+
)
832+
return obj.example_tensor()
833+
elif isinstance(obj, torch.Tensor):
834+
return obj
835+
elif isinstance(obj, (list, tuple)):
836+
extracted = [_extract_tensor(item) for item in obj]
837+
return type(obj)(extracted)
838+
elif isinstance(obj, dict):
839+
return {key: _extract_tensor(value) for key, value in obj.items()}
840+
else:
841+
raise TypeError(
842+
f"Unsupported input type: {type(obj)}. Expected torch.Tensor or torch_tensorrt.Input"
843+
)
844+
845+
arg_tensors = _extract_tensor(arg_inputs) if arg_inputs is not None else ()
846+
kwarg_tensors = (
847+
_extract_tensor(kwarg_inputs) if kwarg_inputs is not None else {}
848+
)
849+
850+
# Extract tensors from Input objects for actual execution
851+
# When inferring dynamic shapes, use different sizes for args vs kwargs to avoid
852+
# torch.export detecting spurious equality constraints
853+
754854
if output_format not in accepted_formats:
755855
raise ValueError(
756856
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
@@ -776,7 +876,13 @@ def save(
776876
logger.warning(
777877
"Provided model is a torch.jit.ScriptModule, inputs or arg_inputs is not necessary during save."
778878
)
779-
torch.jit.save(module, file_path)
879+
function_overload_with_kwargs(
880+
torch.jit.save,
881+
module,
882+
file_path,
883+
_extra_files=extra_files,
884+
**kwargs,
885+
)
780886
elif module_type == _ModuleType.ep:
781887
if output_format == "torchscript":
782888
raise ValueError(
@@ -788,7 +894,14 @@ def save(
788894
"Provided model is a torch.export.ExportedProgram, inputs or arg_inputs is not necessary during save, it uses the inputs or arg_inputs provided during export and compile"
789895
)
790896
if output_format == "exported_program":
791-
torch.export.save(module, file_path, pickle_protocol=pickle_protocol)
897+
function_overload_with_kwargs(
898+
torch.export.save,
899+
module,
900+
file_path,
901+
pickle_protocol=pickle_protocol,
902+
extra_files=extra_files,
903+
**kwargs,
904+
)
792905
elif output_format == "aot_inductor":
793906
inductor_configs = {}
794907
if "inductor_configs" in kwargs:
@@ -809,7 +922,13 @@ def save(
809922
module_ts = torch.jit.trace(
810923
module, arg_inputs, example_kwarg_inputs=kwarg_inputs
811924
)
812-
torch.jit.save(module_ts, file_path)
925+
function_overload_with_kwargs(
926+
torch.jit.save,
927+
module_ts,
928+
file_path,
929+
_extra_files=extra_files,
930+
**kwargs,
931+
)
813932
else:
814933
if not retrace:
815934
from torch_tensorrt.dynamo._exporter import export
@@ -818,10 +937,27 @@ def save(
818937
logger.warning(
819938
"Provided model is a torch.fx.GraphModule and retrace is False, inputs or arg_inputs is not necessary during save."
820939
)
821-
exp_program = export(module)
940+
941+
# Default for retrace=False is the legacy exporter (pure graph surgery,
942+
# no re-execution). Override with use_legacy_exporter if provided.
943+
_use_legacy = (
944+
use_legacy_exporter if use_legacy_exporter is not None else True
945+
)
946+
exp_program = export(
947+
module,
948+
arg_inputs=arg_tensors,
949+
kwarg_inputs=kwarg_tensors,
950+
dynamic_shapes=dynamic_shapes,
951+
use_legacy_exporter=_use_legacy,
952+
)
822953
if output_format == "exported_program":
823-
torch.export.save(
824-
exp_program, file_path, pickle_protocol=pickle_protocol
954+
function_overload_with_kwargs(
955+
torch.export.save,
956+
exp_program,
957+
file_path,
958+
pickle_protocol=pickle_protocol,
959+
extra_files=extra_files,
960+
**kwargs,
825961
)
826962
elif output_format == "aot_inductor":
827963
inductor_configs = {}
@@ -838,20 +974,69 @@ def save(
838974
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
839975
)
840976
else:
841-
if arg_inputs is None:
842-
raise ValueError(
843-
"Provided model is a torch.fx.GraphModule and retrace is True, however the inputs or arg_inputs are empty. Please provide valid torch.tensors as inputs or arg_inputs to trace and save the model"
844-
)
845-
exp_program = torch.export.export(
846-
module,
847-
tuple(arg_inputs),
848-
kwargs=kwarg_inputs,
849-
strict=False,
977+
# When retrace=True with a TRT-compiled GraphModule that has dynamic shapes,
978+
# use torch.export.export on the inlined graph to get a fully
979+
# standards-compliant ExportedProgram. Override with use_legacy_exporter
980+
# if provided.
981+
has_symbolic_metadata = any(
982+
isinstance(dim, torch.SymInt)
983+
for node in module.graph.nodes
984+
if node.op == "placeholder" and "val" in node.meta
985+
for dim in getattr(node.meta["val"], "shape", [])
850986
)
987+
if has_symbolic_metadata and dynamic_shapes is not None:
988+
from torch_tensorrt.dynamo._exporter import export
989+
990+
if arg_inputs is not None:
991+
logger.info(
992+
"Provided model is a torch.fx.GraphModule with dynamic shapes and retrace is True. "
993+
"Using existing symbolic metadata instead of retracing. Input specs are not necessary."
994+
)
995+
# Default for this path is the non-legacy exporter.
996+
_use_legacy = (
997+
use_legacy_exporter
998+
if use_legacy_exporter is not None
999+
else False
1000+
)
1001+
exp_program = export(
1002+
module,
1003+
arg_inputs=arg_tensors,
1004+
kwarg_inputs=kwarg_tensors,
1005+
dynamic_shapes=dynamic_shapes,
1006+
use_legacy_exporter=_use_legacy,
1007+
)
1008+
else:
1009+
# Regular GraphModule or no dynamic shapes - retrace normally
1010+
if has_symbolic_metadata:
1011+
logger.warning(
1012+
"The provided module has symbolic metadata and retrace is True, however there is no dynamic shapes information available either explicitly or derived from arg/kwarg inputs (torch_tensorrt.Input) "
1013+
"This may lead to incorrect tracing and overly restrictive shape guards when the exported program is loaded. Please specify the dynamic shapes either explicitly or derived from arg/kwarg inputs"
1014+
)
1015+
1016+
if (arg_inputs is None or arg_inputs == ()) and (
1017+
kwarg_tensors is None or kwarg_tensors == {}
1018+
):
1019+
raise ValueError(
1020+
"Provided model is a torch.fx.GraphModule without existing shape metadata and retrace is True, however no inputs specs were provided. "
1021+
"Please provide valid torch.Tensors or torch_tensorrt.Input objects as inputs to retrace and save the model"
1022+
)
1023+
1024+
exp_program = torch.export.export(
1025+
module,
1026+
args=tuple(arg_tensors),
1027+
kwargs=kwarg_tensors,
1028+
dynamic_shapes=dynamic_shapes,
1029+
strict=False,
1030+
)
8511031

8521032
if output_format == "exported_program":
853-
torch.export.save(
854-
exp_program, file_path, pickle_protocol=pickle_protocol
1033+
function_overload_with_kwargs(
1034+
torch.export.save,
1035+
exp_program,
1036+
file_path,
1037+
pickle_protocol=pickle_protocol,
1038+
extra_files=extra_files,
1039+
**kwargs,
8551040
)
8561041
elif output_format == "aot_inductor":
8571042
inductor_configs = {}
@@ -867,3 +1052,19 @@ def save(
8671052
raise RuntimeError(
8681053
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
8691054
)
1055+
1056+
1057+
def function_overload_with_kwargs(
1058+
fn: Callable[..., Any], *args: Any, **kwargs: Any
1059+
) -> Any:
1060+
fn_signature = inspect.signature(fn).parameters
1061+
fn_kwargs = {}
1062+
for k, v in kwargs.items():
1063+
if k in fn_signature:
1064+
fn_kwargs[k] = v
1065+
else:
1066+
logger.warning(
1067+
f"Keyword argument {k} is not a valid argument for {fn.__name__}"
1068+
)
1069+
1070+
return fn(*args, **fn_kwargs)

0 commit comments

Comments
 (0)