11# -*- encoding: utf-8 -*-
22# @Author: SWHL
33# @Contact: liekkaskono@163.com
4- """
5- TensorRT Inference Session for RapidOCR.
6-
7- This module provides TensorRT-based inference for OCR models, offering
8- significant performance improvements over ONNX Runtime on NVIDIA GPUs.
9-
10- Key Features:
11- - Automatic engine building and caching
12- - Dynamic shape support via optimization profiles
13- - Pre-allocated buffers for minimal inference overhead
14- - Optional pinned memory for faster data transfers
15-
16- Performance Optimizations:
17- 1. Pre-allocated buffers with max shape (avoids reallocation overhead)
18- 2. Pinned memory for faster CPU-GPU transfers (~2x on discrete GPUs)
19- 3. Persistent CUDA stream (no stream creation per inference)
20- 4. Async memory copies overlapped with computation
21-
22- Example:
23- >>> from rapidocr.inference_engine.tensorrt import TRTInferSession
24- >>> with TRTInferSession(config) as session:
25- ... output = session(input_array)
26- """
27-
284import traceback
295from pathlib import Path
306from typing import Any , Dict , List , Optional
4218
4319
4420class TRTInferSession (InferSession ):
45- """TensorRT Inference Session for RapidOCR.
46-
47- This class provides GPU-accelerated inference using NVIDIA TensorRT.
48- It manages engine loading/building, memory allocation, and inference
49- execution with optimizations for minimal latency.
50-
51- Supports context manager protocol for automatic resource cleanup:
52- >>> with TRTInferSession(cfg) as session:
53- ... result = session(input_data)
54-
55- Attributes:
56- cfg: Configuration dictionary.
57- engine: TensorRT engine instance.
58- context: TensorRT execution context.
59- stream: CUDA stream for async operations.
60- inputs: List of input buffer objects.
61- outputs: List of output buffer objects.
62- """
63-
6421 def __init__ (self , cfg : Dict [str , Any ]):
65- """Initialize TensorRT inference session.
66-
67- Args:
68- cfg: Configuration dictionary containing:
69- - engine_cfg: TensorRT-specific settings (device_id, precision, etc.)
70- - model_path: Optional path to custom ONNX model
71- - task_type, lang_type, etc.: For default model selection
72-
73- Raises:
74- AssertionError: If CUDA device setup fails.
75- RuntimeError: If engine building fails.
76- """
7722 self .model_root_dir = Path (cfg .get ("model_root_dir" ))
7823 if not self .model_root_dir .exists ():
7924 raise FileNotFoundError (
@@ -84,17 +29,13 @@ def __init__(self, cfg: Dict[str, Any]):
8429 self .engine_cfg = cfg .get ("engine_cfg" , {})
8530 self ._closed = False
8631
87- # Initialize CUDA device
8832 self .device_id = self ._setup_cuda_device ()
8933
90- # TensorRT logger
9134 self .trt_logger = trt .Logger (trt .Logger .WARNING )
9235
93- # Get or build engine
9436 engine_path = self ._get_engine_path (cfg )
9537 self .engine = self ._load_or_build_engine (cfg , engine_path )
9638
97- # Create execution context
9839 self .context = self .engine .create_execution_context ()
9940
10041 # Allocate memory buffers (pre-allocated with max shape)
@@ -115,45 +56,13 @@ def __init__(self, cfg: Dict[str, Any]):
11556 self ._requires_square_input = False
11657 self ._max_square_size = 2048
11758
118- # =========================================================================
119- # Context Manager Protocol
120- # =========================================================================
121-
12259 def __enter__ (self ) -> "TRTInferSession" :
123- """Enter context manager.
124-
125- Returns:
126- self: The session instance.
127- """
12860 return self
12961
13062 def __exit__ (self , exc_type , exc_val , exc_tb ) -> None :
131- """Exit context manager and cleanup resources.
132-
133- Args:
134- exc_type: Exception type if an exception occurred.
135- exc_val: Exception value if an exception occurred.
136- exc_tb: Exception traceback if an exception occurred.
137- """
13863 self .close ()
13964
14065 def close (self ) -> None :
141- """Explicitly close and cleanup resources.
142-
143- This method releases all CUDA resources including:
144- - GPU memory buffers
145- - CUDA stream
146- - TensorRT context and engine
147-
148- Safe to call multiple times.
149-
150- Example:
151- >>> session = TRTInferSession(cfg)
152- >>> try:
153- ... result = session(data)
154- ... finally:
155- ... session.close()
156- """
15766 if self ._closed :
15867 return
15968
@@ -185,13 +94,8 @@ def close(self) -> None:
18594 logger .debug (f"Error during session close: { e } " )
18695
18796 def __del__ (self ):
188- """Destructor - cleanup resources if not already closed."""
18997 self .close ()
19098
191- # =========================================================================
192- # Inference Methods
193- # =========================================================================
194-
19599 def __call__ (self , input_content : np .ndarray ) -> np .ndarray :
196100 """Run inference on input data.
197101
@@ -244,31 +148,13 @@ def __call__(self, input_content: np.ndarray) -> np.ndarray:
244148 ) from e
245149
246150 def _set_input_shape (self , input_content : np .ndarray ) -> tuple :
247- """Set input shape and get corresponding output shape.
248-
249- For dynamic shape networks, we need to inform TensorRT of the
250- actual input shape before each inference.
251-
252- Args:
253- input_content: Input array to get shape from.
254-
255- Returns:
256- Output shape tuple after setting input shape.
257- """
258151 input_name = self .engine .get_tensor_name (0 )
259152 self .context .set_input_shape (input_name , input_content .shape )
260153
261154 output_name = self .engine .get_tensor_name (1 )
262155 return self .context .get_tensor_shape (output_name )
263156
264157 def _copy_input_to_device (self , input_content : np .ndarray ) -> None :
265- """Copy input data to GPU asynchronously.
266-
267- Uses pre-allocated pinned memory buffer for optimal transfer speed.
268-
269- Args:
270- input_content: Input data to copy.
271- """
272158 input_flat = input_content .ravel ()
273159 self .inputs [0 ].host [: input_flat .size ] = input_flat
274160
@@ -285,16 +171,6 @@ def _execute_inference(self) -> None:
285171 self .context .execute_async_v3 (stream_handle = self .stream )
286172
287173 def _copy_output_to_host (self , output_shape : tuple ) -> np .ndarray :
288- """Copy output from GPU to CPU and reshape.
289-
290- Only copies the actual output size, not the full pre-allocated buffer.
291-
292- Args:
293- output_shape: Expected output shape.
294-
295- Returns:
296- Output array reshaped to correct dimensions.
297- """
298174 output_size = int (np .prod (output_shape ))
299175 output_nbytes = output_size * self .outputs [0 ].host .itemsize
300176
@@ -311,19 +187,7 @@ def _copy_output_to_host(self, output_shape: tuple) -> np.ndarray:
311187
312188 return self .outputs [0 ].host [:output_size ].reshape (output_shape )
313189
314- # =========================================================================
315- # MULTI Model Square Input Handling
316- # =========================================================================
317-
318190 def _check_multi_model (self , cfg : Dict [str , Any ]) -> bool :
319- """Check if this is MULTI detection model requiring square input.
320-
321- Args:
322- cfg: Configuration dictionary.
323-
324- Returns:
325- True if this is MULTI detection model, False otherwise.
326- """
327191 try :
328192 from ...utils .typings import LangDet , TaskType
329193
@@ -334,11 +198,6 @@ def _check_multi_model(self, cfg: Dict[str, Any]) -> bool:
334198 return False
335199
336200 def _get_max_profile_size (self ) -> int :
337- """Get maximum dimension from TensorRT optimization profile.
338-
339- Returns:
340- Maximum height/width allowed by the optimization profile.
341- """
342201 try :
343202 # Try to get from config first
344203 profile_cfg = self .engine_cfg .get ("det_profile" , {})
@@ -359,16 +218,6 @@ def _get_max_profile_size(self) -> int:
359218 return 2048
360219
361220 def _pad_to_square (self , input_content : np .ndarray ) -> tuple :
362- """Pad input to square shape for MULTI model.
363-
364- Args:
365- input_content: Input array with shape (N, C, H, W).
366-
367- Returns:
368- Tuple of (padded_input, original_hw) where:
369- - padded_input: Square input array (N, C, S, S)
370- - original_hw: Original (H, W) for later cropping, or None if already square
371- """
372221 N , C , H , W = input_content .shape
373222
374223 if H == W :
@@ -426,19 +275,7 @@ def _crop_output(self, output: np.ndarray, original_hw: tuple) -> np.ndarray:
426275 crop_w = int (orig_w * scale_w )
427276 return output [:, :, :crop_h , :crop_w ]
428277
429- # =========================================================================
430- # Initialization Helpers
431- # =========================================================================
432-
433278 def _setup_cuda_device (self ) -> int :
434- """Setup CUDA device for inference.
435-
436- Returns:
437- Device ID that was set.
438-
439- Raises:
440- AssertionError: If CUDA device setup fails.
441- """
442279 device_id = self .engine_cfg .get ("device_id" , 0 )
443280 status_tuple = cudart .cudaSetDevice (device_id )
444281 status = status_tuple [0 ]
@@ -449,17 +286,6 @@ def _setup_cuda_device(self) -> int:
449286 return device_id
450287
451288 def _get_engine_path (self , cfg : Dict [str , Any ]) -> Path :
452- """Determine the TensorRT engine file path.
453-
454- Engine files are cached per GPU architecture and precision setting
455- to avoid rebuilding on subsequent runs.
456-
457- Args:
458- cfg: Configuration dictionary.
459-
460- Returns:
461- Path to engine file (may not exist yet).
462- """
463289 cache_dir = self .engine_cfg .get ("cache_dir" )
464290 if cache_dir is None :
465291 cache_dir = self .model_root_dir / "models"
@@ -501,15 +327,6 @@ def _get_gpu_arch(self) -> str:
501327 def _load_or_build_engine (
502328 self , cfg : Dict [str , Any ], engine_path : Path
503329 ) -> trt .ICudaEngine :
504- """Load cached engine or build new one from ONNX.
505-
506- Args:
507- cfg: Configuration dictionary.
508- engine_path: Path where engine should be cached.
509-
510- Returns:
511- TensorRT engine instance.
512- """
513330 force_rebuild = self .engine_cfg .get ("force_rebuild" , False )
514331
515332 # Try to load cached engine
@@ -536,11 +353,9 @@ def _load_or_build_engine(
536353 return builder .build ()
537354
538355 def _get_onnx_path (self , cfg : Dict [str , Any ]) -> Path :
539- """Get ONNX model path, downloading if necessary."""
540356 model_path = cfg .get ("model_path" )
541357
542358 if model_path is None :
543- # Download default ONNX model
544359 original_engine_type = cfg .engine_type
545360 cfg .engine_type = EngineType .ONNXRUNTIME
546361
@@ -556,7 +371,7 @@ def _get_onnx_path(self, cfg: Dict[str, Any]) -> Path:
556371
557372 cfg .engine_type = original_engine_type
558373
559- model_path = self .DEFAULT_MODEL_PATH / Path (model_info ["model_dir" ]).name
374+ model_path = self .model_root_dir / Path (model_info ["model_dir" ]).name
560375 download_params = DownloadFileInput (
561376 file_url = model_info ["model_dir" ],
562377 sha256 = model_info ["SHA256" ],
@@ -570,42 +385,19 @@ def _get_onnx_path(self, cfg: Dict[str, Any]) -> Path:
570385 return model_path
571386
572387 def _load_engine (self , engine_path : Path ) -> trt .ICudaEngine :
573- """Load a serialized TensorRT engine from disk."""
574388 runtime = trt .Runtime (self .trt_logger )
575389 with open (engine_path , "rb" ) as f :
576390 engine_data = f .read ()
577391 return runtime .deserialize_cuda_engine (engine_data )
578392
579- # =========================================================================
580- # Interface Methods (required by InferSession)
581- # =========================================================================
582-
583393 def have_key (self , key : str = "character" ) -> bool :
584- """Check if engine has metadata key.
585-
586- TensorRT engines don't store custom metadata like ONNX models.
587-
588- Returns:
589- Always False for TensorRT engines.
590- """
591394 return False
592395
593396 def get_character_list (self , key : str = "character" ) -> List [str ]:
594397 return []
595398
596399 @classmethod
597400 def get_dict_key_url (cls , file_info : FileInfo ) -> Optional [str ]:
598- """Get dictionary URL by falling back to Paddle/ONNX model config.
599-
600- TensorRT doesn't have entries in default_models.yaml, so we
601- look up the dictionary URL from Paddle or ONNX configurations.
602-
603- Args:
604- file_info: Model file information.
605-
606- Returns:
607- Dictionary URL string or None if not found.
608- """
609401 # Try Paddle first (usually has dict_url)
610402 for engine_type in [EngineType .PADDLE , EngineType .ONNXRUNTIME ]:
611403 try :
@@ -626,6 +418,4 @@ def get_dict_key_url(cls, file_info: FileInfo) -> Optional[str]:
626418
627419
628420class TensorRTError (Exception ):
629- """Exception raised for TensorRT inference errors."""
630-
631421 pass
0 commit comments