Skip to content

Commit a4fd085

Browse files
authored
feat: support batch inference both model and refacter code
feat: support batch inference both model and refacter code
2 parents 6da3974 + 6bb3e10 commit a4fd085

22 files changed

Lines changed: 468 additions & 564 deletions

File tree

README.md

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ unitable是来源unitable的transformer模型,精度最高,暂仅支持pytor
2424

2525
### 📅 最近动态
2626

27+
2025-08-29 update: 发布2.1.0,支持batch推理
2728
2025-06-22 update: 发布v2.x,适配rapidocr v3.x \
2829
2025-01-09 update: 发布v1.x,全新接口升级。 \
2930
2024.12.30 update:支持Unitable模型的表格识别,使用pytorch框架 \
@@ -109,11 +110,13 @@ unitable是来源unitable的transformer模型,精度最高,暂仅支持pytor
109110

110111
|`rapid_table`|OCR|
111112
|:---:|:---|
112-
|v0.x|`rapidocr_onnxruntime`|
113-
|v1.0.x|`rapidocr>=2.0.0,<3.0.0`|
114113
|v2.x|`rapidocr>=3.0.0`|
114+
|v1.0.x|`rapidocr>=2.0.0,<3.0.0`|
115+
|v0.x|`rapidocr_onnxruntime`|
115116

116-
由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。其余模型在初始化`RapidTable`类时,会根据`model_type`来自动下载模型到安装包所在`models`目录下。当然也可以通过`RapidTableInput(model_path='')`来指定自己模型路径(`v1.0.x` 参数变量名使用`model_path`, `v2.x` 参数变量名变更为`model_dir_or_path`)。注意仅限于我们现支持的`model_type`
117+
由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。其余模型在初始化`RapidTable`类时,会根据`model_type`来自动下载模型到安装包所在`models`目录下。
118+
119+
当然也可以通过`RapidTableInput(model_path='')`来指定自己模型路径(`v1.0.x` 参数变量名使用`model_path`, `v2.x` 参数变量名变更为`model_dir_or_path`)。注意仅限于我们现支持的`model_type`
117120

118121
> ⚠️注意:`rapid_table>=v1.0.0`之后,不再将`rapidocr`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr`包。
119122
>
@@ -141,69 +144,66 @@ ModelType支持已有的4个模型 ([source](./rapid_table/utils/typings.py)):
141144

142145
```python
143146
class ModelType(Enum):
144-
PPSTRUCTURE_EN = "ppstructure_en"
145-
PPSTRUCTURE_ZH = "ppstructure_zh"
146-
SLANETPLUS = "slanet_plus"
147-
UNITABLE = "unitable"
147+
PPSTRUCTURE_EN = "ppstructure_en" # onnxruntime
148+
PPSTRUCTURE_ZH = "ppstructure_zh" # onnxruntime
149+
SLANETPLUS = "slanet_plus" # onnxruntime
150+
UNITABLE = "unitable" # torch推理引擎
148151
```
149152

150-
##### CPU使用
153+
#### batch_size推理
151154

152155
```python
153-
154-
from rapidocr import RapidOCR
156+
from pathlib import Path
155157
156158
from rapid_table import ModelType, RapidTable, RapidTableInput
157159
158-
ocr_engine = RapidOCR()
160+
input_args = RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH)
161+
table_engine = RapidTable(input_args)
162+
163+
img_list = list(Path("images").iterdir())
164+
results = table_engine(img_path, batch_size=3) # 这里,batch_size默认为1
165+
166+
# indexes:指定可视化的图像索引。默认为0
167+
results.vis(save_dir="outputs", save_name="vis", indexes=(0, 1, 2))
168+
```
169+
170+
##### CPU使用
171+
172+
```python
173+
from rapid_table import ModelType, RapidTable, RapidTableInput
159174
160175
input_args = RapidTableInput(model_type=ModelType.UNITABLE)
161176
table_engine = RapidTable(input_args)
162177
163178
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
164-
165-
# # 使用单字识别
166-
# ori_ocr_res = ocr_engine(img_path, return_word_box=True)
167-
# ocr_results = [
168-
# [word_result[0][2], word_result[0][0], word_result[0][1]]
169-
# for word_result in ori_ocr_res.word_results
170-
# ]
171-
# ocr_results = list(zip(*ocr_results))
172-
173179
ori_ocr_res = ocr_engine(img_path)
174-
ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
175-
results = table_engine(img_path, ocr_results=ocr_results)
180+
results = table_engine(img_path)
176181
results.vis(save_dir="outputs", save_name="vis")
177182
```
178183

179184
##### GPU使用
180185

181-
```python
182-
183-
from rapidocr import RapidOCR
186+
> `engine_cfg`中参数是和[`engine_cfg.yaml`](https://github.com/RapidAI/RapidTable/blob/6da3974a35ac5da8a5cf58194eab00b6886212e8/rapid_table/engine_cfg.yaml)相对应的。
184187

188+
```python
185189
from rapid_table import ModelType, RapidTable, RapidTableInput
186190
187-
ocr_engine = RapidOCR()
188-
189191
# onnxruntime-gpu
190192
input_args = RapidTableInput(
191-
model_type=ModelType.SLANETPLUS, engine_cfg={"use_cuda": True, "gpu_id": 1}
193+
model_type=ModelType.SLANETPLUS,
194+
engine_cfg={"use_cuda": True, "cuda_ep_cfg.gpu_id": 1}
192195
)
193196
194197
# torch gpu
195198
# input_args = RapidTableInput(
196199
# model_type=ModelType.UNITABLE,
197-
# engine_cfg={"use_cuda": True, "cuda_ep_cfg.gpu_id": 1},
200+
# engine_cfg={"use_cuda": True, "gpu_id": 1},
198201
# )
202+
199203
table_engine = RapidTable(input_args)
200204
201205
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
202-
203-
ori_ocr_res = ocr_engine(img_path)
204-
ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
205-
206-
results = table_engine(img_path, ocr_results=ocr_results)
206+
results = table_engine(img_path)
207207
results.vis(save_dir="outputs", save_name="vis")
208208
```
209209

batch_demo.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

demo.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,17 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: liekkaskono@163.com
4-
from rapidocr import EngineType, RapidOCR
4+
from pathlib import Path
55

66
from rapid_table import ModelType, RapidTable, RapidTableInput
77

8-
ocr_engine = RapidOCR(
9-
params={
10-
"Det.engine_type": EngineType.TORCH,
11-
"Cls.engine_type": EngineType.TORCH,
12-
"Rec.engine_type": EngineType.TORCH,
13-
}
14-
)
15-
16-
input_args = RapidTableInput(model_type=ModelType.UNITABLE)
8+
# input_args = RapidTableInput(
9+
# model_type=ModelType.UNITABLE,
10+
# engine_cfg={"use_cuda": True, "gpu_id": 1},
11+
# )
12+
input_args = RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH)
1713
table_engine = RapidTable(input_args)
1814

