22from ultralytics import YOLO
33from pathlib import Path
44from huggingface_hub import hf_hub_download
5+ from rfdetr import RFDETRNano , RFDETRBase , RFDETRMedium , RFDETRLarge
56
67from commonforms .utils import BoundingBox , Page , Widget
78from commonforms .form_creator import PyPdfFormCreator
89from commonforms .exceptions import EncryptedPdfError
910
1011import formalpdf
1112import pypdfium2
13+ import PIL
1214
1315
1416# our mapping from (model_name, fast) to (repo_id, filename) for the huggingface hub
1719 ("FFDNET-S" , False ): ("jbarrow/FFDNet-S" , "FFDNet-S.pt" ),
1820 ("FFDNET-L" , True ): ("jbarrow/FFDNet-L-cpu" , "FFDNet-L.onnx" ),
1921 ("FFDNET-L" , False ): ("jbarrow/FFDNet-L" , "FFDNet-L.pt" ),
22+ ("FFDetr-Nano" , False ): ("./models/FFDetr-Nano" , "checkpoint_best_ema.pth" )
2023}
2124
2225
26+
27+ def batch (lst : list , n : int = 8 ):
28+ l = len (lst )
29+ for ndx in range (0 , l , n ):
30+ yield lst [ndx :min (ndx + n , l )]
31+
32+
33+
34+ class FFDetrDetector :
35+ def __init__ (
36+ self , model_or_path : str , device : int | str = "cpu"
37+ ) -> None :
38+ self .device = device
39+ self .model = RFDETRMedium (pretrain_weights = model_or_path , resolution = 224 * 5 , num_classes = 2 )
40+
41+ self .id_to_cls = {0 : "TextBox" , 1 : "ChoiceButton" }
42+
43+ def resize (
44+ self ,
45+ image : PIL .Image .Image ,
46+ size : tuple [int , int ] | int ,
47+ ) -> PIL .Image .Image :
48+ if isinstance (size , int ):
49+ size = (size , size )
50+
51+ return image .resize (size , PIL .Image .Resampling .LANCZOS )
52+
53+ def extract_widgets (
54+ self , pages : list [Page ], confidence : float = 0.2 , image_size : int = 1120
55+ ) -> dict [int , list [Widget ]]:
56+ image_size = 1024
57+ results = []
58+ for b in batch ([p .image for p in pages ], n = 1 ):
59+ results += [self .model .predict (b , threshold = confidence )]
60+
61+ widgets = {}
62+
63+ if len (pages ) == 1 :
64+ results = [results ]
65+
66+ for page_ix , detections in enumerate (results ):
67+ print (f"{ page_ix } : { len (detections )} fields detected" )
68+ detections = detections .with_nms (threshold = 0.1 , class_agnostic = True )
69+ print (f"{ len (detections )} after nms" )
70+ widgets [page_ix ] = []
71+
72+ for class_id , box in zip (detections .class_id , detections .xyxy ):
73+ x0 , x1 = box [[0 , 2 ]] / pages [page_ix ].image .width
74+ y0 , y1 = box [[1 , 3 ]] / pages [page_ix ].image .height
75+
76+ widget_type = self .id_to_cls [class_id ]
77+
78+ widgets [page_ix ].append (
79+ Widget (
80+ widget_type = widget_type ,
81+ bounding_box = BoundingBox (x0 = x0 , y0 = y0 , x1 = x1 , y1 = y1 ),
82+ page = page_ix ,
83+ )
84+ )
85+
86+ widgets [page_ix ] = sort_widgets (widgets [page_ix ])
87+
88+ return widgets
89+
90+
2391class FFDNetDetector :
2492 def __init__ (
2593 self , model_or_path : str , device : int | str = "cpu" , fast : bool = False
@@ -148,7 +216,8 @@ def render_pdf(pdf_path: str) -> list[Page]:
148216 doc = formalpdf .open (pdf_path )
149217 try :
150218 for page in doc :
151- image = page .render ()
219+ image = page .render (dpi = 144 )
220+ print (image .width , image .height )
152221 pages .append (Page (image = image , width = image .width , height = image .height ))
153222 return pages
154223 finally :
@@ -168,7 +237,8 @@ def prepare_form(
168237 fast : bool = False ,
169238 multiline : bool = False ,
170239):
171- detector = FFDNetDetector (model_or_path , device = device , fast = fast )
240+ # detector = FFDNetDetector(model_or_path, device=device, fast=fast)
241+ detector = FFDetrDetector ("./models/FFDetr-Medium/checkpoint_best_ema.pth" )
172242
173243 try :
174244 pages = render_pdf (input_path )
0 commit comments