Skip to content

Commit 1457778

Browse files
committed
add ffdetr inference code
1 parent 6907542 commit 1457778

3 files changed

Lines changed: 74 additions & 2 deletions

File tree

commonforms/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def main():
5858

5959
args = parser.parse_args()
6060

61+
print(f"**{args.confidence=}")
6162
prepare_form(
6263
args.input,
6364
args.output,

commonforms/inference.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
from ultralytics import YOLO
33
from pathlib import Path
44
from huggingface_hub import hf_hub_download
5+
from rfdetr import RFDETRNano, RFDETRBase, RFDETRMedium, RFDETRLarge
56

67
from commonforms.utils import BoundingBox, Page, Widget
78
from commonforms.form_creator import PyPdfFormCreator
89
from commonforms.exceptions import EncryptedPdfError
910

1011
import formalpdf
1112
import pypdfium2
13+
import PIL
1214

1315

1416
# our mapping from (model_name, fast) to (repo_id, filename) for the huggingface hub
@@ -17,9 +19,75 @@
1719
("FFDNET-S", False): ("jbarrow/FFDNet-S", "FFDNet-S.pt"),
1820
("FFDNET-L", True): ("jbarrow/FFDNet-L-cpu", "FFDNet-L.onnx"),
1921
("FFDNET-L", False): ("jbarrow/FFDNet-L", "FFDNet-L.pt"),
22+
("FFDetr-Nano", False): ("./models/FFDetr-Nano", "checkpoint_best_ema.pth")
2023
}
2124

2225

26+
27+
def batch(lst: list, n: int = 8):
28+
l = len(lst)
29+
for ndx in range(0, l, n):
30+
yield lst[ndx:min(ndx + n, l)]
31+
32+
33+
34+
class FFDetrDetector:
35+
def __init__(
36+
self, model_or_path: str, device: int | str = "cpu"
37+
) -> None:
38+
self.device = device
39+
self.model = RFDETRMedium(pretrain_weights=model_or_path, resolution=224*5, num_classes=2)
40+
41+
self.id_to_cls = {0: "TextBox", 1: "ChoiceButton"}
42+
43+
def resize(
44+
self,
45+
image: PIL.Image.Image,
46+
size: tuple[int, int] | int,
47+
) -> PIL.Image.Image:
48+
if isinstance(size, int):
49+
size = (size, size)
50+
51+
return image.resize(size, PIL.Image.Resampling.LANCZOS)
52+
53+
def extract_widgets(
54+
self, pages: list[Page], confidence: float = 0.2, image_size: int = 1120
55+
) -> dict[int, list[Widget]]:
56+
image_size = 1024
57+
results = []
58+
for b in batch([p.image for p in pages], n=1):
59+
results += [self.model.predict(b, threshold=confidence)]
60+
61+
widgets = {}
62+
63+
if len(pages) == 1:
64+
results = [results]
65+
66+
for page_ix, detections in enumerate(results):
67+
print(f"{page_ix}: {len(detections)} fields detected")
68+
detections = detections.with_nms(threshold=0.1, class_agnostic=True)
69+
print(f"{len(detections)} after nms")
70+
widgets[page_ix] = []
71+
72+
for class_id, box in zip(detections.class_id, detections.xyxy):
73+
x0, x1 = box[[0, 2]] / pages[page_ix].image.width
74+
y0, y1 = box[[1, 3]] / pages[page_ix].image.height
75+
76+
widget_type = self.id_to_cls[class_id]
77+
78+
widgets[page_ix].append(
79+
Widget(
80+
widget_type=widget_type,
81+
bounding_box=BoundingBox(x0=x0, y0=y0, x1=x1, y1=y1),
82+
page=page_ix,
83+
)
84+
)
85+
86+
widgets[page_ix] = sort_widgets(widgets[page_ix])
87+
88+
return widgets
89+
90+
2391
class FFDNetDetector:
2492
def __init__(
2593
self, model_or_path: str, device: int | str = "cpu", fast: bool = False
@@ -148,7 +216,8 @@ def render_pdf(pdf_path: str) -> list[Page]:
148216
doc = formalpdf.open(pdf_path)
149217
try:
150218
for page in doc:
151-
image = page.render()
219+
image = page.render(dpi=144)
220+
print(image.width, image.height)
152221
pages.append(Page(image=image, width=image.width, height=image.height))
153222
return pages
154223
finally:
@@ -168,7 +237,8 @@ def prepare_form(
168237
fast: bool = False,
169238
multiline: bool = False,
170239
):
171-
detector = FFDNetDetector(model_or_path, device=device, fast=fast)
240+
# detector = FFDNetDetector(model_or_path, device=device, fast=fast)
241+
detector = FFDetrDetector("./models/FFDetr-Medium/checkpoint_best_ema.pth")
172242

173243
try:
174244
pages = render_pdf(input_path)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies = [
1919
"pillow>=11.3.0",
2020
"pydantic>=2.11.9",
2121
"pypdf>=6.1.1",
22+
"rfdetr>=1.3.0",
2223
"ultralytics>=8.3.204",
2324
]
2425

0 commit comments

Comments
 (0)