19-
img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
20-
21-
# # 使用单字识别
22-
# ori_ocr_res = ocr_engine(img_path, return_word_box=True)
23-
# ocr_results = [
24-
# [word_result[0][2], word_result[0][0], word_result[0][1]]
25-
# for word_result in ori_ocr_res.word_results
26-
# ]
27-
# ocr_results = list(zip(*ocr_results))
28-
29-
ori_ocr_res = ocr_engine(img_path)
30-
ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
31-
results = table_engine(img_path, ocr_results=ocr_results)
32-
results.vis(save_dir="outputs", save_name="vis")
15+
img_list = list(Path("images").iterdir())
16+
results = table_engine(img_list, batch_size=3)
17+
results.vis(save_dir="outputs", save_name="vis", indexes=(0, 1, 2))

rapid_table/engine_cfg.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ onnxruntime:
88

99
use_cuda: false
1010
cuda_ep_cfg:
11-
device_id: 0
11+
gpu_id: 0
1212
arena_extend_strategy: "kNextPowerOfTwo"
1313
cudnn_conv_algo_search: "EXHAUSTIVE"
1414
do_copy_in_default_stream: true
@@ -18,7 +18,7 @@ onnxruntime:
1818

1919
use_cann: false
2020
cann_ep_cfg:
21-
device_id: 0
21+
gpu_id: 0
2222
arena_extend_strategy: "kNextPowerOfTwo"
2323
npu_mem_limit: 21474836480 # 20 * 1024 * 1024 * 1024
2424
op_select_impl_mode: "high_performance"

rapid_table/inference_engine/torch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ def __init__(self, cfg) -> None:
2323
self.engine_cfg[cfg["engine_type"].value], cfg["engine_cfg"]
2424
)
2525

26-
self.device = "cpu"
27-
if engine_cfg.use_cuda:
28-
self.device = f"cuda:{engine_cfg.gpu_id}"
26+
self.device = torch.device(
27+
f"cuda:{engine_cfg.gpu_id}"
28+
if torch.cuda.is_available() and engine_cfg.use_cuda
29+
else "cpu"
30+
)
2931

3032
model_info = cfg["model_dir_or_path"]
3133
self.encoder = self._init_model(model_info["encoder"], Encoder)

0 commit comments

Comments
 (0)