Skip to content

Commit 56fadb3

Browse files
feat: Track the table extraction method (#513)
<!-- CURSOR_SUMMARY --> > [!NOTE] > **Medium Risk** > Touches core `LayoutElement(s)` data structures and modifies cleanup logic that reorders/filters arrays, which could subtly affect downstream consumers if attribute alignment assumptions are wrong. > > **Overview** > Adds a new `table_extraction_method` attribute to `LayoutElement` and the vectorized `LayoutElements` container to record which table algorithm produced a given table (e.g., grid/tatr/vlm). > > Propagates this new field through `LayoutElements` operations (`__eq__`, `slice`, `concatenate`, `iter_elements`, `from_list`) and through layout cleanup routines so table-related metadata (`text_as_html`, `table_as_cells`, and the new method field) is retained and correctly aligned after sorting/filtering. > > Bumps the library version to `1.6.11` and documents the enhancement in `CHANGELOG.md`. > > <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit 70d9da7. Bugbot is set up for automated code reviews on this repo. Configure [here](https://www.cursor.com/dashboard/bugbot).</sup> <!-- /CURSOR_SUMMARY -->
1 parent 2458cad commit 56fadb3

3 files changed

Lines changed: 55 additions & 6 deletions

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 1.6.11
2+
3+
### Enhancement
4+
- Add `table_extraction_method` field to `LayoutElements` and `LayoutElement` to track which algorithm produced a table (grid, tatr, vlm).
5+
16
## 1.6.10
27

38
### Enhancement
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.6.10" # pragma: no cover
1+
__version__ = "1.6.11" # pragma: no cover

unstructured_inference/inference/layoutelement.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class LayoutElements(TextRegions):
2626
element_class_id_map: dict[int, str] = field(default_factory=dict)
2727
text_as_html: np.ndarray = field(default_factory=lambda: np.array([]))
2828
table_as_cells: np.ndarray = field(default_factory=lambda: np.array([]))
29+
table_extraction_method: np.ndarray = field(default_factory=lambda: np.array([]))
2930
routing: str | None = None
3031
routing_score: float | None = None
3132
_optional_array_attributes: list[str] = field(
@@ -38,6 +39,7 @@ class LayoutElements(TextRegions):
3839
"element_class_ids",
3940
"text_as_html",
4041
"table_as_cells",
42+
"table_extraction_method",
4143
],
4244
)
4345
_scalar_to_array_mappings: dict[str, str] = field(
@@ -71,6 +73,9 @@ def __eq__(self, other: object) -> bool:
7173
and np.array_equal(self.is_extracted_array[mask], other.is_extracted_array[mask])
7274
and np.array_equal(self.text_as_html[mask], other.text_as_html[mask])
7375
and np.array_equal(self.table_as_cells[mask], other.table_as_cells[mask])
76+
and np.array_equal(
77+
self.table_extraction_method[mask], other.table_extraction_method[mask]
78+
)
7479
)
7580

7681
def __getitem__(self, indices):
@@ -88,13 +93,14 @@ def slice(self, indices) -> LayoutElements:
8893
element_class_id_map=self.element_class_id_map,
8994
text_as_html=self.text_as_html[indices],
9095
table_as_cells=self.table_as_cells[indices],
96+
table_extraction_method=self.table_extraction_method[indices],
9197
)
9298

9399
@classmethod
94100
def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
95101
"""concatenate a sequence of LayoutElements in order as one LayoutElements"""
96102
coords, texts, probs, class_ids, sources, is_extracted_array = [], [], [], [], [], []
97-
text_as_html, table_as_cells = [], []
103+
text_as_html, table_as_cells, table_extraction_method = [], [], []
98104
class_id_reverse_map: dict[str, int] = {}
99105
for group in groups:
100106
coords.append(group.element_coords)
@@ -104,6 +110,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
104110
is_extracted_array.append(group.is_extracted_array)
105111
text_as_html.append(group.text_as_html)
106112
table_as_cells.append(group.table_as_cells)
113+
table_extraction_method.append(group.table_extraction_method)
107114

108115
idx = group.element_class_ids.copy()
109116
if group.element_class_id_map:
@@ -126,6 +133,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
126133
is_extracted_array=np.concatenate(is_extracted_array),
127134
text_as_html=np.concatenate(text_as_html),
128135
table_as_cells=np.concatenate(table_as_cells),
136+
table_extraction_method=np.concatenate(table_extraction_method),
129137
)
130138

