Skip to content

Commit 6bb3e10

Browse files
committed
chore: update files
1 parent 732c29e commit 6bb3e10

3 files changed

Lines changed: 10 additions & 3 deletions

File tree

demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
table_engine = RapidTable(input_args)
1414

1515
img_list = list(Path("images").iterdir())
16-
results = table_engine(img_path, batch_size=3)
16+
results = table_engine(img_list, batch_size=3)
1717
results.vis(save_dir="outputs", save_name="vis", indexes=(0, 1, 2))

rapid_table/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,14 @@ def get_ocr_results(
126126
batch_dt_boxes, batch_rec_res = [], []
127127

128128
if ocr_results is not None:
129-
for img, ocr_result in zip(imgs, ocr_results[start_i:end_i]):
129+
ocr_results_batch = ocr_results[start_i:end_i]
130+
if len(ocr_results_batch) != len(imgs):
131+
raise ValueError(
132+
f"Batch size mismatch: {len(imgs)} images but {len(ocr_results_batch)} OCR results "
133+
f"(indices {start_i}:{end_i})."
134+
)
135+
136+
for img, ocr_result in zip(imgs, ocr_results_batch):
130137
img_h, img_w = img.shape[:2]
131138
dt_boxes, rec_res = format_ocr_results(ocr_result, img_h, img_w)
132139
batch_dt_boxes.append(dt_boxes)

rapid_table/table_structure/unitable/unitable_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
748748
logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask)
749749

750750
logits = self.generator(logits)[:, -1, :]
751-
total = set(list(range(logits.shape[-1])))
751+
total = set(range(logits.shape[-1]))
752752
black_list = list(total.difference(set(self.token_white_list)))
753753
logits[..., black_list] = -1e9
754754
probs = F.softmax(logits, dim=-1)

0 commit comments

Comments
 (0)