|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -import time |
15 | 14 | from typing import Any, Dict, List, Tuple |
16 | 15 |
|
17 | 16 | import numpy as np |
18 | 17 |
|
19 | 18 | from ...inference_engine.base import get_engine |
20 | | -from ...utils.typings import EngineType, ModelType |
21 | | -from ..utils import wrap_with_html_struct |
| 19 | +from ...utils.typings import EngineType |
22 | 20 | from .post_process import TableLabelDecode |
23 | 21 | from .pre_process import TablePreprocess |
24 | 22 |
|
@@ -47,71 +45,3 @@ def __call__( |
47 | 45 | bbox_preds, struct_probs, shape_lists, ori_imgs |
48 | 46 | ) |
49 | 47 | return table_structs, cell_bboxes |
50 | | - |
51 | | - def batch_process( |
52 | | - self, img_list: List[np.ndarray] |
53 | | - ) -> List[Tuple[List[str], np.ndarray, float]]: |
54 | | - """批量处理图像列表 |
55 | | - Args: |
56 | | - img_list: 图像列表 |
57 | | -
|
58 | | - Returns: |
59 | | - 结果列表,每个元素包含 (table_struct_str, cell_bboxes, elapse) |
60 | | - """ |
61 | | - starttime = time.perf_counter() |
62 | | - |
63 | | - batch_data = self.batch_preprocess_op(img_list) |
64 | | - |
65 | | - preprocessed_images = batch_data[0] |
66 | | - shape_lists = batch_data[1] |
67 | | - preprocessed_images = np.array(preprocessed_images) |
68 | | - |
69 | | - bbox_preds, struct_probs = self.session(preprocessed_images) |
70 | | - |
71 | | - batch_size = preprocessed_images.shape[0] |
72 | | - results = [] |
73 | | - |
74 | | - for i in range(batch_size): |
75 | | - single_bbox_preds = bbox_preds[i : i + 1] |
76 | | - single_struct_probs = struct_probs[i : i + 1] |
77 | | - single_shape_list = np.array([shape_lists[i]]) |
78 | | - |
79 | | - post_result = self.postprocess_op( |
80 | | - single_bbox_preds, single_struct_probs, [single_shape_list] |
81 | | - ) |
82 | | - |
83 | | - table_struct_str = wrap_with_html_struct( |
84 | | - post_result["structure_batch_list"][0][0] |
85 | | - ) |
86 | | - cell_bboxes = post_result["bbox_batch_list"][0] |
87 | | - |
88 | | - if self.cfg["model_type"] == ModelType.SLANETPLUS: |
89 | | - cell_bboxes = self.rescale_cell_bboxes(img_list[i], cell_bboxes) |
90 | | - |
91 | | - cell_bboxes = self.filter_blank_bbox(cell_bboxes) |
92 | | - |
93 | | - results.append((table_struct_str, cell_bboxes, 0)) |
94 | | - |
95 | | - total_elapse = time.perf_counter() - starttime |
96 | | - for i in range(len(results)): |
97 | | - results[i] = (results[i][0], results[i][1], total_elapse / batch_size) |
98 | | - |
99 | | - return results |
100 | | - |
101 | | - def rescale_cell_bboxes( |
102 | | - self, img: np.ndarray, cell_bboxes: np.ndarray |
103 | | - ) -> np.ndarray: |
104 | | - h, w = img.shape[:2] |
105 | | - resized = 488 |
106 | | - ratio = min(resized / h, resized / w) |
107 | | - w_ratio = resized / (w * ratio) |
108 | | - h_ratio = resized / (h * ratio) |
109 | | - cell_bboxes[:, 0::2] *= w_ratio |
110 | | - cell_bboxes[:, 1::2] *= h_ratio |
111 | | - return cell_bboxes |
112 | | - |
113 | | - @staticmethod |
114 | | - def filter_blank_bbox(cell_bboxes: np.ndarray) -> np.ndarray: |
115 | | - # 过滤掉占位的bbox |
116 | | - mask = ~np.all(cell_bboxes == 0, axis=1) |
117 | | - return cell_bboxes[mask] |
0 commit comments