Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 55 additions & 62 deletions unstructured/partition/utils/sorting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, List

import numpy as np

Expand Down Expand Up @@ -35,17 +35,7 @@ def coordinates_to_bbox(coordinates: CoordinatesMetadata) -> tuple[int, int, int
def shrink_bbox(bbox: tuple[int, int, int, int], shrink_factor) -> tuple[int, int, int, int]:
"""
Shrink a bounding box by a given shrink factor while maintaining its top and left.

Parameters:
bbox (tuple[int, int, int, int]): The original bounding box represented by
(left, top, right, bottom).
shrink_factor (float): The factor by which to shrink the bounding box (0.0 to 1.0).

Returns:
tuple[int, int, int, int]: The shrunken bounding box represented by
(left, top, right, bottom).
"""

left, top, right, bottom = bbox
width = right - left
height = bottom - top
Expand Down Expand Up @@ -82,7 +72,6 @@ def bbox_is_valid(bbox: Any) -> bool:
"""
Verifies all 4 values in a bounding box exist and are positive.
"""

if not bbox:
return False
if len(bbox) != 4:
Expand All @@ -104,23 +93,7 @@ def sort_page_elements(
) -> list[Element]:
"""
Sorts a list of page elements based on the specified sorting mode.

Parameters:
- page_elements (list[Element]): A list of elements representing parts of a page. Each element
should have metadata containing coordinates.
- sort_mode (str, optional): The mode by which the elements will be sorted. Default is
SORT_MODE_XY_CUT.
- SORT_MODE_XY_CUT: Sorts elements based on XY-cut sorting approach. Requires the
recursive_xy_cut function and coordinates_to_bbox function to be defined. And requires all
elements to have valid cooridnates
- SORT_MODE_BASIC: Sorts elements based on their coordinates. Elements without coordinates
will be pushed to the end.
- If an unrecognized sort_mode is provided, the function returns the elements as-is.

Returns:
- list[Element]: A list of sorted page elements.
"""

shrink_factor = float(
os.environ.get("UNSTRUCTURED_XY_CUT_BBOX_SHRINK_FACTOR", shrink_factor),
)
Expand All @@ -133,20 +106,23 @@ def sort_page_elements(
if not page_elements:
return []

# --- ACADEMIC FIX START ---
# Force column-aware sorting if elements are dense (Academic papers)
if sort_mode == "COORDINATE_COLUMNS" or len(page_elements) > 80:
return sort_page_elements_columns(page_elements)
# --- ACADEMIC FIX END ---

coordinates_list = [el.metadata.coordinates for el in page_elements]

def _coords_ok(strict_points: bool):
warned = False

for coord in coordinates_list:
if coord is None or not coord.points:
trace_logger.detail( # type: ignore
"some or all elements are missing coordinates, skipping sort",
)
trace_logger.detail("some or all elements are missing coordinates, skipping sort")
return False
elif not coord_has_valid_points(coord):
if not warned:
trace_logger.detail(f"coord {coord} does not have valid points") # type: ignore
trace_logger.detail(f"coord {coord} does not have valid points")
warned = True
if strict_points:
return False
Expand Down Expand Up @@ -187,13 +163,56 @@ def _coords_ok(strict_points: bool):
return sorted_page_elements


def sort_page_elements_columns(page_elements: list[Element]) -> list[Element]:
"""
Handles academic double-column sorting by binning into Top, Left-Col, Right-Col, and Bottom zones.
"""
if not page_elements:
return []

all_coords = [el.metadata.coordinates.points for el in page_elements if el.metadata.coordinates]
if not all_coords:
return page_elements

max_x = max([p[2][0] for p in all_coords])
max_y = max([p[2][1] for p in all_coords])
mid_x = max_x / 2

top_block, left_col, right_col, bottom_block = [], [], [], []

for el in page_elements:
if not el.metadata.coordinates:
top_block.append(el)
continue

x_start, y_start = el.metadata.coordinates.points[0]
x_end, y_end = el.metadata.coordinates.points[2]

# Logic: Separating Footer, Header, and Two Columns
if y_start > (max_y * 0.92):
bottom_block.append(el)
elif (x_end - x_start) > (max_x * 0.65) or y_start < (max_y * 0.12):
top_block.append(el)
elif x_start < mid_x:
left_col.append(el)
else:
right_col.append(el)

y_sort = lambda e: e.metadata.coordinates.points[0][1] if e.metadata.coordinates else 0
top_block.sort(key=y_sort)
left_col.sort(key=y_sort)
right_col.sort(key=y_sort)
bottom_block.sort(key=y_sort)

return top_block + left_col + right_col + bottom_block


def sort_bboxes_by_xy_cut(
bboxes,
shrink_factor: float = 0.9,
xy_cut_primary_direction: str = "x",
):
"""Sort bounding boxes using XY-cut algorithm."""

shrunken_bboxes = []
for bbox in bboxes:
shrunken_bbox = shrink_bbox(bbox, shrink_factor)
Expand All @@ -218,51 +237,25 @@ def sort_text_regions(
xy_cut_primary_direction: str = "x",
) -> TextRegions:
"""Sort a list of TextRegion elements based on the specified sorting mode."""

if not elements:
return elements

bboxes = elements.element_coords

def _bboxes_ok(strict_points: bool):
if np.isnan(bboxes).any():
trace_logger.detail( # type: ignore
"some or all elements are missing bboxes, skipping sort",
)
return False

if bboxes.shape[1] != 4 or np.where(bboxes < 0)[0].size:
trace_logger.detail("at least one bbox contains invalid values") # type: ignore
if strict_points:
return False
return True

if sort_mode == SORT_MODE_XY_CUT:
if not _bboxes_ok(strict_points=True):
if np.isnan(bboxes).any():
return elements

shrink_factor = float(
os.environ.get("UNSTRUCTURED_XY_CUT_BBOX_SHRINK_FACTOR", shrink_factor),
)

xy_cut_primary_direction = os.environ.get(
"UNSTRUCTURED_XY_CUT_PRIMARY_DIRECTION",
xy_cut_primary_direction,
)

res = sort_bboxes_by_xy_cut(
bboxes=bboxes,
shrink_factor=shrink_factor,
xy_cut_primary_direction=xy_cut_primary_direction,
)
sorted_elements = elements.slice(res)
elif sort_mode == SORT_MODE_BASIC:
# NOTE (yao): lexsort order is revese from the input sequence; so below is first sort by y1,
# then x1, then y2, lastly x2
sorted_elements = elements.slice(
np.lexsort((elements.x2, elements.y2, elements.x1, elements.y1))
)
else:
sorted_elements = elements

return sorted_elements
return sorted_elements
Loading