Skip to content

Commit b73ec51

Browse files
authored
feat: add PP-OCRv5 cls module
1 parent 9023589 commit b73ec51

2 files changed

Lines changed: 11 additions & 3 deletions

File tree

python/rapidocr/ch_ppocr_cls/main.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,19 @@
2020
import numpy as np
2121

2222
from rapidocr.inference_engine.base import get_engine
23+
from rapidocr.utils.typings import OCRVersion
2324

2425
from .utils import ClsPostProcess, TextClsOutput
2526

27+
CLS_SHAPE_BY_OCR_VERSION = {
28+
OCRVersion.PPOCRV4: [3, 48, 192],
29+
OCRVersion.PPOCRV5: [3, 80, 160],
30+
}
31+
2632

2733
class TextClassifier:
2834
def __init__(self, cfg: Dict[str, Any]):
29-
self.cls_image_shape = cfg["cls_image_shape"]
35+
self.cls_image_shape = CLS_SHAPE_BY_OCR_VERSION[cfg["ocr_version"]]
3036
self.cls_batch_num = cfg["cls_batch_num"]
3137
self.cls_thresh = cfg["cls_thresh"]
3238
self.postprocess_op = ClsPostProcess(cfg["label_list"])
@@ -83,7 +89,9 @@ def resize_norm_img(self, img: np.ndarray) -> np.ndarray:
8389
else:
8490
resized_w = int(math.ceil(img_h * ratio))
8591

86-
resized_image = cv2.resize(img, (resized_w, img_h))
92+
resized_image = cv2.resize(
93+
img, (resized_w, img_h), interpolation=cv2.INTER_LINEAR
94+
)
8795
resized_image = resized_image.astype("float32")
8896
if img_c == 1:
8997
resized_image = resized_image / 255

python/rapidocr/inference_engine/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def get_model_url(cls, file_info: FileInfo) -> Dict[str, str]:
136136
return model_dict[k]
137137

138138
for k in model_dict:
139-
if k.startswith(lang_type):
139+
if k.startswith(lang_type) and model_type in k:
140140
return model_dict[k]
141141

142142
logger.error(

0 commit comments

Comments
 (0)