|
20 | 20 | import numpy as np |
21 | 21 |
|
22 | 22 | from rapidocr.inference_engine.base import get_engine |
| 23 | +from rapidocr.utils.typings import OCRVersion |
23 | 24 |
|
24 | 25 | from .utils import ClsPostProcess, TextClsOutput |
25 | 26 |
|
| 27 | +CLS_SHAPE_BY_OCR_VERSION = { |
| 28 | + OCRVersion.PPOCRV4: [3, 48, 192], |
| 29 | + OCRVersion.PPOCRV5: [3, 80, 160], |
| 30 | +} |
| 31 | + |
26 | 32 |
|
27 | 33 | class TextClassifier: |
28 | 34 | def __init__(self, cfg: Dict[str, Any]): |
29 | | - self.cls_image_shape = cfg["cls_image_shape"] |
| 35 | + self.cls_image_shape = CLS_SHAPE_BY_OCR_VERSION[cfg["ocr_version"]] |
30 | 36 | self.cls_batch_num = cfg["cls_batch_num"] |
31 | 37 | self.cls_thresh = cfg["cls_thresh"] |
32 | 38 | self.postprocess_op = ClsPostProcess(cfg["label_list"]) |
@@ -83,7 +89,9 @@ def resize_norm_img(self, img: np.ndarray) -> np.ndarray: |
83 | 89 | else: |
84 | 90 | resized_w = int(math.ceil(img_h * ratio)) |
85 | 91 |
|
86 | | - resized_image = cv2.resize(img, (resized_w, img_h)) |
| 92 | + resized_image = cv2.resize( |
| 93 | + img, (resized_w, img_h), interpolation=cv2.INTER_LINEAR |
| 94 | + ) |
87 | 95 | resized_image = resized_image.astype("float32") |
88 | 96 | if img_c == 1: |
89 | 97 | resized_image = resized_image / 255 |
|
0 commit comments