Skip to content

Commit bf1ae45

Browse files
committed
add tests for ffdetr vs. ffdnet
1 parent 4e6497d commit bf1ae45

2 files changed

Lines changed: 41 additions & 23 deletions

File tree

commonforms/inference.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,18 @@
2525
("FFDNET-S", False): ("jbarrow/FFDNet-S", "FFDNet-S.pt"),
2626
("FFDNET-L", True): ("jbarrow/FFDNet-L-cpu", "FFDNet-L.onnx"),
2727
("FFDNET-L", False): ("jbarrow/FFDNet-L", "FFDNet-L.pt"),
28-
("FFDETR", False): ("jbarrow/FFDetr", "FFDetr.pth")
28+
("FFDETR", False): ("jbarrow/FFDetr", "FFDetr.pth"),
2929
}
3030

3131

32-
3332
def batch(lst: list, n: int = 8):
3433
l = len(lst)
3534
for ndx in range(0, l, n):
36-
yield lst[ndx:min(ndx + n, l)]
35+
yield lst[ndx : min(ndx + n, l)]
3736

3837

3938
class FFDetrDetector:
40-
def __init__(
41-
self, model_or_path: str, device: int | str = "cpu"
42-
) -> None:
39+
def __init__(self, model_or_path: str, device: int | str = "cpu") -> None:
4340
self.device = device
4441
self.model = RFDETRMedium(pretrain_weights=self.get_model_path(model_or_path))
4542

@@ -49,8 +46,8 @@ def get_model_path(self, model_or_path: str) -> str:
4946
model_upper = model_or_path.upper()
5047
if model_upper in ["FFDETR"]:
5148
# download the model, will just use the cached version if it already exists
52-
repo_id, filename = models[(model_upper, False)]
53-
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
49+
repo_id, filename = models[(model_upper, False)]
50+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
5451
else:
5552
model_path = model_or_path
5653

@@ -75,7 +72,7 @@ def extract_widgets(
7572
) -> dict[int, list[Widget]]:
7673
image_size = 1024
7774
results = []
78-
for b in batch([p.image for p in pages], n=batch_size):
75+
for b in batch([self.resize(p.image, image_size) for p in pages], n=batch_size):
7976
predictions = self.model.predict(b, threshold=confidence)
8077
if len(pages) == 1 or batch_size == 1:
8178
predictions = [predictions]
@@ -131,8 +128,8 @@ def get_model_path(
131128
model_upper = model_or_path.upper()
132129
if model_upper in ["FFDNET-S", "FFDNET-L"]:
133130
# download the model, will just use the cached version if it already exists
134-
repo_id, filename = models[(model_upper, fast)]
135-
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
131+
repo_id, filename = models[(model_upper, fast)]
132+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
136133
else:
137134
model_path = model_or_path
138135

@@ -247,17 +244,20 @@ def prepare_form(
247244
input_path: str | Path,
248245
output_path: str | Path,
249246
*,
250-
model_or_path: str = "FFDNet-L",
247+
model_or_path: str = "FFDetr",
251248
keep_existing_fields: bool = False,
252249
use_signature_fields: bool = False,
253250
device: int | str = "cpu",
254-
image_size: int = 1600,
255-
confidence: float = 0.3,
251+
image_size: int = 1024,
252+
confidence: float = 0.4,
256253
fast: bool = False,
257254
multiline: bool = False,
255+
batch_size: int = 4,
258256
):
259-
# detector = FFDNetDetector(model_or_path, device=device, fast=fast)
260-
detector = FFDetrDetector("FFDetr")
257+
if "FFDNET" in model_or_path.upper():
258+
detector = FFDNetDetector(model_or_path, device=device, fast=fast)
259+
else:
260+
detector = FFDetrDetector(model_or_path)
261261

262262
try:
263263
pages = render_pdf(input_path)
@@ -277,7 +277,9 @@ def prepare_form(
277277
name = f"{widget.widget_type.lower()}_{widget.page}_{i}"
278278

279279
if widget.widget_type == "TextBox":
280-
writer.add_text_box(name, page_ix, widget.bounding_box, multiline=multiline)
280+
writer.add_text_box(
281+
name, page_ix, widget.bounding_box, multiline=multiline
282+
)
281283
elif widget.widget_type == "ChoiceButton":
282284
writer.add_checkbox(name, page_ix, widget.bounding_box)
283285
elif widget.widget_type == "Signature":

tests/inference_test.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
def test_inference(tmp_path):
99
# tmp_path is a built-in pythest fixture where we'll write the outputs
1010
output_path = tmp_path / "output.pdf"
11-
commonforms.prepare_form("./tests/resources/input.pdf", output_path)
11+
commonforms.prepare_form("./tests/resources/input.pdf", output_path, model_or_path="FFDetr")
1212

1313
assert output_path.exists()
1414

@@ -20,7 +20,7 @@ def test_inference(tmp_path):
2020

2121
def test_inference_fast(tmp_path):
2222
output_path = tmp_path / "output.pdf"
23-
commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True)
23+
commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True, model_or_path="FFDNet-L")
2424

2525
assert output_path.exists()
2626

@@ -32,7 +32,9 @@ def test_inference_fast(tmp_path):
3232

3333
def test_mutlinline(tmp_path):
3434
output_path = tmp_path / "output.pdf"
35-
commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True, multiline=True)
35+
commonforms.prepare_form(
36+
"./tests/resources/input.pdf", output_path, fast=True, multiline=True
37+
)
3638

3739
assert output_path.exists()
3840

@@ -42,7 +44,6 @@ def test_mutlinline(tmp_path):
4244
doc.document.close()
4345

4446

45-
4647
def test_encrypted_failure(tmp_path):
4748
# Reminder to future Joe: password for encrypted PDF is "kanbanery"
4849
output_path = tmp_path / "output.pdf"
@@ -51,7 +52,22 @@ def test_encrypted_failure(tmp_path):
5152
commonforms.prepare_form("./tests/resources/encrypted.pdf", output_path)
5253

5354

55+
def test_inference_ffdetr(tmp_path):
56+
# tmp_path is a built-in pythest fixture where we'll write the outputs
57+
output_path = tmp_path / "output.pdf"
58+
commonforms.prepare_form(
59+
"./tests/resources/input.pdf", output_path, model_or_path="FFDetr"
60+
)
61+
62+
assert output_path.exists()
63+
64+
doc = formalpdf.open(output_path)
65+
assert len(doc[0].widgets()) > 0
66+
67+
doc.document.close()
68+
69+
5470
# TODO(joe): future tests around handling encrypted PDFs
5571
# 1. add a --password flag and test that inference doesn't fail
56-
# 2. if a password is provided, ensure that the _output_ PDF remains encrpyted
57-
# with the same password
72+
# 2. if a password is provided, ensure that the _output_ PDF remains encrpyted
73+
# with the same password

0 commit comments

Comments
 (0)