1010
1111import formalpdf
1212import pypdfium2
13+ import logging
1314import PIL
1415
1516
16- # our mapping from (model_name, fast) to (repo_id, filename) for the huggingface hub
17+ logging .basicConfig (level = logging .INFO )
18+
19+
20+ # our mapping from (model_name_upper, fast) to (repo_id, filename) for the huggingface hub.
21+ # keeping it simple and declarative like this becuase it's not like we're adding a bunch
22+ # of models.
1723models = {
1824 ("FFDNET-S" , True ): ("jbarrow/FFDNet-S-cpu" , "FFDNet-S.onnx" ),
1925 ("FFDNET-S" , False ): ("jbarrow/FFDNet-S" , "FFDNet-S.pt" ),
2026 ("FFDNET-L" , True ): ("jbarrow/FFDNet-L-cpu" , "FFDNet-L.onnx" ),
2127 ("FFDNET-L" , False ): ("jbarrow/FFDNet-L" , "FFDNet-L.pt" ),
22- ("FFDetr-Nano " , False ): ("./models/ FFDetr-Nano " , "checkpoint_best_ema .pth" )
28+ ("FFDETR " , False ): ("jbarrow/ FFDetr" , "FFDetr .pth" )
2329}
2430
2531
@@ -30,15 +36,25 @@ def batch(lst: list, n: int = 8):
3036 yield lst [ndx :min (ndx + n , l )]
3137
3238
33-
3439class FFDetrDetector :
3540 def __init__ (
3641 self , model_or_path : str , device : int | str = "cpu"
3742 ) -> None :
3843 self .device = device
39- self .model = RFDETRMedium (pretrain_weights = model_or_path , resolution = 224 * 5 , num_classes = 2 )
44+ self .model = RFDETRMedium (pretrain_weights = self .get_model_path (model_or_path ))
45+
46+ self .id_to_cls = {0 : "TextBox" , 1 : "ChoiceButton" , 2 : "Signature" }
47+
48+ def get_model_path (self , model_or_path : str ) -> str :
49+ model_upper = model_or_path .upper ()
50+ if model_upper in ["FFDETR" ]:
51+ # download the model, will just use the cached version if it already exists
52+ repo_id , filename = models [(model_upper , False )]
53+ model_path = hf_hub_download (repo_id = repo_id , filename = filename )
54+ else :
55+ model_path = model_or_path
4056
41- self . id_to_cls = { 0 : "TextBox" , 1 : "ChoiceButton" }
57+ return model_path
4258
4359 def resize (
4460 self ,
@@ -51,22 +67,26 @@ def resize(
5167 return image .resize (size , PIL .Image .Resampling .LANCZOS )
5268
5369 def extract_widgets (
54- self , pages : list [Page ], confidence : float = 0.2 , image_size : int = 1120
70+ self ,
71+ pages : list [Page ],
72+ confidence : float = 0.4 ,
73+ image_size : int = 1120 ,
74+ batch_size : int = 3 ,
5575 ) -> dict [int , list [Widget ]]:
5676 image_size = 1024
5777 results = []
58- for b in batch ([p .image for p in pages ], n = 1 ):
59- results += [self .model .predict (b , threshold = confidence )]
78+ for b in batch ([p .image for p in pages ], n = batch_size ):
79+ predictions = self .model .predict (b , threshold = confidence )
80+ if len (pages ) == 1 or batch_size == 1 :
81+ predictions = [predictions ]
82+ results .extend (predictions )
6083
6184 widgets = {}
6285
63- if len (pages ) == 1 :
64- results = [results ]
65-
6686 for page_ix , detections in enumerate (results ):
67- print (f"{ page_ix } : { len (detections )} fields detected" )
87+ logging . info (f" Page { page_ix } : { len (detections )} fields detected" )
6888 detections = detections .with_nms (threshold = 0.1 , class_agnostic = True )
69- print (f"{ len (detections )} after nms" )
89+ logging . info (f"\t \t { len (detections )} after nms" )
7090 widgets [page_ix ] = []
7191
7292 for class_id , box in zip (detections .class_id , detections .xyxy ):
@@ -217,7 +237,6 @@ def render_pdf(pdf_path: str) -> list[Page]:
217237 try :
218238 for page in doc :
219239 image = page .render (dpi = 144 )
220- print (image .width , image .height )
221240 pages .append (Page (image = image , width = image .width , height = image .height ))
222241 return pages
223242 finally :
@@ -238,7 +257,7 @@ def prepare_form(
238257 multiline : bool = False ,
239258):
240259 # detector = FFDNetDetector(model_or_path, device=device, fast=fast)
241- detector = FFDetrDetector ("./models/ FFDetr-Medium/checkpoint_best_ema.pth " )
260+ detector = FFDetrDetector ("FFDetr" )
242261
243262 try :
244263 pages = render_pdf (input_path )
0 commit comments