Skip to content

Commit 336eb4e

Browse files
authored
🎉 Release FFDetr (#26)
* add ffdetr inference code * FFDetr inference * test update to pyproject * test update to pyproject * add tests for ffdetr vs. ffdnet
1 parent 920aa10 commit 336eb4e

5 files changed

Lines changed: 126 additions & 16 deletions

File tree

‎.github/workflows/ci.yml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
.venv
3535
key: uv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml', 'uv.lock') }}
3636
restore-keys: |
37-
uv-${{ runner.os }}-${{ matrix.python-version }}-
37+
uv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml', 'uv.lock') }}
3838
3939
- name: Install (uv)
4040
run: |

‎commonforms/__main__.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def main():
5858

5959
args = parser.parse_args()
6060

61+
print(f"**{args.confidence=}")
6162
prepare_form(
6263
args.input,
6364
args.output,

‎commonforms/inference.py‎

Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,109 @@
22
from ultralytics import YOLO
33
from pathlib import Path
44
from huggingface_hub import hf_hub_download
5+
from rfdetr import RFDETRNano, RFDETRBase, RFDETRMedium, RFDETRLarge
56

67
from commonforms.utils import BoundingBox, Page, Widget
78
from commonforms.form_creator import PyPdfFormCreator
89
from commonforms.exceptions import EncryptedPdfError
910

1011
import formalpdf
1112
import pypdfium2
13+
import logging
14+
import PIL
1215

1316

14-
# our mapping from (model_name, fast) to (repo_id, filename) for the huggingface hub
17+
logging.basicConfig(level=logging.INFO)
18+
19+
20+
# our mapping from (model_name_upper, fast) to (repo_id, filename) for the huggingface hub.
21+
# keeping it simple and declarative like this becuase it's not like we're adding a bunch
22+
# of models.
1523
models = {
1624
("FFDNET-S", True): ("jbarrow/FFDNet-S-cpu", "FFDNet-S.onnx"),
1725
("FFDNET-S", False): ("jbarrow/FFDNet-S", "FFDNet-S.pt"),
1826
("FFDNET-L", True): ("jbarrow/FFDNet-L-cpu", "FFDNet-L.onnx"),
1927
("FFDNET-L", False): ("jbarrow/FFDNet-L", "FFDNet-L.pt"),
28+
("FFDETR", False): ("jbarrow/FFDetr", "FFDetr.pth"),
2029
}
2130

2231

32+
def batch(lst: list, n: int = 8):
33+
l = len(lst)
34+
for ndx in range(0, l, n):
35+
yield lst[ndx : min(ndx + n, l)]
36+
37+
38+
class FFDetrDetector:
39+
def __init__(self, model_or_path: str, device: int | str = "cpu") -> None:
40+
self.device = device
41+
self.model = RFDETRMedium(pretrain_weights=self.get_model_path(model_or_path))
42+
43+
self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"}
44+
45+
def get_model_path(self, model_or_path: str) -> str:
46+
model_upper = model_or_path.upper()
47+
if model_upper in ["FFDETR"]:
48+
# download the model, will just use the cached version if it already exists
49+
repo_id, filename = models[(model_upper, False)]
50+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
51+
else:
52+
model_path = model_or_path
53+
54+
return model_path
55+
56+
def resize(
57+
self,
58+
image: PIL.Image.Image,
59+
size: tuple[int, int] | int,
60+
) -> PIL.Image.Image:
61+
if isinstance(size, int):
62+
size = (size, size)
63+
64+
return image.resize(size, PIL.Image.Resampling.LANCZOS)
65+
66+
def extract_widgets(
67+
self,
68+
pages: list[Page],
69+
confidence: float = 0.4,
70+
image_size: int = 1120,
71+
batch_size: int = 3,
72+
) -> dict[int, list[Widget]]:
73+
image_size = 1024
74+
results = []
75+
for b in batch([self.resize(p.image, image_size) for p in pages], n=batch_size):
76+
predictions = self.model.predict(b, threshold=confidence)
77+
if len(pages) == 1 or batch_size == 1:
78+
predictions = [predictions]
79+
results.extend(predictions)
80+
81+
widgets = {}
82+
83+
for page_ix, detections in enumerate(results):
84+
logging.info(f" Page {page_ix}: {len(detections)} fields detected")
85+
detections = detections.with_nms(threshold=0.1, class_agnostic=True)
86+
logging.info(f"\t\t{len(detections)} after nms")
87+
widgets[page_ix] = []
88+
89+
for class_id, box in zip(detections.class_id, detections.xyxy):
90+
x0, x1 = box[[0, 2]] / pages[page_ix].image.width
91+
y0, y1 = box[[1, 3]] / pages[page_ix].image.height
92+
93+
widget_type = self.id_to_cls[class_id]
94+
95+
widgets[page_ix].append(
96+
Widget(
97+
widget_type=widget_type,
98+
bounding_box=BoundingBox(x0=x0, y0=y0, x1=x1, y1=y1),
99+
page=page_ix,
100+
)
101+
)
102+
103+
widgets[page_ix] = sort_widgets(widgets[page_ix])
104+
105+
return widgets
106+
107+
23108
class FFDNetDetector:
24109
def __init__(
25110
self, model_or_path: str, device: int | str = "cpu", fast: bool = False
@@ -43,8 +128,8 @@ def get_model_path(
43128
model_upper = model_or_path.upper()
44129
if model_upper in ["FFDNET-S", "FFDNET-L"]:
45130
# download the model, will just use the cached version if it already exists
46-
repo_id, filename = models[(model_upper, fast)]
47-
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
131+
repo_id, filename = models[(model_upper, fast)]
132+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
48133
else:
49134
model_path = model_or_path
50135

@@ -148,7 +233,7 @@ def render_pdf(pdf_path: str) -> list[Page]:
148233
doc = formalpdf.open(pdf_path)
149234
try:
150235
for page in doc:
151-
image = page.render()
236+
image = page.render(dpi=144)
152237
pages.append(Page(image=image, width=image.width, height=image.height))
153238
return pages
154239
finally:
@@ -159,16 +244,20 @@ def prepare_form(
159244
input_path: str | Path,
160245
output_path: str | Path,
161246
*,
162-
model_or_path: str = "FFDNet-L",
247+
model_or_path: str = "FFDetr",
163248
keep_existing_fields: bool = False,
164249
use_signature_fields: bool = False,
165250
device: int | str = "cpu",
166-
image_size: int = 1600,
167-
confidence: float = 0.3,
251+
image_size: int = 1024,
252+
confidence: float = 0.4,
168253
fast: bool = False,
169254
multiline: bool = False,
255+
batch_size: int = 4,
170256
):
171-
detector = FFDNetDetector(model_or_path, device=device, fast=fast)
257+
if "FFDNET" in model_or_path.upper():
258+
detector = FFDNetDetector(model_or_path, device=device, fast=fast)
259+
else:
260+
detector = FFDetrDetector(model_or_path)
172261

173262
try:
174263
pages = render_pdf(input_path)
@@ -188,7 +277,9 @@ def prepare_form(
188277
name = f"{widget.widget_type.lower()}_{widget.page}_{i}"
189278

190279
if widget.widget_type == "TextBox":
191-
writer.add_text_box(name, page_ix, widget.bounding_box, multiline=multiline)
280+
writer.add_text_box(
281+
name, page_ix, widget.bounding_box, multiline=multiline
282+
)
192283
elif widget.widget_type == "ChoiceButton":
193284
writer.add_checkbox(name, page_ix, widget.bounding_box)
194285
elif widget.widget_type == "Signature":

‎pyproject.toml‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ dependencies = [
1919
"pillow>=11.3.0",
2020
"pydantic>=2.11.9",
2121
"pypdf>=6.1.1",
22+
"rfdetr>=1.3.0",
23+
"transformers>=4.57",
2224
"ultralytics>=8.3.204",
2325
]
2426

‎tests/inference_test.py‎

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
def test_inference(tmp_path):
99
# tmp_path is a built-in pythest fixture where we'll write the outputs
1010
output_path = tmp_path / "output.pdf"
11-
commonforms.prepare_form("./tests/resources/input.pdf", output_path)
11+
commonforms.prepare_form("./tests/resources/input.pdf", output_path, model_or_path="FFDetr")
1212

1313
assert output_path.exists()
1414

@@ -20,7 +20,7 @@ def test_inference(tmp_path):
2020

2121
def test_inference_fast(tmp_path):
2222
output_path = tmp_path / "output.pdf"
23-
commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True)
23+
commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True, model_or_path="FFDNet-L")
2424

2525
assert output_path.exists()
2626

@@ -32,7 +32,9 @@ def test_inference_fast(tmp_path):
3232

3333
def test_mutlinline(tmp_path):
3434
output_path = tmp_path / "output.pdf"
35-
commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True, multiline=True)
35+
commonforms.prepare_form(
36+
"./tests/resources/input.pdf", output_path, fast=True, multiline=True
37+
)
3638

3739
assert output_path.exists()
3840

@@ -42,7 +44,6 @@ def test_mutlinline(tmp_path):
4244
doc.document.close()
4345

4446

45-
4647
def test_encrypted_failure(tmp_path):
4748
# Reminder to future Joe: password for encrypted PDF is "kanbanery"
4849
output_path = tmp_path / "output.pdf"
@@ -51,7 +52,22 @@ def test_encrypted_failure(tmp_path):
5152
commonforms.prepare_form("./tests/resources/encrypted.pdf", output_path)
5253

5354

55+
def test_inference_ffdetr(tmp_path):
56+
# tmp_path is a built-in pythest fixture where we'll write the outputs
57+
output_path = tmp_path / "output.pdf"
58+
commonforms.prepare_form(
59+
"./tests/resources/input.pdf", output_path, model_or_path="FFDetr"
60+
)
61+
62+
assert output_path.exists()
63+
64+
doc = formalpdf.open(output_path)
65+
assert len(doc[0].widgets()) > 0
66+
67+
doc.document.close()
68+
69+
5470
# TODO(joe): future tests around handling encrypted PDFs
5571
# 1. add a --password flag and test that inference doesn't fail
56-
# 2. if a password is provided, ensure that the _output_ PDF remains encrpyted
57-
# with the same password
72+
# 2. if a password is provided, ensure that the _output_ PDF remains encrpyted
73+
# with the same password

0 commit comments

Comments
 (0)