131139
def iter_elements(self):
@@ -140,6 +148,7 @@ def iter_elements(self):
140148
is_extracted,
141149
text_as_html,
142150
table_as_cells,
151+
table_extraction_method,
143152
) in zip(
144153
self.element_coords,
145154
self.texts,
@@ -149,6 +158,7 @@ def iter_elements(self):
149158
self.is_extracted_array,
150159
self.text_as_html,
151160
self.table_as_cells,
161+
self.table_extraction_method,
152162
):
153163
yield LayoutElement.from_coords(
154164
x1,
@@ -166,6 +176,7 @@ def iter_elements(self):
166176
is_extracted=is_extracted,
167177
text_as_html=text_as_html,
168178
table_as_cells=table_as_cells,
179+
table_extraction_method=table_extraction_method,
169180
)
170181

171182
@classmethod
@@ -176,7 +187,16 @@ def from_list(cls, elements: list):
176187
coords = np.empty((len_ele, 4), dtype=float)
177188
# text and probs can be Nones so use lists first then convert into array to avoid them being
178189
# filled as nan
179-
texts, text_as_html, table_as_cells, sources, is_extracted_array, class_probs = (
190+
(
191+
texts,
192+
text_as_html,
193+
table_as_cells,
194+
table_extraction_method,
195+
sources,
196+
is_extracted_array,
197+
class_probs,
198+
) = (
199+
[],
180200
[],
181201
[],
182202
[],
@@ -193,6 +213,7 @@ def from_list(cls, elements: list):
193213
is_extracted_array.append(element.is_extracted)
194214
text_as_html.append(element.text_as_html)
195215
table_as_cells.append(element.table_as_cells)
216+
table_extraction_method.append(getattr(element, "table_extraction_method", None))
196217
class_probs.append(element.prob)
197218
class_types[i] = element.type or "None"
198219

@@ -209,6 +230,7 @@ def from_list(cls, elements: list):
209230
is_extracted_array=np.array(is_extracted_array),
210231
text_as_html=np.array(text_as_html),
211232
table_as_cells=np.array(table_as_cells),
233+
table_extraction_method=np.array(table_extraction_method),
212234
)
213235

214236

@@ -220,6 +242,7 @@ class LayoutElement(TextRegion):
220242
parent: Optional[LayoutElement] = None
221243
text_as_html: Optional[str] = None
222244
table_as_cells: Optional[str] = None
245+
table_extraction_method: Optional[str] = None
223246

224247
def to_dict(self) -> dict:
225248
"""Converts the class instance to dictionary form."""
@@ -264,6 +287,7 @@ def from_coords(
264287
prob: Optional[float] = None,
265288
text_as_html: Optional[str] = None,
266289
table_as_cells: Optional[str] = None,
290+
table_extraction_method: Optional[str] = None,
267291
**kwargs,
268292
) -> LayoutElement:
269293
"""Constructs a LayoutElement from coordinates."""
@@ -276,6 +300,7 @@ def from_coords(
276300
source=source,
277301
text_as_html=text_as_html,
278302
table_as_cells=table_as_cells,
303+
table_extraction_method=table_extraction_method,
279304
bbox=bbox,
280305
**kwargs,
281306
)
@@ -429,7 +454,16 @@ def clean_layoutelements(elements: LayoutElements, subregion_threshold: float =
429454
final_attrs: dict[str, Any] = {
430455
"element_class_id_map": elements.element_class_id_map,
431456
}
432-
for attr in ("element_class_ids", "element_probs", "texts", "sources", "is_extracted_array"):
457+
for attr in (
458+
"element_class_ids",
459+
"element_probs",
460+
"texts",
461+
"sources",
462+
"is_extracted_array",
463+
"text_as_html",
464+
"table_as_cells",
465+
"table_extraction_method",
466+
):
433467
if (original_attr := getattr(elements, attr)) is None:
434468
continue
435469
final_attrs[attr] = original_attr[sorted_by_area][mask][sorted_by_y1]
@@ -505,11 +539,21 @@ def clean_layoutelements_for_class(
505539

506540
final_coords = np.vstack([target_coords[mask], other_coords[other_mask]])
507541
final_attrs: dict[str, Any] = {"element_class_id_map": elements.element_class_id_map}
508-
for attr in ("element_class_ids", "element_probs", "texts", "sources", "is_extracted_array"):
542+
for attr in (
543+
"element_class_ids",
544+
"element_probs",
545+
"texts",
546+
"sources",
547+
"is_extracted_array",
548+
"text_as_html",
549+
"table_as_cells",
550+
"table_extraction_method",
551+
):
509552
if (original_attr := getattr(elements, attr)) is None:
510553
continue
554+
sorted_attr = original_attr[sorted_by_area]
511555
final_attrs[attr] = np.concatenate(
512-
(original_attr[target_indices][mask], original_attr[~target_indices][other_mask]),
556+
(sorted_attr[target_indices][mask], sorted_attr[~target_indices][other_mask]),
513557
)
514558
final_elements = LayoutElements(element_coords=final_coords, **final_attrs)
515559
return final_elements

0 commit comments

Comments
 (0)