2525 ("FFDNET-S" , False ): ("jbarrow/FFDNet-S" , "FFDNet-S.pt" ),
2626 ("FFDNET-L" , True ): ("jbarrow/FFDNet-L-cpu" , "FFDNet-L.onnx" ),
2727 ("FFDNET-L" , False ): ("jbarrow/FFDNet-L" , "FFDNet-L.pt" ),
28- ("FFDETR" , False ): ("jbarrow/FFDetr" , "FFDetr.pth" )
28+ ("FFDETR" , False ): ("jbarrow/FFDetr" , "FFDetr.pth" ),
2929}
3030
3131
32-
3332def batch (lst : list , n : int = 8 ):
3433 l = len (lst )
3534 for ndx in range (0 , l , n ):
36- yield lst [ndx : min (ndx + n , l )]
35+ yield lst [ndx : min (ndx + n , l )]
3736
3837
3938class FFDetrDetector :
40- def __init__ (
41- self , model_or_path : str , device : int | str = "cpu"
42- ) -> None :
39+ def __init__ (self , model_or_path : str , device : int | str = "cpu" ) -> None :
4340 self .device = device
4441 self .model = RFDETRMedium (pretrain_weights = self .get_model_path (model_or_path ))
4542
@@ -49,8 +46,8 @@ def get_model_path(self, model_or_path: str) -> str:
4946 model_upper = model_or_path .upper ()
5047 if model_upper in ["FFDETR" ]:
5148 # download the model, will just use the cached version if it already exists
52- repo_id , filename = models [(model_upper , False )]
53- model_path = hf_hub_download (repo_id = repo_id , filename = filename )
49+ repo_id , filename = models [(model_upper , False )]
50+ model_path = hf_hub_download (repo_id = repo_id , filename = filename )
5451 else :
5552 model_path = model_or_path
5653
@@ -75,7 +72,7 @@ def extract_widgets(
7572 ) -> dict [int , list [Widget ]]:
7673 image_size = 1024
7774 results = []
78- for b in batch ([p .image for p in pages ], n = batch_size ):
75+ for b in batch ([self . resize ( p .image , image_size ) for p in pages ], n = batch_size ):
7976 predictions = self .model .predict (b , threshold = confidence )
8077 if len (pages ) == 1 or batch_size == 1 :
8178 predictions = [predictions ]
@@ -131,8 +128,8 @@ def get_model_path(
131128 model_upper = model_or_path .upper ()
132129 if model_upper in ["FFDNET-S" , "FFDNET-L" ]:
133130 # download the model, will just use the cached version if it already exists
134- repo_id , filename = models [(model_upper , fast )]
135- model_path = hf_hub_download (repo_id = repo_id , filename = filename )
131+ repo_id , filename = models [(model_upper , fast )]
132+ model_path = hf_hub_download (repo_id = repo_id , filename = filename )
136133 else :
137134 model_path = model_or_path
138135
@@ -247,17 +244,20 @@ def prepare_form(
247244 input_path : str | Path ,
248245 output_path : str | Path ,
249246 * ,
250- model_or_path : str = "FFDNet-L " ,
247+ model_or_path : str = "FFDetr " ,
251248 keep_existing_fields : bool = False ,
252249 use_signature_fields : bool = False ,
253250 device : int | str = "cpu" ,
254- image_size : int = 1600 ,
255- confidence : float = 0.3 ,
251+ image_size : int = 1024 ,
252+ confidence : float = 0.4 ,
256253 fast : bool = False ,
257254 multiline : bool = False ,
255+ batch_size : int = 4 ,
258256):
259- # detector = FFDNetDetector(model_or_path, device=device, fast=fast)
260- detector = FFDetrDetector ("FFDetr" )
257+ if "FFDNET" in model_or_path .upper ():
258+ detector = FFDNetDetector (model_or_path , device = device , fast = fast )
259+ else :
260+ detector = FFDetrDetector (model_or_path )
261261
262262 try :
263263 pages = render_pdf (input_path )
@@ -277,7 +277,9 @@ def prepare_form(
277277 name = f"{ widget .widget_type .lower ()} _{ widget .page } _{ i } "
278278
279279 if widget .widget_type == "TextBox" :
280- writer .add_text_box (name , page_ix , widget .bounding_box , multiline = multiline )
280+ writer .add_text_box (
281+ name , page_ix , widget .bounding_box , multiline = multiline
282+ )
281283 elif widget .widget_type == "ChoiceButton" :
282284 writer .add_checkbox (name , page_ix , widget .bounding_box )
283285 elif widget .widget_type == "Signature" :
0 commit comments