Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions rapid_videocr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ class RapidVideOCRInput:
is_batch_rec: bool = False
batch_size: int = 10
out_format: str = OutputFormat.ALL.value
ocr_params: Optional[Dict[str, Any]] = None

ocr_params_list: Optional[List[Dict[str, Any]]] = None

class RapidVideOCR:
def __init__(self, input_params: RapidVideOCRInput):
self.logger = Logger(logger_name=__name__).get_log()

self.ocr_processor = OCRProcessor(
input_params.ocr_params, input_params.batch_size
ocr_params_list=input_params.ocr_params_list, batch_size=input_params.batch_size
)

self.cropper = CropByProject()
Expand Down
96 changes: 73 additions & 23 deletions rapid_videocr/ocr_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: SWHL
# @Contact: liekkaskono@163.com
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import cv2
import numpy as np
Expand All @@ -17,13 +17,14 @@
padding_img,
read_img,
)
from collections import defaultdict


class OCRProcessor:
def __init__(self, ocr_params: Optional[Dict] = None, batch_size: int = 10):
def __init__(self, ocr_params_list: Optional[List[Dict[str, Any]]] = None, batch_size: int = 10):
self.logger = Logger(logger_name=__name__).get_log()
self.ocr_engine = self._init_ocr_engine(ocr_params)
self.batch_size = batch_size
self.ocr_params_list = ocr_params_list or [{}]

def _init_ocr_engine(self, ocr_params: Optional[Dict] = None) -> RapidOCR:
return RapidOCR(params=ocr_params)
Expand All @@ -32,6 +33,8 @@ def __call__(
self, img_list: List[Path], is_batch_rec: bool, is_txt_dir: bool
) -> Tuple[List[str], List[str], List[str]]:
self.is_txt_dir = is_txt_dir
self.is_batch_rec = is_batch_rec
self.ocr_engines = [self._init_ocr_engine(p) for p in self.ocr_params_list]
process_func = self.batch_rec if is_batch_rec else self.single_rec
rec_results = process_func(img_list)
srt_results = self._generate_srt_results(rec_results)
Expand All @@ -47,14 +50,23 @@ def single_rec(self, img_list: List[Path]) -> List[Tuple[int, str, str, str]]:
time_str = self._get_srt_timestamp(img_path)
ass_time_str = self._get_ass_timestamp(img_path)
img = self._preprocess_image(img_path)

dt_boxes, rec_res = self.get_ocr_result(img)
txts = (
self.process_same_line(dt_boxes, rec_res)
if dt_boxes is not None
else ""
)
rec_results.append([i, time_str, txts, ass_time_str])
results = self.get_ocr_results(img)
max_txt_len = 0
final_txts = ""

# Iterate over all OCR results from different configs.
for idx, (dt_boxes, rec_res) in enumerate(results):
txts = (
self.process_same_line(dt_boxes, rec_res)
if dt_boxes is not None
else ""
)
# Compare and select the best (longest) recognized text for this image.
if max_txt_len < len(txts):
max_txt_len = len(txts)
final_txts = txts

rec_results.append([i, time_str, final_txts, ass_time_str])
return rec_results

@staticmethod
Expand Down Expand Up @@ -132,20 +144,53 @@ def batch_rec(self, img_list: List[Path]) -> List[Tuple[int, str, str, str]]:

img_nums = len(img_list)
rec_results = []

for start_i in tqdm(range(0, img_nums, self.batch_size), desc="Concat Rec"):
end_i = min(img_nums, start_i + self.batch_size)

concat_img, img_coordinates, img_paths = self._prepare_batch(
img_list[start_i:end_i]
)
dt_boxes, rec_res = self.get_ocr_result(concat_img)
if rec_res is None or dt_boxes is None:
results = self.get_ocr_results(concat_img)

