Skip to content

Commit c7f9b7d

Browse files
Bowen Fuclaude
andcommitted
refactor: minimize core changes to essential extension hooks only
- _compiler.py: add generic extension hook registries (_compile_passes, _preserved_ep_attrs, _export_context_factories, _post_trace_hooks) + EP preservation around run_decompositions() + compile pass loop after post_lowering; add profiling_verbosity param to dynamo_compile - _settings.py: add profiling_verbosity field to CompilationSettings - _tracer.py: wrap torch.export.export in ExitStack for registered context factories - _ConversionContext.py: add current_node field - _TRTInterpreter.py: add PREFER_AOT_PYTHON_PLUGINS flag, route profiling_verbosity from settings, set ctx.current_node in run_node, stamp layer.metadata from node.meta["layer_metadata"] (only if not already set by converter) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 2b8e544 commit c7f9b7d

5 files changed

Lines changed: 73 additions & 163 deletions

File tree

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 48 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import platform
77
import warnings
8-
from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Set, Tuple, Union
8+
from typing import Any, Callable, Collection, List, Optional, Sequence, Set, Tuple, Union
99

1010
import torch
1111
from torch.export import ExportedProgram
@@ -40,7 +40,6 @@
4040
post_lowering,
4141
pre_export_lowering,
4242
)
43-
4443
from torch_tensorrt.dynamo.partitioning._resource_partitioner import (
4544
resource_partition,
4645
)
@@ -57,66 +56,32 @@
5756

5857
logger = logging.getLogger(__name__)
5958

60-
# Passes registered by extensions at import time.
61-
# Each entry is called after post_lowering and before TRT conversion.
62-
#
63-
# Pass contract:
64-
# fn(exported_program=..., fx_module=..., logger=...)
65-
#
66-
# A pass may either:
67-
# - mutate in place and return None, or
68-
# - return (exported_program, fx_module) to replace the current objects.
69-
#
70-
# Extensions must never be imported by torch_tensorrt — they register
71-
# themselves here instead.
59+
# Extension hook registries — populated at import time by external modules
60+
# (e.g. torch_tensorrt.annotation). torch_tensorrt itself never imports
61+
# those modules; they register themselves here instead.
7262
_compile_passes: List[Callable] = []
73-
74-
# ExportedProgram attribute names that extensions want preserved across
75-
# run_decompositions() (which returns a fresh EP object, discarding any
76-
# custom attributes set before the call). Extensions register attribute
77-
# names at import time via register_preserved_ep_attr().
7863
_preserved_ep_attrs: List[str] = []
79-
80-
# Context-manager factories called around torch.export.export when an
81-
# nn.Module is traced. Each factory receives (model, inputs) and returns
82-
# a context manager. Registered by extensions at import time.
8364
_export_context_factories: List[Callable] = []
84-
85-
# Hooks called after dynamo_trace() (torch.export) and before dynamo_compile().
86-
# Each hook receives (exported_program, inputs) and may return a new EP.
8765
_post_trace_hooks: List[Callable] = []
8866

8967

9068
def register_compile_pass(fn: Callable) -> None:
91-
"""Register a pre-TRT FX pass. Called by extensions at import time."""
69+
"""Register a pre-TRT FX pass called after post_lowering."""
9270
_compile_passes.append(fn)
9371

9472

9573
def register_preserved_ep_attr(name: str) -> None:
96-
"""Preserve a custom ExportedProgram attribute across run_decompositions().
97-
98-
Extensions that store state on ``exported_program.<name>`` before calling
99-
``torch_tensorrt.compile`` should register that attribute name here so the
100-
compiler can copy it onto the new EP returned by ``run_decompositions()``.
101-
"""
74+
"""Preserve a custom ExportedProgram attribute across run_decompositions()."""
10275
_preserved_ep_attrs.append(name)
10376

10477

10578
def register_export_context(fn: Callable) -> None:
106-
"""Register a context factory ``fn(model, inputs)`` wrapping torch.export.export.
107-
108-
Called by extensions at import time so that custom state can be active
109-
during the export step without modifying the tracer directly.
110-
"""
79+
"""Register a context factory ``fn(model, inputs)`` wrapping torch.export.export."""
11180
_export_context_factories.append(fn)
11281

11382

