@@ -26,6 +26,7 @@ class LayoutElements(TextRegions):
2626 element_class_id_map : dict [int , str ] = field (default_factory = dict )
2727 text_as_html : np .ndarray = field (default_factory = lambda : np .array ([]))
2828 table_as_cells : np .ndarray = field (default_factory = lambda : np .array ([]))
29+ table_extraction_method : np .ndarray = field (default_factory = lambda : np .array ([]))
2930 routing : str | None = None
3031 routing_score : float | None = None
3132 _optional_array_attributes : list [str ] = field (
@@ -38,6 +39,7 @@ class LayoutElements(TextRegions):
3839 "element_class_ids" ,
3940 "text_as_html" ,
4041 "table_as_cells" ,
42+ "table_extraction_method" ,
4143 ],
4244 )
4345 _scalar_to_array_mappings : dict [str , str ] = field (
@@ -71,6 +73,9 @@ def __eq__(self, other: object) -> bool:
7173 and np .array_equal (self .is_extracted_array [mask ], other .is_extracted_array [mask ])
7274 and np .array_equal (self .text_as_html [mask ], other .text_as_html [mask ])
7375 and np .array_equal (self .table_as_cells [mask ], other .table_as_cells [mask ])
76+ and np .array_equal (
77+ self .table_extraction_method [mask ], other .table_extraction_method [mask ]
78+ )
7479 )
7580
7681 def __getitem__ (self , indices ):
@@ -88,13 +93,14 @@ def slice(self, indices) -> LayoutElements:
8893 element_class_id_map = self .element_class_id_map ,
8994 text_as_html = self .text_as_html [indices ],
9095 table_as_cells = self .table_as_cells [indices ],
96+ table_extraction_method = self .table_extraction_method [indices ],
9197 )
9298
9399 @classmethod
94100 def concatenate (cls , groups : Iterable [LayoutElements ]) -> LayoutElements :
95101 """concatenate a sequence of LayoutElements in order as one LayoutElements"""
96102 coords , texts , probs , class_ids , sources , is_extracted_array = [], [], [], [], [], []
97- text_as_html , table_as_cells = [], []
103+ text_as_html , table_as_cells , table_extraction_method = [], [], []
98104 class_id_reverse_map : dict [str , int ] = {}
99105 for group in groups :
100106 coords .append (group .element_coords )
@@ -104,6 +110,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
104110 is_extracted_array .append (group .is_extracted_array )
105111 text_as_html .append (group .text_as_html )
106112 table_as_cells .append (group .table_as_cells )
113+ table_extraction_method .append (group .table_extraction_method )
107114
108115 idx = group .element_class_ids .copy ()
109116 if group .element_class_id_map :
@@ -126,6 +133,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
126133 is_extracted_array = np .concatenate (is_extracted_array ),
127134 text_as_html = np .concatenate (text_as_html ),
128135 table_as_cells = np .concatenate (table_as_cells ),
136+ table_extraction_method = np .concatenate (table_extraction_method ),
129137 )
130138
131139 def iter_elements (self ):
@@ -140,6 +148,7 @@ def iter_elements(self):
140148 is_extracted ,
141149 text_as_html ,
142150 table_as_cells ,
151+ table_extraction_method ,
143152 ) in zip (
144153 self .element_coords ,
145154 self .texts ,
@@ -149,6 +158,7 @@ def iter_elements(self):
149158 self .is_extracted_array ,
150159 self .text_as_html ,
151160 self .table_as_cells ,
161+ self .table_extraction_method ,
152162 ):
153163 yield LayoutElement .from_coords (
154164 x1 ,
@@ -166,6 +176,7 @@ def iter_elements(self):
166176 is_extracted = is_extracted ,
167177 text_as_html = text_as_html ,
168178 table_as_cells = table_as_cells ,
179+ table_extraction_method = table_extraction_method ,
169180 )
170181
171182 @classmethod
@@ -176,7 +187,16 @@ def from_list(cls, elements: list):
176187 coords = np .empty ((len_ele , 4 ), dtype = float )
177188 # text and probs can be Nones so use lists first then convert into array to avoid them being
178189 # filled as nan
179- texts , text_as_html , table_as_cells , sources , is_extracted_array , class_probs = (
190+ (
191+ texts ,
192+ text_as_html ,
193+ table_as_cells ,
194+ table_extraction_method ,
195+ sources ,
196+ is_extracted_array ,
197+ class_probs ,
198+ ) = (
199+ [],
180200 [],
181201 [],
182202 [],
@@ -193,6 +213,7 @@ def from_list(cls, elements: list):
193213 is_extracted_array .append (element .is_extracted )
194214 text_as_html .append (element .text_as_html )
195215 table_as_cells .append (element .table_as_cells )
216+ table_extraction_method .append (getattr (element , "table_extraction_method" , None ))
196217 class_probs .append (element .prob )
197218 class_types [i ] = element .type or "None"
198219
@@ -209,6 +230,7 @@ def from_list(cls, elements: list):
209230 is_extracted_array = np .array (is_extracted_array ),
210231 text_as_html = np .array (text_as_html ),
211232 table_as_cells = np .array (table_as_cells ),
233+ table_extraction_method = np .array (table_extraction_method ),
212234 )
213235
214236
@@ -220,6 +242,7 @@ class LayoutElement(TextRegion):
220242 parent : Optional [LayoutElement ] = None
221243 text_as_html : Optional [str ] = None
222244 table_as_cells : Optional [str ] = None
245+ table_extraction_method : Optional [str ] = None
223246
224247 def to_dict (self ) -> dict :
225248 """Converts the class instance to dictionary form."""
@@ -264,6 +287,7 @@ def from_coords(
264287 prob : Optional [float ] = None ,
265288 text_as_html : Optional [str ] = None ,
266289 table_as_cells : Optional [str ] = None ,
290+ table_extraction_method : Optional [str ] = None ,
267291 ** kwargs ,
268292 ) -> LayoutElement :
269293 """Constructs a LayoutElement from coordinates."""
@@ -276,6 +300,7 @@ def from_coords(
276300 source = source ,
277301 text_as_html = text_as_html ,
278302 table_as_cells = table_as_cells ,
303+ table_extraction_method = table_extraction_method ,
279304 bbox = bbox ,
280305 ** kwargs ,
281306 )
@@ -429,7 +454,16 @@ def clean_layoutelements(elements: LayoutElements, subregion_threshold: float =
429454 final_attrs : dict [str , Any ] = {
430455 "element_class_id_map" : elements .element_class_id_map ,
431456 }
432- for attr in ("element_class_ids" , "element_probs" , "texts" , "sources" , "is_extracted_array" ):
457+ for attr in (
458+ "element_class_ids" ,
459+ "element_probs" ,
460+ "texts" ,
461+ "sources" ,
462+ "is_extracted_array" ,
463+ "text_as_html" ,
464+ "table_as_cells" ,
465+ "table_extraction_method" ,
466+ ):
433467 if (original_attr := getattr (elements , attr )) is None :
434468 continue
435469 final_attrs [attr ] = original_attr [sorted_by_area ][mask ][sorted_by_y1 ]
@@ -505,11 +539,21 @@ def clean_layoutelements_for_class(
505539
506540 final_coords = np .vstack ([target_coords [mask ], other_coords [other_mask ]])
507541 final_attrs : dict [str , Any ] = {"element_class_id_map" : elements .element_class_id_map }
508- for attr in ("element_class_ids" , "element_probs" , "texts" , "sources" , "is_extracted_array" ):
542+ for attr in (
543+ "element_class_ids" ,
544+ "element_probs" ,
545+ "texts" ,
546+ "sources" ,
547+ "is_extracted_array" ,
548+ "text_as_html" ,
549+ "table_as_cells" ,
550+ "table_extraction_method" ,
551+ ):
509552 if (original_attr := getattr (elements , attr )) is None :
510553 continue
554+ sorted_attr = original_attr [sorted_by_area ]
511555 final_attrs [attr ] = np .concatenate (
512- (original_attr [target_indices ][mask ], original_attr [~ target_indices ][other_mask ]),
556+ (sorted_attr [target_indices ][mask ], sorted_attr [~ target_indices ][other_mask ]),
513557 )
514558 final_elements = LayoutElements (element_coords = final_coords , ** final_attrs )
515559 return final_elements
0 commit comments