if len(results) == 1:
dt_boxes, rec_res = results[0]
if rec_res is None or dt_boxes is None:
continue
one_batch_rec_results = self._process_batch_results(
start_i, img_coordinates, dt_boxes, rec_res, img_paths
)
rec_results.extend(one_batch_rec_results)
continue

one_batch_rec_results = self._process_batch_results(
start_i, img_coordinates, dt_boxes, rec_res, img_paths
)
rec_results.extend(one_batch_rec_results)
all_batch_results = defaultdict(list)
# Iterate over all OCR results from different configs.
for idx, (dt_boxes, rec_res) in enumerate(results):
if rec_res is None or dt_boxes is None:
continue
one_batch_rec_results = self._process_batch_results(
start_i, img_coordinates, dt_boxes, rec_res, img_paths
)
for i, row in enumerate(one_batch_rec_results):
# row = [cur_frame_idx, time_str, txts, ass_time_str]
all_batch_results[i].append(row)

# Compare and select the best (longest) recognized text for each image.
for i in range(len(img_paths)):
batch_result = all_batch_results[i]
max_txt_len = 0
final_row = None
for row in batch_result:
txts = row[2] # get text
if len(txts) > max_txt_len:
max_txt_len = len(txts)
final_row = row
if final_row is None:
time_str = self._get_srt_timestamp(img_paths[i])
ass_time_str = self._get_ass_timestamp(img_paths[i])
final_row = [start_i + i, time_str, "", ass_time_str]
rec_results.append(final_row)

return rec_results

def _prepare_batch(
Expand Down Expand Up @@ -223,13 +268,18 @@ def _is_box_matched(self, frame_boxes: np.ndarray, dt_box: np.ndarray) -> bool:
box_iou = compute_poly_iou(frame_boxes, dt_box)
return is_inclusive_each_other(frame_boxes, dt_box) or box_iou > 0.1

def get_ocr_result(
def get_ocr_results(
self, img: np.ndarray
) -> Tuple[Optional[np.ndarray], Optional[Tuple[str]]]:
ocr_result = self.ocr_engine(img)
if ocr_result.boxes is None:
return None, None
return ocr_result.boxes, ocr_result.txts
) -> List[Tuple[Optional[np.ndarray], Optional[Tuple[str]]]]:

results = []
for engine in self.ocr_engines:
ocr_result = engine(img)
if ocr_result.boxes is None:
results.append((None, None))
else:
results.append((ocr_result.boxes, ocr_result.txts))
return results

def process_same_line(self, dt_boxes: np.ndarray, rec_res: List[str]) -> str:
if len(rec_res) == 1:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,35 @@ def test_out_only_txt(setup_and_teardown):
txt_data = read_txt(txt_path)
assert len(txt_data) == 8
assert txt_data[-2] == "你们接着善后"

@pytest.mark.parametrize("img_dir", [test_dir / "RGBImages"])
def test_ocr_multi_configs(setup_and_teardown, img_dir):
save_dir, srt_path, ass_path, txt_path = setup_and_teardown

ocr_params_list = [
{
"Det.limit_side_len": 4000,
"Det.limit_type": "max",
},
{
"Det.limit_side_len": 640,
"Det.limit_type": "min",
}
]
input_param = RapidVideOCRInput(is_batch_rec=False, ocr_params_list=ocr_params_list)
extractor = RapidVideOCR(input_param)
extractor(img_dir, save_dir)

srt_data = read_txt(srt_path)
assert len(srt_data) == 16
assert srt_data[2] == "空间里面他绝对赢不了的"
assert srt_data[-2] == "你们接着善后"

ass_data = read_txt(ass_path)
assert len(ass_data) == 17
assert ass_data[13].split(",", 9)[-1] == "空间里面他绝对赢不了的"
assert ass_data[-1].split(",", 9)[-1] == "你们接着善后"

txt_data = read_txt(txt_path)
assert len(txt_data) == 8
assert txt_data[-2] == "你们接着善后"