55import os
66import platform
77import warnings
8- from typing import Any , Callable , Collection , Dict , List , Optional , Sequence , Set , Tuple , Union
8+ from typing import Any , Callable , Collection , List , Optional , Sequence , Set , Tuple , Union
99
1010import torch
1111from torch .export import ExportedProgram
4040 post_lowering ,
4141 pre_export_lowering ,
4242)
43-
4443from torch_tensorrt .dynamo .partitioning ._resource_partitioner import (
4544 resource_partition ,
4645)
5756
5857logger = logging .getLogger (__name__ )
5958
60- # Passes registered by extensions at import time.
61- # Each entry is called after post_lowering and before TRT conversion.
62- #
63- # Pass contract:
64- # fn(exported_program=..., fx_module=..., logger=...)
65- #
66- # A pass may either:
67- # - mutate in place and return None, or
68- # - return (exported_program, fx_module) to replace the current objects.
69- #
70- # Extensions must never be imported by torch_tensorrt — they register
71- # themselves here instead.
59+ # Extension hook registries — populated at import time by external modules
60+ # (e.g. torch_tensorrt.annotation). torch_tensorrt itself never imports
61+ # those modules; they register themselves here instead.
7262_compile_passes : List [Callable ] = []
73-
74- # ExportedProgram attribute names that extensions want preserved across
75- # run_decompositions() (which returns a fresh EP object, discarding any
76- # custom attributes set before the call). Extensions register attribute
77- # names at import time via register_preserved_ep_attr().
7863_preserved_ep_attrs : List [str ] = []
79-
80- # Context-manager factories called around torch.export.export when an
81- # nn.Module is traced. Each factory receives (model, inputs) and returns
82- # a context manager. Registered by extensions at import time.
8364_export_context_factories : List [Callable ] = []
84-
85- # Hooks called after dynamo_trace() (torch.export) and before dynamo_compile().
86- # Each hook receives (exported_program, inputs) and may return a new EP.
8765_post_trace_hooks : List [Callable ] = []
8866
8967
9068def register_compile_pass (fn : Callable ) -> None :
91- """Register a pre-TRT FX pass. Called by extensions at import time ."""
69+ """Register a pre-TRT FX pass called after post_lowering ."""
9270 _compile_passes .append (fn )
9371
9472
9573def register_preserved_ep_attr (name : str ) -> None :
96- """Preserve a custom ExportedProgram attribute across run_decompositions().
97-
98- Extensions that store state on ``exported_program.<name>`` before calling
99- ``torch_tensorrt.compile`` should register that attribute name here so the
100- compiler can copy it onto the new EP returned by ``run_decompositions()``.
101- """
74+ """Preserve a custom ExportedProgram attribute across run_decompositions()."""
10275 _preserved_ep_attrs .append (name )
10376
10477
10578def register_export_context (fn : Callable ) -> None :
106- """Register a context factory ``fn(model, inputs)`` wrapping torch.export.export.
107-
108- Called by extensions at import time so that custom state can be active
109- during the export step without modifying the tracer directly.
110- """
79+ """Register a context factory ``fn(model, inputs)`` wrapping torch.export.export."""
11180 _export_context_factories .append (fn )
11281
11382
11483def register_post_trace_hook (fn : Callable ) -> None :
115- """Register a hook called after torch.export and before TRT compilation.
116-
117- ``fn(exported_program, inputs)`` may return a new ExportedProgram to
118- replace the current one, or ``None`` to leave it unchanged.
119- """
84+ """Register a hook called after torch.export and before TRT compilation."""
12085 _post_trace_hooks .append (fn )
12186
12287
@@ -174,6 +139,7 @@ def cross_compile_for_windows(
174139 enable_resource_partitioning : bool = _defaults .ENABLE_RESOURCE_PARTITIONING ,
175140 cpu_memory_budget : Optional [int ] = _defaults .CPU_MEMORY_BUDGET ,
176141 dynamically_allocate_resources : bool = _defaults .DYNAMICALLY_ALLOCATE_RESOURCES ,
142+ decompose_attention : bool = _defaults .DECOMPOSE_ATTENTION ,
177143 ** kwargs : Any ,
178144) -> torch .fx .GraphModule :
179145 """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -251,6 +217,7 @@ def cross_compile_for_windows(
251217 enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited.
252218 cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
253219 dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
220+ decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
254221 **kwargs: Any,
255222 Returns:
256223 torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -409,6 +376,7 @@ def cross_compile_for_windows(
409376 "enable_resource_partitioning" : enable_resource_partitioning ,
410377 "cpu_memory_budget" : cpu_memory_budget ,
411378 "dynamically_allocate_resources" : dynamically_allocate_resources ,
379+ "decompose_attention" : decompose_attention ,
412380 }
413381
414382 # disable the following settings is not supported for cross compilation for windows feature
@@ -428,16 +396,10 @@ def cross_compile_for_windows(
428396
429397 settings = CompilationSettings (** compilation_options )
430398 logger .info ("Compilation Settings: %s\n " , settings )
431- # Preserve custom EP attributes across run_decompositions() — the returned
432- # EP is a fresh object and does not carry any attributes set by extensions.
433- _saved_ep_attrs = {
434- k : getattr (exported_program , k )
435- for k in _preserved_ep_attrs
436- if hasattr (exported_program , k )
437- }
438399 exported_program = pre_export_lowering (exported_program , settings )
400+ _saved_ep_attrs = {k : getattr (exported_program , k ) for k in _preserved_ep_attrs if hasattr (exported_program , k )}
439401 exported_program = exported_program .run_decompositions (
440- get_decompositions (enable_experimental_decompositions )
402+ get_decompositions (enable_experimental_decompositions , decompose_attention )
441403 )
442404 for k , v in _saved_ep_attrs .items ():
443405 try :
@@ -448,23 +410,23 @@ def cross_compile_for_windows(
448410 gm = exported_program .module ()
449411 logger .debug ("Input graph: " + str (gm .graph ))
450412
451- # Apply lowering on the graph module
413+ # Apply lowering on the graph module. Note: constant_fold runs inside post_lowering and requires
414+ # module parameters to still be on GPU, so we must not deallocate before this call.
452415 gm = post_lowering (gm , settings )
416+ logger .debug (f"CPU memory usage after post_lowering: { get_cpu_memory_usage ()} MB" )
453417 logger .debug ("Lowered Input graph: " + str (gm .graph ))
454418 for _pass in _compile_passes :
455- _result = _pass (
456- exported_program = exported_program ,
457- fx_module = gm ,
458- logger = logger ,
459- )
419+ _result = _pass (exported_program = exported_program , fx_module = gm , logger = logger )
460420 if isinstance (_result , tuple ):
461421 exported_program , gm = _result
422+
462423 # Move the weights in the state_dict to CPU
463424 if offload_module_to_cpu :
464- deallocate_module (exported_program . module (), delete_module = False )
425+ deallocate_module (gm )
465426 logger .info (
466427 "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
467428 )
429+ logger .debug (f"CPU memory usage after CPU offload: { get_cpu_memory_usage ()} MB" )
468430 else :
469431 remaining_memory , total_memory = torch .cuda .mem_get_info ()
470432 if remaining_memory < total_memory // 2 :
@@ -546,6 +508,7 @@ def compile(
546508 cpu_memory_budget : Optional [int ] = _defaults .CPU_MEMORY_BUDGET ,
547509 enable_resource_partitioning : bool = _defaults .ENABLE_RESOURCE_PARTITIONING ,
548510 dynamically_allocate_resources : bool = _defaults .DYNAMICALLY_ALLOCATE_RESOURCES ,
511+ decompose_attention : bool = _defaults .DECOMPOSE_ATTENTION ,
549512 profiling_verbosity : Optional [Any ] = None ,
550513 ** kwargs : Any ,
551514) -> torch .fx .GraphModule :
@@ -634,6 +597,7 @@ def compile(
634597 enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited.
635598 cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
636599 dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
600+ decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
637601 **kwargs: Any,
638602 Returns:
639603 torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -648,7 +612,7 @@ def compile(
648612
649613 if not kwargs .get ("use_explicit_typing" , False ):
650614 warnings .warn (
651- "`use_explicit_typing` is deprecated. This setting will be removed and you should enable autocast instead ." ,
615+ "`use_explicit_typing` is deprecated. use_explicit_types is now on by default, this setting will be removed and you should enable autocast to recover weak typing behavior ." ,
652616 DeprecationWarning ,
653617 stacklevel = 2 ,
654618 )
@@ -837,21 +801,17 @@ def compile(
837801 "enable_resource_partitioning" : enable_resource_partitioning ,
838802 "cpu_memory_budget" : cpu_memory_budget ,
839803 "dynamically_allocate_resources" : dynamically_allocate_resources ,
840- "profiling_verbosity " : profiling_verbosity ,
804+ "decompose_attention " : decompose_attention ,
841805 }
806+ if profiling_verbosity is not None :
807+ compilation_options ["profiling_verbosity" ] = profiling_verbosity
842808 logger .debug (f"CPU memory usage before lowering: { get_cpu_memory_usage ()} MB" )
843809 settings = CompilationSettings (** compilation_options )
844810 logger .info ("Compilation Settings: %s\n " , settings )
845- # Preserve custom EP attributes across run_decompositions() — the returned
846- # EP is a fresh object and does not carry any attributes set by extensions.
847- _saved_ep_attrs = {
848- k : getattr (exported_program , k )
849- for k in _preserved_ep_attrs
850- if hasattr (exported_program , k )
851- }
852811 exported_program = pre_export_lowering (exported_program , settings )
812+ _saved_ep_attrs = {k : getattr (exported_program , k ) for k in _preserved_ep_attrs if hasattr (exported_program , k )}
853813 exported_program = exported_program .run_decompositions (
854- get_decompositions (enable_experimental_decompositions )
814+ get_decompositions (enable_experimental_decompositions , decompose_attention )
855815 )
856816 for k , v in _saved_ep_attrs .items ():
857817 try :
@@ -863,24 +823,19 @@ def compile(
863823 # Move the weights in the state_dict to CPU
864824 logger .debug ("Input graph: " + str (gm .graph ))
865825
866- # Apply lowering on the graph module
826+ # Apply lowering on the graph module. Note: constant_fold runs inside post_lowering and requires
827+ # module parameters to still be on GPU, so we must not deallocate before this call.
867828 gm = post_lowering (gm , settings )
868829 logger .debug (f"CPU memory usage after post_lowering: { get_cpu_memory_usage ()} MB" )
869830 logger .debug ("Lowered Input graph: " + str (gm .graph ))
870-
871831 for _pass in _compile_passes :
872- _result = _pass (
873- exported_program = exported_program ,
874- fx_module = gm ,
875- logger = logger ,
876- )
832+ _result = _pass (exported_program = exported_program , fx_module = gm , logger = logger )
877833 if isinstance (_result , tuple ):
878834 exported_program , gm = _result
879835
880836 # Move the weights in the state_dict to CPU
881837 if offload_module_to_cpu :
882- deallocate_module (gm , delete_module = False )
883- deallocate_module (exported_program .module (), delete_module = False )
838+ deallocate_module (gm )
884839 logger .info (
885840 "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
886841 )
@@ -892,11 +847,7 @@ def compile(
892847 "Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
893848 )
894849 trt_gm = compile_module (
895- gm ,
896- trt_arg_inputs ,
897- trt_kwarg_inputs ,
898- settings ,
899- engine_cache ,
850+ gm , trt_arg_inputs , trt_kwarg_inputs , settings , engine_cache
900851 )
901852 return trt_gm
902853
@@ -1152,7 +1103,6 @@ def preserve_module_specs(
11521103 trt_modules [name ] = trt_module
11531104
11541105 if _debugger_config :
1155-
11561106 if _debugger_config .save_engine_profile :
11571107 if settings .use_python_runtime :
11581108 if _debugger_config .profile_format != "cudagraph" :
@@ -1267,6 +1217,7 @@ def convert_exported_program_to_serialized_trt_engine(
12671217 l2_limit_for_tiling : int = _defaults .L2_LIMIT_FOR_TILING ,
12681218 offload_module_to_cpu : bool = _defaults .OFFLOAD_MODULE_TO_CPU ,
12691219 use_distributed_mode_trace : bool = _defaults .USE_DISTRIBUTED_MODE_TRACE ,
1220+ decompose_attention : bool = _defaults .DECOMPOSE_ATTENTION ,
12701221 ** kwargs : Any ,
12711222) -> bytes :
12721223 """Convert an ExportedProgram to a serialized TensorRT engine
@@ -1341,6 +1292,7 @@ def convert_exported_program_to_serialized_trt_engine(
13411292 l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
13421293 offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
13431294 use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model.
1295+ decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
13441296 **kwargs: Any,
13451297 Returns:
13461298 bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
@@ -1510,20 +1462,15 @@ def convert_exported_program_to_serialized_trt_engine(
15101462 "l2_limit_for_tiling" : l2_limit_for_tiling ,
15111463 "offload_module_to_cpu" : offload_module_to_cpu ,
15121464 "use_distributed_mode_trace" : use_distributed_mode_trace ,
1465+ "decompose_attention" : decompose_attention ,
15131466 }
15141467
15151468 settings = CompilationSettings (** compilation_options )
15161469 logger .info ("Compilation Settings: %s\n " , settings )
1517- # Preserve custom EP attributes across run_decompositions() — the returned
1518- # EP is a fresh object and does not carry any attributes set by extensions.
1519- _saved_ep_attrs = {
1520- k : getattr (exported_program , k )
1521- for k in _preserved_ep_attrs
1522- if hasattr (exported_program , k )
1523- }
15241470 exported_program = pre_export_lowering (exported_program , settings )
1471+ _saved_ep_attrs = {k : getattr (exported_program , k ) for k in _preserved_ep_attrs if hasattr (exported_program , k )}
15251472 exported_program = exported_program .run_decompositions (
1526- get_decompositions (enable_experimental_decompositions )
1473+ get_decompositions (enable_experimental_decompositions , decompose_attention )
15271474 )
15281475 for k , v in _saved_ep_attrs .items ():
15291476 try :
@@ -1539,17 +1486,13 @@ def convert_exported_program_to_serialized_trt_engine(
15391486 gm = post_lowering (gm , settings )
15401487 logger .debug ("Lowered Input graph: " + str (gm .graph ))
15411488 for _pass in _compile_passes :
1542- _result = _pass (
1543- exported_program = exported_program ,
1544- fx_module = gm ,
1545- logger = logger ,
1546- )
1489+ _result = _pass (exported_program = exported_program , fx_module = gm , logger = logger )
15471490 if isinstance (_result , tuple ):
15481491 exported_program , gm = _result
15491492
15501493 # Move the weights in the state_dict to CPU
15511494 if offload_module_to_cpu :
1552- deallocate_module (exported_program .module (), delete_module = False )
1495+ deallocate_module (exported_program .module ())
15531496 logger .info (
15541497 "The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
15551498 )
@@ -1560,6 +1503,9 @@ def convert_exported_program_to_serialized_trt_engine(
15601503 "Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
15611504 )
15621505
1506+ if trt_kwarg_inputs is None :
1507+ trt_kwarg_inputs = {}
1508+
15631509 flattened_input_list = get_flat_args_with_check (
15641510 exported_program , list (trt_arg_inputs ), trt_kwarg_inputs
15651511 )[0 ]
@@ -1571,16 +1517,20 @@ def convert_exported_program_to_serialized_trt_engine(
15711517 settings = settings ,
15721518 engine_cache = engine_cache ,
15731519 )
1574- except UnsupportedOperatorException :
1520+ except UnsupportedOperatorException as e :
15751521 logger .error (
15761522 f"Conversion of module { gm } not currently fully supported or convertible!" ,
15771523 exc_info = True ,
15781524 )
1525+ raise UnsupportedOperatorException (
1526+ f"Conversion of module { gm } not currently fully supported or convertible!"
1527+ ) from e
15791528 except Exception as e :
15801529 logger .error (
15811530 f"While interpreting the module got an error: { e } " ,
15821531 exc_info = True ,
15831532 )
1533+ raise RuntimeError (f"While interpreting the module got an error: { e } " ) from e
15841534
15851535 serialized_engine : bytes = interpreter_result .serialized_engine
15861536 return serialized_engine
0 commit comments