|
25 | 25 | lookup_converter, |
26 | 26 | needs_edge_program, |
27 | 27 | ) |
| 28 | +from executorch.backends.nvidia.tensorrt.serialization import ( |
| 29 | + serialize_blob, |
| 30 | + TensorRTBlobMetadata, |
| 31 | + TensorRTIOBinding, |
| 32 | +) |
28 | 33 |
|
29 | 34 | logger = logging.getLogger(__name__) |
30 | 35 | logger.setLevel(logging.WARNING) |
@@ -100,14 +105,21 @@ def preprocess( |
100 | 105 | ) |
101 | 106 | _mark_network_outputs(network, output_nodes, input_map) |
102 | 107 |
|
| 108 | + # Collect I/O bindings from network |
| 109 | + io_bindings = _collect_io_bindings(network) |
| 110 | + |
103 | 111 | # Configure and build engine |
104 | 112 | config = _create_builder_config(builder, spec, trt) |
105 | 113 | serialized_engine = builder.build_serialized_network(network, config) |
106 | 114 |
|
107 | 115 | if serialized_engine is None: |
108 | 116 | raise RuntimeError("Failed to build TensorRT engine") |
109 | 117 |
|
110 | | - return PreprocessResult(processed_bytes=bytes(serialized_engine)) |
| 118 | + # Serialize with metadata |
| 119 | + metadata = TensorRTBlobMetadata(io_bindings=io_bindings) |
| 120 | + blob = serialize_blob(bytes(serialized_engine), metadata) |
| 121 | + |
| 122 | + return PreprocessResult(processed_bytes=blob) |
111 | 123 |
|
112 | 124 |
|
113 | 125 | def _get_input_nodes( |
@@ -284,6 +296,65 @@ def _mark_network_outputs( |
284 | 296 | network.mark_output(output_tensor) |
285 | 297 |
|
286 | 298 |
|
| 299 | +def _trt_dtype_to_string(dtype: Any) -> str: |
| 300 | + """Convert TensorRT DataType to string representation.""" |
| 301 | + dtype_name = str(dtype) |
| 302 | + # dtype looks like "DataType.FLOAT" or "DataType.HALF" |
| 303 | + if "." in dtype_name: |
| 304 | + dtype_name = dtype_name.split(".")[-1] |
| 305 | + |
| 306 | + dtype_map = { |
| 307 | + "FLOAT": "float32", |
| 308 | + "HALF": "float16", |
| 309 | + "INT8": "int8", |
| 310 | + "INT32": "int32", |
| 311 | + "INT64": "int64", |
| 312 | + "BOOL": "bool", |
| 313 | + "UINT8": "uint8", |
| 314 | + "FP8": "float8", |
| 315 | + "BF16": "bfloat16", |
| 316 | + } |
| 317 | + return dtype_map.get(dtype_name, "float32") |
| 318 | + |
| 319 | + |
| 320 | +def _collect_io_bindings(network: Any) -> List[TensorRTIOBinding]: |
| 321 | + """Collect I/O binding information from TensorRT network. |
| 322 | +
|
| 323 | + Args: |
| 324 | + network: TensorRT network definition. |
| 325 | +
|
| 326 | + Returns: |
| 327 | + List of TensorRTIOBinding with input/output tensor metadata. |
| 328 | + """ |
| 329 | + bindings = [] |
| 330 | + |
| 331 | + # Collect inputs |
| 332 | + for i in range(network.num_inputs): |
| 333 | + tensor = network.get_input(i) |
| 334 | + bindings.append( |
| 335 | + TensorRTIOBinding( |
| 336 | + name=tensor.name, |
| 337 | + dtype=_trt_dtype_to_string(tensor.dtype), |
| 338 | + shape=list(tensor.shape), |
| 339 | + is_input=True, |
| 340 | + ) |
| 341 | + ) |
| 342 | + |
| 343 | + # Collect outputs |
| 344 | + for i in range(network.num_outputs): |
| 345 | + tensor = network.get_output(i) |
| 346 | + bindings.append( |
| 347 | + TensorRTIOBinding( |
| 348 | + name=tensor.name, |
| 349 | + dtype=_trt_dtype_to_string(tensor.dtype), |
| 350 | + shape=list(tensor.shape), |
| 351 | + is_input=False, |
| 352 | + ) |
| 353 | + ) |
| 354 | + |
| 355 | + return bindings |
| 356 | + |
| 357 | + |
287 | 358 | def _create_builder_config(builder: Any, spec: TensorRTCompileSpec, trt: Any) -> Any: |
288 | 359 | """Create and configure TensorRT builder config.""" |
289 | 360 | config = builder.create_builder_config() |
|
0 commit comments