Skip to content

Commit 4035db7

Browse files
author
shoumikhin
committed
[executorch][nvidia][tensorrt][11/n] Complete preprocess integration with serialization
Complete preprocess integration with blob serialization for TensorRT engine compilation. Differential Revision: [D93275051](https://our.internmc.facebook.com/intern/diff/D93275051/) [ghstack-poisoned]
1 parent 91c00ae commit 4035db7

2 files changed

Lines changed: 73 additions & 1 deletion

File tree

backends/nvidia/tensorrt/backend.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
lookup_converter,
2626
needs_edge_program,
2727
)
28+
from executorch.backends.nvidia.tensorrt.serialization import (
29+
serialize_blob,
30+
TensorRTBlobMetadata,
31+
TensorRTIOBinding,
32+
)
2833

2934
logger = logging.getLogger(__name__)
3035
logger.setLevel(logging.WARNING)
@@ -100,14 +105,21 @@ def preprocess(
100105
)
101106
_mark_network_outputs(network, output_nodes, input_map)
102107

108+
# Collect I/O bindings from network
109+
io_bindings = _collect_io_bindings(network)
110+
103111
# Configure and build engine
104112
config = _create_builder_config(builder, spec, trt)
105113
serialized_engine = builder.build_serialized_network(network, config)
106114

107115
if serialized_engine is None:
108116
raise RuntimeError("Failed to build TensorRT engine")
109117

110-
return PreprocessResult(processed_bytes=bytes(serialized_engine))
118+
# Serialize with metadata
119+
metadata = TensorRTBlobMetadata(io_bindings=io_bindings)
120+
blob = serialize_blob(bytes(serialized_engine), metadata)
121+
122+
return PreprocessResult(processed_bytes=blob)
111123

112124

113125
def _get_input_nodes(
@@ -284,6 +296,65 @@ def _mark_network_outputs(
284296
network.mark_output(output_tensor)
285297

286298

299+
def _trt_dtype_to_string(dtype: Any) -> str:
300+
"""Convert TensorRT DataType to string representation."""
301+
dtype_name = str(dtype)
302+
# dtype looks like "DataType.FLOAT" or "DataType.HALF"
303+
if "." in dtype_name:
304+
dtype_name = dtype_name.split(".")[-1]
305+
306+
dtype_map = {
307+
"FLOAT": "float32",
308+
"HALF": "float16",
309+
"INT8": "int8",
310+
"INT32": "int32",
311+
"INT64": "int64",
312+
"BOOL": "bool",
313+
"UINT8": "uint8",
314+
"FP8": "float8",
315+
"BF16": "bfloat16",
316+
}
317+
return dtype_map.get(dtype_name, "float32")
318+
319+
320+
def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]:
321+
"""Collect I/O binding information from TensorRT network.
322+
323+
Args:
324+
network: TensorRT network definition.
325+
326+
Returns:
327+
List of TensorRTIOBinding with input/output tensor metadata.
328+
"""
329+
bindings = []
330+
331+
# Collect inputs
332+
for i in range(network.num_inputs):
333+
tensor = network.get_input(i)
334+
bindings.append(
335+
TensorRTIOBinding(
336+
name=tensor.name,
337+
dtype=_trt_dtype_to_string(tensor.dtype),
338+
shape=list(tensor.shape),
339+
is_input=True,
340+
)
341+
)
342+
343+
# Collect outputs
344+
for i in range(network.num_outputs):
345+
tensor = network.get_output(i)
346+
bindings.append(
347+
TensorRTIOBinding(
348+
name=tensor.name,
349+
dtype=_trt_dtype_to_string(tensor.dtype),
350+
shape=list(tensor.shape),
351+
is_input=False,
352+
)
353+
)
354+
355+
return bindings
356+
357+
287358
def _create_builder_config(builder: Any, spec: TensorRTCompileSpec, trt: Any) -> Any:
288359
"""Create and configure TensorRT builder config."""
289360
config = builder.create_builder_config()

backends/nvidia/tensorrt/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def define_common_targets():
1919
"//executorch/backends/nvidia/tensorrt:compile_spec",
2020
"//executorch/backends/nvidia/tensorrt:converter_registry",
2121
"//executorch/backends/nvidia/tensorrt:converter_utils",
22+
"//executorch/backends/nvidia/tensorrt:serialization",
2223
"//executorch/backends/nvidia/tensorrt/converters:converters",
2324
"//executorch/exir/backend:backend_details",
2425
],

0 commit comments

Comments
 (0)