Skip to content

Commit d4bf8f2

Browse files
authored
fix: add missing DEFAULT_MODEL_PATH property (#664)
1 parent d70bbbb commit d4bf8f2

2 files changed

Lines changed: 2 additions & 244 deletions

File tree

python/rapidocr/inference_engine/tensorrt/main.py

Lines changed: 1 addition & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,6 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: liekkaskono@163.com
4-
"""
5-
TensorRT Inference Session for RapidOCR.
6-
7-
This module provides TensorRT-based inference for OCR models, offering
8-
significant performance improvements over ONNX Runtime on NVIDIA GPUs.
9-
10-
Key Features:
11-
- Automatic engine building and caching
12-
- Dynamic shape support via optimization profiles
13-
- Pre-allocated buffers for minimal inference overhead
14-
- Optional pinned memory for faster data transfers
15-
16-
Performance Optimizations:
17-
1. Pre-allocated buffers with max shape (avoids reallocation overhead)
18-
2. Pinned memory for faster CPU-GPU transfers (~2x on discrete GPUs)
19-
3. Persistent CUDA stream (no stream creation per inference)
20-
4. Async memory copies overlapped with computation
21-
22-
Example:
23-
>>> from rapidocr.inference_engine.tensorrt import TRTInferSession
24-
>>> with TRTInferSession(config) as session:
25-
... output = session(input_array)
26-
"""
27-
284
import traceback
295
from pathlib import Path
306
from typing import Any, Dict, List, Optional
@@ -42,38 +18,7 @@
4218

4319

4420
class TRTInferSession(InferSession):
45-
"""TensorRT Inference Session for RapidOCR.
46-
47-
This class provides GPU-accelerated inference using NVIDIA TensorRT.
48-
It manages engine loading/building, memory allocation, and inference
49-
execution with optimizations for minimal latency.
50-
51-
Supports context manager protocol for automatic resource cleanup:
52-
>>> with TRTInferSession(cfg) as session:
53-
... result = session(input_data)
54-
55-
Attributes:
56-
cfg: Configuration dictionary.
57-
engine: TensorRT engine instance.
58-
context: TensorRT execution context.
59-
stream: CUDA stream for async operations.
60-
inputs: List of input buffer objects.
61-
outputs: List of output buffer objects.
62-
"""
63-
6421
def __init__(self, cfg: Dict[str, Any]):
65-
"""Initialize TensorRT inference session.
66-
67-
Args:
68-
cfg: Configuration dictionary containing:
69-
- engine_cfg: TensorRT-specific settings (device_id, precision, etc.)
70-
- model_path: Optional path to custom ONNX model
71-
- task_type, lang_type, etc.: For default model selection
72-
73-
Raises:
74-
AssertionError: If CUDA device setup fails.
75-
RuntimeError: If engine building fails.
76-
"""
7722
self.model_root_dir = Path(cfg.get("model_root_dir"))
7823
if not self.model_root_dir.exists():
7924
raise FileNotFoundError(
@@ -84,17 +29,13 @@ def __init__(self, cfg: Dict[str, Any]):
8429
self.engine_cfg = cfg.get("engine_cfg", {})
8530
self._closed = False
8631

87-
# Initialize CUDA device
8832
self.device_id = self._setup_cuda_device()
8933

90-
# TensorRT logger
9134
self.trt_logger = trt.Logger(trt.Logger.WARNING)
9235

93-
# Get or build engine
9436
engine_path = self._get_engine_path(cfg)
9537
self.engine = self._load_or_build_engine(cfg, engine_path)
9638

97-
# Create execution context
9839
self.context = self.engine.create_execution_context()
9940

10041
# Allocate memory buffers (pre-allocated with max shape)
@@ -115,45 +56,13 @@ def __init__(self, cfg: Dict[str, Any]):
11556
self._requires_square_input = False
11657
self._max_square_size = 2048
11758

118-
# =========================================================================
119-
# Context Manager Protocol
120-
# =========================================================================
121-
12259
def __enter__(self) -> "TRTInferSession":
123-
"""Enter context manager.
124-
125-
Returns:
126-
self: The session instance.
127-
"""
12860
return self
12961

13062
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
131-
"""Exit context manager and cleanup resources.
132-
133-
Args:
134-
exc_type: Exception type if an exception occurred.
135-
exc_val: Exception value if an exception occurred.
136-
exc_tb: Exception traceback if an exception occurred.
137-
"""
13863
self.close()
13964

14065
def close(self) -> None:
141-
"""Explicitly close and cleanup resources.
142-
143-
This method releases all CUDA resources including:
144-
- GPU memory buffers
145-
- CUDA stream
146-
- TensorRT context and engine
147-
148-
Safe to call multiple times.
149-
150-
Example:
151-
>>> session = TRTInferSession(cfg)
152-
>>> try:
153-
... result = session(data)
154-
... finally:
155-
... session.close()
156-
"""
15766
if self._closed:
15867
return
15968

@@ -185,13 +94,8 @@ def close(self) -> None:
18594
logger.debug(f"Error during session close: {e}")
18695

18796
def __del__(self):
188-
"""Destructor - cleanup resources if not already closed."""
18997
self.close()
19098

191-
# =========================================================================
192-
# Inference Methods
193-
# =========================================================================
194-
19599
def __call__(self, input_content: np.ndarray) -> np.ndarray:
196100
"""Run inference on input data.
197101
@@ -244,31 +148,13 @@ def __call__(self, input_content: np.ndarray) -> np.ndarray:
244148
) from e
245149

246150
def _set_input_shape(self, input_content: np.ndarray) -> tuple:
247-
"""Set input shape and get corresponding output shape.
248-
249-
For dynamic shape networks, we need to inform TensorRT of the
250-
actual input shape before each inference.
251-
252-
Args:
253-
input_content: Input array to get shape from.
254-
255-
Returns:
256-
Output shape tuple after setting input shape.
257-
"""
258151
input_name = self.engine.get_tensor_name(0)
259152
self.context.set_input_shape(input_name, input_content.shape)
260153

261154
output_name = self.engine.get_tensor_name(1)
262155
return self.context.get_tensor_shape(output_name)
263156

264157
def _copy_input_to_device(self, input_content: np.ndarray) -> None:
265-
"""Copy input data to GPU asynchronously.
266-
267-
Uses pre-allocated pinned memory buffer for optimal transfer speed.
268-
269-
Args:
270-
input_content: Input data to copy.
271-
"""
272158
input_flat = input_content.ravel()
273159
self.inputs[0].host[: input_flat.size] = input_flat
274160

@@ -285,16 +171,6 @@ def _execute_inference(self) -> None:
285171
self.context.execute_async_v3(stream_handle=self.stream)
286172

287173
def _copy_output_to_host(self, output_shape: tuple) -> np.ndarray:
288-
"""Copy output from GPU to CPU and reshape.
289-
290-
Only copies the actual output size, not the full pre-allocated buffer.
291-
292-
Args:
293-
output_shape: Expected output shape.
294-
295-
Returns:
296-
Output array reshaped to correct dimensions.
297-
"""
298174
output_size = int(np.prod(output_shape))
299175
output_nbytes = output_size * self.outputs[0].host.itemsize
300176

@@ -311,19 +187,7 @@ def _copy_output_to_host(self, output_shape: tuple) -> np.ndarray:
311187

312188
return self.outputs[0].host[:output_size].reshape(output_shape)
313189

314-
# =========================================================================
315-
# MULTI Model Square Input Handling
316-
# =========================================================================
317-
318190
def _check_multi_model(self, cfg: Dict[str, Any]) -> bool:
319-
"""Check if this is MULTI detection model requiring square input.
320-
321-
Args:
322-
cfg: Configuration dictionary.
323-
324-
Returns:
325-
True if this is MULTI detection model, False otherwise.
326-
"""
327191
try:
328192
from ...utils.typings import LangDet, TaskType
329193

@@ -334,11 +198,6 @@ def _check_multi_model(self, cfg: Dict[str, Any]) -> bool:
334198
return False
335199

336200
def _get_max_profile_size(self) -> int:
337-
"""Get maximum dimension from TensorRT optimization profile.
338-
339-
Returns:
340-
Maximum height/width allowed by the optimization profile.
341-
"""
342201
try:
343202
# Try to get from config first
344203
profile_cfg = self.engine_cfg.get("det_profile", {})
@@ -359,16 +218,6 @@ def _get_max_profile_size(self) -> int:
359218
return 2048
360219

361220
def _pad_to_square(self, input_content: np.ndarray) -> tuple:
362-
"""Pad input to square shape for MULTI model.
363-
364-
Args:
365-
input_content: Input array with shape (N, C, H, W).
366-
367-
Returns:
368-
Tuple of (padded_input, original_hw) where:
369-
- padded_input: Square input array (N, C, S, S)
370-
- original_hw: Original (H, W) for later cropping, or None if already square
371-
"""
372221
N, C, H, W = input_content.shape
373222

374223
if H == W:
@@ -426,19 +275,7 @@ def _crop_output(self, output: np.ndarray, original_hw: tuple) -> np.ndarray:
426275
crop_w = int(orig_w * scale_w)
427276
return output[:, :, :crop_h, :crop_w]
428277

429-
# =========================================================================
430-
# Initialization Helpers
431-
# =========================================================================
432-
433278
def _setup_cuda_device(self) -> int:
434-
"""Setup CUDA device for inference.
435-
436-
Returns:
437-
Device ID that was set.
438-
439-
Raises:
440-
AssertionError: If CUDA device setup fails.
441-
"""
442279
device_id = self.engine_cfg.get("device_id", 0)
443280
status_tuple = cudart.cudaSetDevice(device_id)
444281
status = status_tuple[0]
@@ -449,17 +286,6 @@ def _setup_cuda_device(self) -> int:
449286
return device_id
450287

451288
def _get_engine_path(self, cfg: Dict[str, Any]) -> Path:
452-
"""Determine the TensorRT engine file path.
453-
454-
Engine files are cached per GPU architecture and precision setting
455-
to avoid rebuilding on subsequent runs.
456-
457-
Args:
458-
cfg: Configuration dictionary.
459-
460-
Returns:
461-
Path to engine file (may not exist yet).
462-
"""
463289
cache_dir = self.engine_cfg.get("cache_dir")
464290
if cache_dir is None:
465291
cache_dir = self.model_root_dir / "models"
@@ -501,15 +327,6 @@ def _get_gpu_arch(self) -> str:
501327
def _load_or_build_engine(
502328
self, cfg: Dict[str, Any], engine_path: Path
503329
) -> trt.ICudaEngine:
504-
"""Load cached engine or build new one from ONNX.
505-
506-
Args:
507-
cfg: Configuration dictionary.
508-
engine_path: Path where engine should be cached.
509-
510-
Returns:
511-
TensorRT engine instance.
512-
"""
513330
force_rebuild = self.engine_cfg.get("force_rebuild", False)
514331

515332
# Try to load cached engine
@@ -536,11 +353,9 @@ def _load_or_build_engine(
536353
return builder.build()
537354

538355
def _get_onnx_path(self, cfg: Dict[str, Any]) -> Path:
539-
"""Get ONNX model path, downloading if necessary."""
540356
model_path = cfg.get("model_path")
541357

542358
if model_path is None:
543-
# Download default ONNX model
544359
original_engine_type = cfg.engine_type
545360
cfg.engine_type = EngineType.ONNXRUNTIME
546361

@@ -556,7 +371,7 @@ def _get_onnx_path(self, cfg: Dict[str, Any]) -> Path:
556371

557372
cfg.engine_type = original_engine_type
558373

559-
model_path = self.DEFAULT_MODEL_PATH / Path(model_info["model_dir"]).name
374+
model_path = self.model_root_dir / Path(model_info["model_dir"]).name
560375
download_params = DownloadFileInput(
561376
file_url=model_info["model_dir"],
562377
sha256=model_info["SHA256"],
@@ -570,42 +385,19 @@ def _get_onnx_path(self, cfg: Dict[str, Any]) -> Path:
570385
return model_path
571386

572387
def _load_engine(self, engine_path: Path) -> trt.ICudaEngine:
573-
"""Load a serialized TensorRT engine from disk."""
574388
runtime = trt.Runtime(self.trt_logger)
575389
with open(engine_path, "rb") as f:
576390
engine_data = f.read()
577391
return runtime.deserialize_cuda_engine(engine_data)
578392

579-
# =========================================================================
580-
# Interface Methods (required by InferSession)
581-
# =========================================================================
582-
583393
def have_key(self, key: str = "character") -> bool:
584-
"""Check if engine has metadata key.
585-
586-
TensorRT engines don't store custom metadata like ONNX models.
587-
588-
Returns:
589-
Always False for TensorRT engines.
590-
"""
591394
return False
592395

593396
def get_character_list(self, key: str = "character") -> List[str]:
594397
return []
595398

596399
@classmethod
597400
def get_dict_key_url(cls, file_info: FileInfo) -> Optional[str]:
598-
"""Get dictionary URL by falling back to Paddle/ONNX model config.
599-
600-
TensorRT doesn't have entries in default_models.yaml, so we
601-
look up the dictionary URL from Paddle or ONNX configurations.
602-
603-
Args:
604-
file_info: Model file information.
605-
606-
Returns:
607-
Dictionary URL string or None if not found.
608-
"""
609401
# Try Paddle first (usually has dict_url)
610402
for engine_type in [EngineType.PADDLE, EngineType.ONNXRUNTIME]:
611403
try:
@@ -626,6 +418,4 @@ def get_dict_key_url(cls, file_info: FileInfo) -> Optional[str]:
626418

627419

628420
class TensorRTError(Exception):
629-
"""Exception raised for TensorRT inference errors."""
630-
631421
pass

0 commit comments

Comments
 (0)