11from __future__ import absolute_import
22
3+ import os
34import json
5+ import numpy as np
46
57from ocrd_utils import (
68 getLogger ,
79 make_file_id ,
810 assert_file_grp_cardinality ,
911 coordinates_of_segment ,
12+ xywh_from_polygon ,
1013 polygon_from_points ,
1114 MIME_TO_EXT
1215)
1316from ocrd_modelfactory import page_from_file
1417from ocrd import Processor
1518
1619from .config import OCRD_TOOL
17- from .extract_pages import CLASSES
20+ from .extract_pages import CLASSES , segment_poly
1821
1922TOOL = 'ocrd-segment-extract-regions'
2023
@@ -40,7 +43,10 @@ def process(self):
4043 If ``transparency`` is true, then also add an alpha channel which is
4144 fully transparent outside of the mask.
4245
43- Create a JSON file with:
46+ Create two JSON files with region types and coordinates: one (page-wise)
47+ in our custom format and one (global) in MS-COCO.
48+
49+ The custom JSON files contain:
4450 * the IDs of the region and its parents,
4551 * the region's coordinates relative to the region image,
4652 * the region's absolute coordinates,
@@ -62,6 +68,7 @@ def process(self):
6268 * ID + '.bin.png': region image (if the workflow provides binarized images)
6369 * ID + '.nrm.png': region image (if the workflow provides grayscale-normalized images)
6470 * ID + '.json': region metadata.
71+ * output_file_grp + '.coco.json'
6572 """
6673 LOG = getLogger ('processor.ExtractRegions' )
6774 assert_file_grp_cardinality (self .input_file_grp , 1 )
@@ -72,12 +79,32 @@ def process(self):
7279 if self .parameter ["classes" ]:
7380 selected_classes = self .parameter ["classes" ]
7481 classes = { region : classes [region ] for region in selected_classes }
82+ # COCO: init data structures
83+ images = list ()
84+ annotations = list ()
85+ categories = list ()
86+ i = 0
87+ for cat , color in classes .items ():
88+ # COCO format does not allow alpha channel
89+ color = (int (color [0 :2 ], 16 ),
90+ int (color [2 :4 ], 16 ),
91+ int (color [4 :6 ], 16 ))
92+ try :
93+ supercat , name = cat .split (':' )
94+ except ValueError :
95+ name = cat
96+ supercat = ''
97+ categories .append (
98+ {'id' : i , 'name' : name , 'supercategory' : supercat ,
99+ 'source' : 'PAGE' , 'color' : color })
100+ i += 1
101+ i = 0 # subregion count (i.e. annotation id)
102+ j = 0 # region count (i.e. image id)
75103 # pylint: disable=attribute-defined-outside-init
76104 for n , input_file in enumerate (self .input_files ):
77105 page_id = input_file .pageId or input_file .ID
78106 LOG .info ("INPUT FILE %i / %s" , n , page_id )
79107 pcgts = page_from_file (self .workspace .download_file (input_file ))
80- self .add_metadata (pcgts )
81108 page = pcgts .get_Page ()
82109 page_image , page_coords , page_image_info = self .workspace .image_from_page (
83110 page , page_id ,
@@ -92,9 +119,9 @@ def process(self):
92119 ptype = page .get_type ()
93120
94121 regions = dict ()
95- for name in classes . keys () :
122+ for name in classes :
96123 if not name or not name .endswith ("Region" ):
97- # no subtypes here
124+ # only top-level regions here
98125 continue
99126 regions [name ] = getattr (page , 'get_' + name )()
100127 for rtype , rlist in regions .items ():
@@ -110,6 +137,7 @@ def process(self):
110137 subrtype = region .get_type ()
111138 else :
112139 subrtype = None
140+ j += 1
113141 description ['subtype' ] = subrtype
114142 description ['coords_rel' ] = coordinates_of_segment (
115143 region , region_image , region_coords ).tolist ()
@@ -164,7 +192,50 @@ def process(self):
164192 extension = '.nrm'
165193 else :
166194 extension = '.raw'
167-
195+ subregions = dict ()
196+ for name in classes :
197+ if not name or ':' in name :
198+ # no subtypes here
199+ continue
200+ if not hasattr (region , 'get_' + name ):
201+ continue
202+ subregions [name ] = getattr (region , 'get_' + name )()
203+ for subrtype , subrlist in subregions .items ():
204+ for subregion in subrlist :
205+ poly = segment_poly (page_id , subregion , region_coords )
206+ if not poly :
207+ continue
208+ polygon = np .array (poly .exterior .coords , np .int )[:- 1 ].tolist ()
209+ xywh = xywh_from_polygon (polygon )
210+ area = poly .area
211+ if subrtype in ['TextRegion' , 'ChartRegion' , 'GraphicRegion' ]:
212+ subsubrtype = subregion .get_type ()
213+ else :
214+ subsubrtype = None
215+ if subsubrtype :
216+ subrtype0 = subrtype + ':' + subsubrtype
217+ else :
218+ subrtype0 = subrtype
219+ description .setdefault ('regions' , []).append (
220+ { 'type' : subrtype ,
221+ 'subtype' : subsubrtype ,
222+ 'coords' : polygon ,
223+ 'area' : area ,
224+ 'region.ID' : subregion .id
225+ })
226+ # COCO: add annotations
227+ i += 1
228+ annotations .append (
229+ {'id' : i , 'image_id' : j ,
230+ 'category_id' : next ((cat ['id' ] for cat in categories if cat ['name' ] == subsubrtype ),
231+ next ((cat ['id' ] for cat in categories if cat ['name' ] == subrtype ))),
232+ 'segmentation' : np .array (poly .exterior .coords , np .int )[:- 1 ].reshape (1 , - 1 ).tolist (),
233+ 'area' : area ,
234+ 'bbox' : [xywh ['x' ], xywh ['y' ], xywh ['w' ], xywh ['h' ]],
235+ 'iscrowd' : 0 })
236+
237+
238+
168239 file_id = make_file_id (input_file , self .output_file_grp ) + '_' + region .id + extension
169240 file_path = self .workspace .save_image_file (
170241 region_image ,
@@ -179,3 +250,25 @@ def process(self):
179250 pageId = input_file .pageId ,
180251 mimetype = 'application/json' ,
181252 content = json .dumps (description ))
253+ # COCO: add image
254+ images .append ({
255+ 'id' : j ,
256+ # all exported coordinates are relative to the cropped region:
257+ # -> use that for reference
258+ 'file_name' : file_path ,
259+ # -> use its size
260+ 'width' : region_image .width ,
261+ 'height' : region_image .height })
262+ # COCO: write result
263+ file_id = self .output_file_grp + '.coco.json'
264+ LOG .info ('Writing COCO result file "%s"' , file_id )
265+ self .workspace .add_file (
266+ ID = file_id ,
267+ file_grp = self .output_file_grp ,
268+ local_filename = os .path .join (self .output_file_grp , file_id ),
269+ mimetype = 'application/json' ,
270+ pageId = None ,
271+ content = json .dumps (
272+ {'categories' : categories ,
273+ 'images' : images ,
274+ 'annotations' : annotations }))
0 commit comments