11483
def register_post_trace_hook(fn: Callable) -> None:
115-
"""Register a hook called after torch.export and before TRT compilation.
116-
117-
``fn(exported_program, inputs)`` may return a new ExportedProgram to
118-
replace the current one, or ``None`` to leave it unchanged.
119-
"""
84+
"""Register a hook called after torch.export and before TRT compilation."""
12085
_post_trace_hooks.append(fn)
12186

12287

@@ -174,6 +139,7 @@ def cross_compile_for_windows(
174139
enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING,
175140
cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET,
176141
dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES,
142+
decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION,
177143
**kwargs: Any,
178144
) -> torch.fx.GraphModule:
179145
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -251,6 +217,7 @@ def cross_compile_for_windows(
251217
enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited.
252218
cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
253219
dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
220+
decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
254221
**kwargs: Any,
255222
Returns:
256223
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -409,6 +376,7 @@ def cross_compile_for_windows(
409376
"enable_resource_partitioning": enable_resource_partitioning,
410377
"cpu_memory_budget": cpu_memory_budget,
411378
"dynamically_allocate_resources": dynamically_allocate_resources,
379+
"decompose_attention": decompose_attention,
412380
}
413381

414382
# disable the following settings is not supported for cross compilation for windows feature
@@ -428,16 +396,10 @@ def cross_compile_for_windows(
428396

429397
settings = CompilationSettings(**compilation_options)
430398
logger.info("Compilation Settings: %s\n", settings)
431-
# Preserve custom EP attributes across run_decompositions() — the returned
432-
# EP is a fresh object and does not carry any attributes set by extensions.
433-
_saved_ep_attrs = {
434-
k: getattr(exported_program, k)
435-
for k in _preserved_ep_attrs
436-
if hasattr(exported_program, k)
437-
}
438399
exported_program = pre_export_lowering(exported_program, settings)
400+
_saved_ep_attrs = {k: getattr(exported_program, k) for k in _preserved_ep_attrs if hasattr(exported_program, k)}
439401
exported_program = exported_program.run_decompositions(
440-
get_decompositions(enable_experimental_decompositions)
402+
get_decompositions(enable_experimental_decompositions, decompose_attention)
441403
)
442404
for k, v in _saved_ep_attrs.items():
443405
try:
@@ -448,23 +410,23 @@ def cross_compile_for_windows(
448410
gm = exported_program.module()
449411
logger.debug("Input graph: " + str(gm.graph))
450412

451-
# Apply lowering on the graph module
413+
# Apply lowering on the graph module. Note: constant_fold runs inside post_lowering and requires
414+
# module parameters to still be on GPU, so we must not deallocate before this call.
452415
gm = post_lowering(gm, settings)
416+
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
453417
logger.debug("Lowered Input graph: " + str(gm.graph))
454418
for _pass in _compile_passes:
455-
_result = _pass(
456-
exported_program=exported_program,
457-
fx_module=gm,
458-
logger=logger,
459-
)
419+
_result = _pass(exported_program=exported_program, fx_module=gm, logger=logger)
460420
if isinstance(_result, tuple):
461421
exported_program, gm = _result
422+
462423
# Move the weights in the state_dict to CPU
463424
if offload_module_to_cpu:
464-
deallocate_module(exported_program.module(), delete_module=False)
425+
deallocate_module(gm)
465426
logger.info(
466427
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
467428
)
429+
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB")
468430
else:
469431
remaining_memory, total_memory = torch.cuda.mem_get_info()
470432
if remaining_memory < total_memory // 2:
@@ -546,6 +508,7 @@ def compile(
546508
cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET,
547509
enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING,
548510
dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES,
511+
decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION,
549512
profiling_verbosity: Optional[Any] = None,
550513
**kwargs: Any,
551514
) -> torch.fx.GraphModule:
@@ -634,6 +597,7 @@ def compile(
634597
enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited.
635598
cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
636599
dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
600+
decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
637601
**kwargs: Any,
638602
Returns:
639603
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -648,7 +612,7 @@ def compile(
648612

649613
if not kwargs.get("use_explicit_typing", False):
650614
warnings.warn(
651-
"`use_explicit_typing` is deprecated. This setting will be removed and you should enable autocast instead.",
615+
"`use_explicit_typing` is deprecated. use_explicit_types is now on by default, this setting will be removed and you should enable autocast to recover weak typing behavior.",
652616
DeprecationWarning,
653617
stacklevel=2,
654618
)
@@ -837,21 +801,17 @@ def compile(
837801
"enable_resource_partitioning": enable_resource_partitioning,
838802
"cpu_memory_budget": cpu_memory_budget,
839803
"dynamically_allocate_resources": dynamically_allocate_resources,
840-
"profiling_verbosity": profiling_verbosity,
804+
"decompose_attention": decompose_attention,
841805
}
806+
if profiling_verbosity is not None:
807+
compilation_options["profiling_verbosity"] = profiling_verbosity
842808
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
843809
settings = CompilationSettings(**compilation_options)
844810
logger.info("Compilation Settings: %s\n", settings)
845-
# Preserve custom EP attributes across run_decompositions() — the returned
846-
# EP is a fresh object and does not carry any attributes set by extensions.
847-
_saved_ep_attrs = {
848-
k: getattr(exported_program, k)
849-
for k in _preserved_ep_attrs
850-
if hasattr(exported_program, k)
851-
}
852811
exported_program = pre_export_lowering(exported_program, settings)
812+
_saved_ep_attrs = {k: getattr(exported_program, k) for k in _preserved_ep_attrs if hasattr(exported_program, k)}
853813
exported_program = exported_program.run_decompositions(
854-
get_decompositions(enable_experimental_decompositions)
814+
get_decompositions(enable_experimental_decompositions, decompose_attention)
855815
)
856816
for k, v in _saved_ep_attrs.items():
857817
try:
@@ -863,24 +823,19 @@ def compile(
863823
# Move the weights in the state_dict to CPU
864824
logger.debug("Input graph: " + str(gm.graph))
865825

866-
# Apply lowering on the graph module
826+
# Apply lowering on the graph module. Note: constant_fold runs inside post_lowering and requires
827+
# module parameters to still be on GPU, so we must not deallocate before this call.
867828
gm = post_lowering(gm, settings)
868829
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
869830
logger.debug("Lowered Input graph: " + str(gm.graph))
870-
871831
for _pass in _compile_passes:
872-
_result = _pass(
873-
exported_program=exported_program,
874-
fx_module=gm,
875-
logger=logger,
876-
)
832+
_result = _pass(exported_program=exported_program, fx_module=gm, logger=logger)
877833
if isinstance(_result, tuple):
878834
exported_program, gm = _result
879835

880836
# Move the weights in the state_dict to CPU
881837
if offload_module_to_cpu:
882-
deallocate_module(gm, delete_module=False)
883-
deallocate_module(exported_program.module(), delete_module=False)
838+
deallocate_module(gm)
884839
logger.info(
885840
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
886841
)
@@ -892,11 +847,7 @@ def compile(
892847
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
893848
)
894849
trt_gm = compile_module(
895-
gm,
896-
trt_arg_inputs,
897-
trt_kwarg_inputs,
898-
settings,
899-
engine_cache,
850+
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
900851
)
901852
return trt_gm
902853

@@ -1152,7 +1103,6 @@ def preserve_module_specs(
11521103
trt_modules[name] = trt_module
11531104

11541105
if _debugger_config:
1155-
11561106
if _debugger_config.save_engine_profile:
11571107
if settings.use_python_runtime:
11581108
if _debugger_config.profile_format != "cudagraph":
@@ -1267,6 +1217,7 @@ def convert_exported_program_to_serialized_trt_engine(
12671217
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
12681218
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
12691219
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
1220+
decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION,
12701221
**kwargs: Any,
12711222
) -> bytes:
12721223
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1341,6 +1292,7 @@ def convert_exported_program_to_serialized_trt_engine(
13411292
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
13421293
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
13431294
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model.
1295+
decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
13441296
**kwargs: Any,
13451297
Returns:
13461298
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
@@ -1510,20 +1462,15 @@ def convert_exported_program_to_serialized_trt_engine(
15101462
"l2_limit_for_tiling": l2_limit_for_tiling,
15111463
"offload_module_to_cpu": offload_module_to_cpu,
15121464
"use_distributed_mode_trace": use_distributed_mode_trace,
1465+
"decompose_attention": decompose_attention,
15131466
}
15141467

15151468
settings = CompilationSettings(**compilation_options)
15161469
logger.info("Compilation Settings: %s\n", settings)
1517-
# Preserve custom EP attributes across run_decompositions() — the returned
1518-
# EP is a fresh object and does not carry any attributes set by extensions.
1519-
_saved_ep_attrs = {
1520-
k: getattr(exported_program, k)
1521-
for k in _preserved_ep_attrs
1522-
if hasattr(exported_program, k)
1523-
}
15241470
exported_program = pre_export_lowering(exported_program, settings)
1471+
_saved_ep_attrs = {k: getattr(exported_program, k) for k in _preserved_ep_attrs if hasattr(exported_program, k)}
15251472
exported_program = exported_program.run_decompositions(
1526-
get_decompositions(enable_experimental_decompositions)
1473+
get_decompositions(enable_experimental_decompositions, decompose_attention)
15271474
)
15281475
for k, v in _saved_ep_attrs.items():
15291476
try:
@@ -1539,17 +1486,13 @@ def convert_exported_program_to_serialized_trt_engine(
15391486
gm = post_lowering(gm, settings)
15401487
logger.debug("Lowered Input graph: " + str(gm.graph))
15411488
for _pass in _compile_passes:
1542-
_result = _pass(
1543-
exported_program=exported_program,
1544-
fx_module=gm,
1545-
logger=logger,
1546-
)
1489+
_result = _pass(exported_program=exported_program, fx_module=gm, logger=logger)
15471490
if isinstance(_result, tuple):
15481491
exported_program, gm = _result
15491492

15501493
# Move the weights in the state_dict to CPU
15511494
if offload_module_to_cpu:
1552-
deallocate_module(exported_program.module(), delete_module=False)
1495+
deallocate_module(exported_program.module())
15531496
logger.info(
15541497
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
15551498
)
@@ -1560,6 +1503,9 @@ def convert_exported_program_to_serialized_trt_engine(
15601503
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
15611504
)
15621505

1506+
if trt_kwarg_inputs is None:
1507+
trt_kwarg_inputs = {}
1508+
15631509
flattened_input_list = get_flat_args_with_check(
15641510
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
15651511
)[0]
@@ -1571,16 +1517,20 @@ def convert_exported_program_to_serialized_trt_engine(
15711517
settings=settings,
15721518
engine_cache=engine_cache,
15731519
)
1574-
except UnsupportedOperatorException:
1520+
except UnsupportedOperatorException as e:
15751521
logger.error(
15761522
f"Conversion of module {gm} not currently fully supported or convertible!",
15771523
exc_info=True,
15781524
)
1525+
raise UnsupportedOperatorException(
1526+
f"Conversion of module {gm} not currently fully supported or convertible!"
1527+
) from e
15791528
except Exception as e:
15801529
logger.error(
15811530
f"While interpreting the module got an error: {e}",
15821531
exc_info=True,
15831532
)
1533+
raise RuntimeError(f"While interpreting the module got an error: {e}") from e
15841534

15851535
serialized_engine: bytes = interpreter_result.serialized_engine
15861536
return serialized_engine

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
AUTOCAST_MAX_OUTPUT_THRESHOLD,
1717
CACHE_BUILT_ENGINES,
1818
CPU_MEMORY_BUDGET,
19+
DECOMPOSE_ATTENTION,
1920
DISABLE_TF32,
2021
DLA_GLOBAL_DRAM_SIZE,
2122
DLA_LOCAL_DRAM_SIZE,
2223
DLA_SRAM_SIZE,
23-
DECOMPOSE_ATTENTION,
24-
DYNAMICALLY_ALLOCATE_RESOURCES,
2524
DRYRUN,
25+
DYNAMICALLY_ALLOCATE_RESOURCES,
2626
ENABLE_AUTOCAST,
2727
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
2828
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,

py/torch_tensorrt/dynamo/_tracer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
from contextlib import ExitStack
45
from inspect import signature
56
from typing import Any, Optional, Tuple, Union
67

@@ -76,8 +77,6 @@ def trace(
7677
# Constructing dynamic shape list as a nested dict
7778
dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs)
7879
dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
79-
# Apply any registered export-context factories.
80-
from contextlib import ExitStack
8180
from torch_tensorrt.dynamo._compiler import _export_context_factories
8281
with ExitStack() as _stack:
8382
for _factory in _export_context_factories:

0 commit comments

Comments
 (0)