22from ultralytics import YOLO
33from pathlib import Path
44from huggingface_hub import hf_hub_download
5+ from rfdetr import RFDETRNano , RFDETRBase , RFDETRMedium , RFDETRLarge
56
67from commonforms .utils import BoundingBox , Page , Widget
78from commonforms .form_creator import PyPdfFormCreator
89from commonforms .exceptions import EncryptedPdfError
910
1011import formalpdf
1112import pypdfium2
13+ import logging
14+ import PIL
1215
1316
14- # our mapping from (model_name, fast) to (repo_id, filename) for the huggingface hub
17+ logging .basicConfig (level = logging .INFO )
18+
19+
20+ # our mapping from (model_name_upper, fast) to (repo_id, filename) for the huggingface hub.
21+ # keeping it simple and declarative like this becuase it's not like we're adding a bunch
22+ # of models.
1523models = {
1624 ("FFDNET-S" , True ): ("jbarrow/FFDNet-S-cpu" , "FFDNet-S.onnx" ),
1725 ("FFDNET-S" , False ): ("jbarrow/FFDNet-S" , "FFDNet-S.pt" ),
1826 ("FFDNET-L" , True ): ("jbarrow/FFDNet-L-cpu" , "FFDNet-L.onnx" ),
1927 ("FFDNET-L" , False ): ("jbarrow/FFDNet-L" , "FFDNet-L.pt" ),
28+ ("FFDETR" , False ): ("jbarrow/FFDetr" , "FFDetr.pth" ),
2029}
2130
2231
32+ def batch (lst : list , n : int = 8 ):
33+ l = len (lst )
34+ for ndx in range (0 , l , n ):
35+ yield lst [ndx : min (ndx + n , l )]
36+
37+
38+ class FFDetrDetector :
39+ def __init__ (self , model_or_path : str , device : int | str = "cpu" ) -> None :
40+ self .device = device
41+ self .model = RFDETRMedium (pretrain_weights = self .get_model_path (model_or_path ))
42+
43+ self .id_to_cls = {0 : "TextBox" , 1 : "ChoiceButton" , 2 : "Signature" }
44+
45+ def get_model_path (self , model_or_path : str ) -> str :
46+ model_upper = model_or_path .upper ()
47+ if model_upper in ["FFDETR" ]:
48+ # download the model, will just use the cached version if it already exists
49+ repo_id , filename = models [(model_upper , False )]
50+ model_path = hf_hub_download (repo_id = repo_id , filename = filename )
51+ else :
52+ model_path = model_or_path
53+
54+ return model_path
55+
56+ def resize (
57+ self ,
58+ image : PIL .Image .Image ,
59+ size : tuple [int , int ] | int ,
60+ ) -> PIL .Image .Image :
61+ if isinstance (size , int ):
62+ size = (size , size )
63+
64+ return image .resize (size , PIL .Image .Resampling .LANCZOS )
65+
66+ def extract_widgets (
67+ self ,
68+ pages : list [Page ],
69+ confidence : float = 0.4 ,
70+ image_size : int = 1120 ,
71+ batch_size : int = 3 ,
72+ ) -> dict [int , list [Widget ]]:
73+ image_size = 1024
74+ results = []
75+ for b in batch ([self .resize (p .image , image_size ) for p in pages ], n = batch_size ):
76+ predictions = self .model .predict (b , threshold = confidence )
77+ if len (pages ) == 1 or batch_size == 1 :
78+ predictions = [predictions ]
79+ results .extend (predictions )
80+
81+ widgets = {}
82+
83+ for page_ix , detections in enumerate (results ):
84+ logging .info (f" Page { page_ix } : { len (detections )} fields detected" )
85+ detections = detections .with_nms (threshold = 0.1 , class_agnostic = True )
86+ logging .info (f"\t \t { len (detections )} after nms" )
87+ widgets [page_ix ] = []
88+
89+ for class_id , box in zip (detections .class_id , detections .xyxy ):
90+ x0 , x1 = box [[0 , 2 ]] / pages [page_ix ].image .width
91+ y0 , y1 = box [[1 , 3 ]] / pages [page_ix ].image .height
92+
93+ widget_type = self .id_to_cls [class_id ]
94+
95+ widgets [page_ix ].append (
96+ Widget (
97+ widget_type = widget_type ,
98+ bounding_box = BoundingBox (x0 = x0 , y0 = y0 , x1 = x1 , y1 = y1 ),
99+ page = page_ix ,
100+ )
101+ )
102+
103+ widgets [page_ix ] = sort_widgets (widgets [page_ix ])
104+
105+ return widgets
106+
107+
23108class FFDNetDetector :
24109 def __init__ (
25110 self , model_or_path : str , device : int | str = "cpu" , fast : bool = False
@@ -43,8 +128,8 @@ def get_model_path(
43128 model_upper = model_or_path .upper ()
44129 if model_upper in ["FFDNET-S" , "FFDNET-L" ]:
45130 # download the model, will just use the cached version if it already exists
46- repo_id , filename = models [(model_upper , fast )]
47- model_path = hf_hub_download (repo_id = repo_id , filename = filename )
131+ repo_id , filename = models [(model_upper , fast )]
132+ model_path = hf_hub_download (repo_id = repo_id , filename = filename )
48133 else :
49134 model_path = model_or_path
50135
@@ -148,7 +233,7 @@ def render_pdf(pdf_path: str) -> list[Page]:
148233 doc = formalpdf .open (pdf_path )
149234 try :
150235 for page in doc :
151- image = page .render ()
236+ image = page .render (dpi = 144 )
152237 pages .append (Page (image = image , width = image .width , height = image .height ))
153238 return pages
154239 finally :
@@ -159,16 +244,20 @@ def prepare_form(
159244 input_path : str | Path ,
160245 output_path : str | Path ,
161246 * ,
162- model_or_path : str = "FFDNet-L " ,
247+ model_or_path : str = "FFDetr " ,
163248 keep_existing_fields : bool = False ,
164249 use_signature_fields : bool = False ,
165250 device : int | str = "cpu" ,
166- image_size : int = 1600 ,
167- confidence : float = 0.3 ,
251+ image_size : int = 1024 ,
252+ confidence : float = 0.4 ,
168253 fast : bool = False ,
169254 multiline : bool = False ,
255+ batch_size : int = 4 ,
170256):
171- detector = FFDNetDetector (model_or_path , device = device , fast = fast )
257+ if "FFDNET" in model_or_path .upper ():
258+ detector = FFDNetDetector (model_or_path , device = device , fast = fast )
259+ else :
260+ detector = FFDetrDetector (model_or_path )
172261
173262 try :
174263 pages = render_pdf (input_path )
@@ -188,7 +277,9 @@ def prepare_form(
188277 name = f"{ widget .widget_type .lower ()} _{ widget .page } _{ i } "
189278
190279 if widget .widget_type == "TextBox" :
191- writer .add_text_box (name , page_ix , widget .bounding_box , multiline = multiline )
280+ writer .add_text_box (
281+ name , page_ix , widget .bounding_box , multiline = multiline
282+ )
192283 elif widget .widget_type == "ChoiceButton" :
193284 writer .add_checkbox (name , page_ix , widget .bounding_box )
194285 elif widget .widget_type == "Signature" :
0 commit comments