Skip to content

Commit cf1399b

Browse files
committed
FFDetr inference
1 parent 1457778 commit cf1399b

1 file changed

Lines changed: 34 additions & 15 deletions

File tree

commonforms/inference.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,22 @@
1010

1111
import formalpdf
1212
import pypdfium2
13+
import logging
1314
import PIL
1415

1516

16-
# our mapping from (model_name, fast) to (repo_id, filename) for the huggingface hub
17+
logging.basicConfig(level=logging.INFO)
18+
19+
20+
# our mapping from (model_name_upper, fast) to (repo_id, filename) for the huggingface hub.
21+
# keeping it simple and declarative like this becuase it's not like we're adding a bunch
22+
# of models.
1723
models = {
1824
("FFDNET-S", True): ("jbarrow/FFDNet-S-cpu", "FFDNet-S.onnx"),
1925
("FFDNET-S", False): ("jbarrow/FFDNet-S", "FFDNet-S.pt"),
2026
("FFDNET-L", True): ("jbarrow/FFDNet-L-cpu", "FFDNet-L.onnx"),
2127
("FFDNET-L", False): ("jbarrow/FFDNet-L", "FFDNet-L.pt"),
22-
("FFDetr-Nano", False): ("./models/FFDetr-Nano", "checkpoint_best_ema.pth")
28+
("FFDETR", False): ("jbarrow/FFDetr", "FFDetr.pth")
2329
}
2430

2531

@@ -30,15 +36,25 @@ def batch(lst: list, n: int = 8):
3036
yield lst[ndx:min(ndx + n, l)]
3137

3238

33-
3439
class FFDetrDetector:
3540
def __init__(
3641
self, model_or_path: str, device: int | str = "cpu"
3742
) -> None:
3843
self.device = device
39-
self.model = RFDETRMedium(pretrain_weights=model_or_path, resolution=224*5, num_classes=2)
44+
self.model = RFDETRMedium(pretrain_weights=self.get_model_path(model_or_path))
45+
46+
self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"}
47+
48+
def get_model_path(self, model_or_path: str) -> str:
49+
model_upper = model_or_path.upper()
50+
if model_upper in ["FFDETR"]:
51+
# download the model, will just use the cached version if it already exists
52+
repo_id, filename = models[(model_upper, False)]
53+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
54+
else:
55+
model_path = model_or_path
4056

41-
self.id_to_cls = {0: "TextBox", 1: "ChoiceButton"}
57+
return model_path
4258

4359
def resize(
4460
self,
@@ -51,22 +67,26 @@ def resize(
5167
return image.resize(size, PIL.Image.Resampling.LANCZOS)
5268

5369
def extract_widgets(
54-
self, pages: list[Page], confidence: float = 0.2, image_size: int = 1120
70+
self,
71+
pages: list[Page],
72+
confidence: float = 0.4,
73+
image_size: int = 1120,
74+
batch_size: int = 3,
5575
) -> dict[int, list[Widget]]:
5676
image_size = 1024
5777
results = []
58-
for b in batch([p.image for p in pages], n=1):
59-
results += [self.model.predict(b, threshold=confidence)]
78+
for b in batch([p.image for p in pages], n=batch_size):
79+
predictions = self.model.predict(b, threshold=confidence)
80+
if len(pages) == 1 or batch_size == 1:
81+
predictions = [predictions]
82+
results.extend(predictions)
6083

6184
widgets = {}
6285

63-
if len(pages) == 1:
64-
results = [results]
65-
6686
for page_ix, detections in enumerate(results):
67-
print(f"{page_ix}: {len(detections)} fields detected")
87+
logging.info(f" Page {page_ix}: {len(detections)} fields detected")
6888
detections = detections.with_nms(threshold=0.1, class_agnostic=True)
69-
print(f"{len(detections)} after nms")
89+
logging.info(f"\t\t{len(detections)} after nms")
7090
widgets[page_ix] = []
7191

7292
for class_id, box in zip(detections.class_id, detections.xyxy):
@@ -217,7 +237,6 @@ def render_pdf(pdf_path: str) -> list[Page]:
217237
try:
218238
for page in doc:
219239
image = page.render(dpi=144)
220-
print(image.width, image.height)
221240
pages.append(Page(image=image, width=image.width, height=image.height))
222241
return pages
223242
finally:
@@ -238,7 +257,7 @@ def prepare_form(
238257
multiline: bool = False,
239258
):
240259
# detector = FFDNetDetector(model_or_path, device=device, fast=fast)
241-
detector = FFDetrDetector("./models/FFDetr-Medium/checkpoint_best_ema.pth")
260+
detector = FFDetrDetector("FFDetr")
242261

243262
try:
244263
pages = render_pdf(input_path)

0 commit comments

Comments
 (0)