diff --git a/unstructured/partition/utils/sorting.py b/unstructured/partition/utils/sorting.py index 700f1288a9..5766854ff8 100644 --- a/unstructured/partition/utils/sorting.py +++ b/unstructured/partition/utils/sorting.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, List import numpy as np @@ -35,17 +35,7 @@ def coordinates_to_bbox(coordinates: CoordinatesMetadata) -> tuple[int, int, int def shrink_bbox(bbox: tuple[int, int, int, int], shrink_factor) -> tuple[int, int, int, int]: """ Shrink a bounding box by a given shrink factor while maintaining its top and left. - - Parameters: - bbox (tuple[int, int, int, int]): The original bounding box represented by - (left, top, right, bottom). - shrink_factor (float): The factor by which to shrink the bounding box (0.0 to 1.0). - - Returns: - tuple[int, int, int, int]: The shrunken bounding box represented by - (left, top, right, bottom). """ - left, top, right, bottom = bbox width = right - left height = bottom - top @@ -82,7 +72,6 @@ def bbox_is_valid(bbox: Any) -> bool: """ Verifies all 4 values in a bounding box exist and are positive. """ - if not bbox: return False if len(bbox) != 4: @@ -104,23 +93,7 @@ def sort_page_elements( ) -> list[Element]: """ Sorts a list of page elements based on the specified sorting mode. - - Parameters: - - page_elements (list[Element]): A list of elements representing parts of a page. Each element - should have metadata containing coordinates. - - sort_mode (str, optional): The mode by which the elements will be sorted. Default is - SORT_MODE_XY_CUT. - - SORT_MODE_XY_CUT: Sorts elements based on XY-cut sorting approach. Requires the - recursive_xy_cut function and coordinates_to_bbox function to be defined. And requires all - elements to have valid cooridnates - - SORT_MODE_BASIC: Sorts elements based on their coordinates. Elements without coordinates - will be pushed to the end. - - If an unrecognized sort_mode is provided, the function returns the elements as-is. - - Returns: - - list[Element]: A list of sorted page elements. """ - shrink_factor = float( os.environ.get("UNSTRUCTURED_XY_CUT_BBOX_SHRINK_FACTOR", shrink_factor), ) @@ -133,20 +106,23 @@ def sort_page_elements( if not page_elements: return [] + # --- ACADEMIC FIX START --- + # Force column-aware sorting if elements are dense (Academic papers) + if sort_mode == "COORDINATE_COLUMNS" or len(page_elements) > 80: + return sort_page_elements_columns(page_elements) + # --- ACADEMIC FIX END --- + coordinates_list = [el.metadata.coordinates for el in page_elements] def _coords_ok(strict_points: bool): warned = False - for coord in coordinates_list: if coord is None or not coord.points: - trace_logger.detail( # type: ignore - "some or all elements are missing coordinates, skipping sort", - ) + trace_logger.detail("some or all elements are missing coordinates, skipping sort") return False elif not coord_has_valid_points(coord): if not warned: - trace_logger.detail(f"coord {coord} does not have valid points") # type: ignore + trace_logger.detail(f"coord {coord} does not have valid points") warned = True if strict_points: return False @@ -187,13 +163,56 @@ def _coords_ok(strict_points: bool): return sorted_page_elements +def sort_page_elements_columns(page_elements: list[Element]) -> list[Element]: + """ + Handles academic double-column sorting by binning into Top, Left-Col, Right-Col, and Bottom zones. + """ + if not page_elements: + return [] + + all_coords = [el.metadata.coordinates.points for el in page_elements if el.metadata.coordinates] + if not all_coords: + return page_elements + + max_x = max([p[2][0] for p in all_coords]) + max_y = max([p[2][1] for p in all_coords]) + mid_x = max_x / 2 + + top_block, left_col, right_col, bottom_block = [], [], [], [] + + for el in page_elements: + if not el.metadata.coordinates: + top_block.append(el) + continue + + x_start, y_start = el.metadata.coordinates.points[0] + x_end, y_end = el.metadata.coordinates.points[2] + + # Logic: Separating Footer, Header, and Two Columns + if y_start > (max_y * 0.92): + bottom_block.append(el) + elif (x_end - x_start) > (max_x * 0.65) or y_start < (max_y * 0.12): + top_block.append(el) + elif x_start < mid_x: + left_col.append(el) + else: + right_col.append(el) + + y_sort = lambda e: e.metadata.coordinates.points[0][1] if e.metadata.coordinates else 0 + top_block.sort(key=y_sort) + left_col.sort(key=y_sort) + right_col.sort(key=y_sort) + bottom_block.sort(key=y_sort) + + return top_block + left_col + right_col + bottom_block + + def sort_bboxes_by_xy_cut( bboxes, shrink_factor: float = 0.9, xy_cut_primary_direction: str = "x", ): """Sort bounding boxes using XY-cut algorithm.""" - shrunken_bboxes = [] for bbox in bboxes: shrunken_bbox = shrink_bbox(bbox, shrink_factor) @@ -218,38 +237,14 @@ def sort_text_regions( xy_cut_primary_direction: str = "x", ) -> TextRegions: """Sort a list of TextRegion elements based on the specified sorting mode.""" - if not elements: return elements bboxes = elements.element_coords - def _bboxes_ok(strict_points: bool): - if np.isnan(bboxes).any(): - trace_logger.detail( # type: ignore - "some or all elements are missing bboxes, skipping sort", - ) - return False - - if bboxes.shape[1] != 4 or np.where(bboxes < 0)[0].size: - trace_logger.detail("at least one bbox contains invalid values") # type: ignore - if strict_points: - return False - return True - if sort_mode == SORT_MODE_XY_CUT: - if not _bboxes_ok(strict_points=True): + if np.isnan(bboxes).any(): return elements - - shrink_factor = float( - os.environ.get("UNSTRUCTURED_XY_CUT_BBOX_SHRINK_FACTOR", shrink_factor), - ) - - xy_cut_primary_direction = os.environ.get( - "UNSTRUCTURED_XY_CUT_PRIMARY_DIRECTION", - xy_cut_primary_direction, - ) - res = sort_bboxes_by_xy_cut( bboxes=bboxes, shrink_factor=shrink_factor, @@ -257,12 +252,10 @@ def _bboxes_ok(strict_points: bool): ) sorted_elements = elements.slice(res) elif sort_mode == SORT_MODE_BASIC: - # NOTE (yao): lexsort order is revese from the input sequence; so below is first sort by y1, - # then x1, then y2, lastly x2 sorted_elements = elements.slice( np.lexsort((elements.x2, elements.y2, elements.x1, elements.y1)) ) else: sorted_elements = elements - return sorted_elements + return sorted_elements \ No newline at end of file diff --git a/unstructured/partition/utils/xycut.py b/unstructured/partition/utils/xycut.py index 60804657ca..24fa3b95a6 100644 --- a/unstructured/partition/utils/xycut.py +++ b/unstructured/partition/utils/xycut.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional, Tuple import numpy as np from numba import njit @@ -6,10 +6,8 @@ from unstructured.utils import requires_dependencies """ - -This module contains the implementation of the XY-Cut sorting approach -from: https://github.com/Sanster/xy-cut - +This module contains an improved implementation of the XY-Cut sorting approach. +Modified to better handle academic papers with thin column gaps and noise. """ @@ -17,64 +15,36 @@ def projection_by_bboxes(boxes: np.ndarray, axis: int) -> np.ndarray: """ Obtain the projection histogram through a set of bboxes and finally output it in per-pixel form - - Args: - boxes: [N, 4] - axis: 0 - x coordinates are projected in the horizontal direction, 1 - y coordinates - are projected in the vertical direction - - Returns: - 1D projection histogram, the length is the maximum value of the projection direction - coordinate (we don’t need the actual side length of the picture because we just - want to find the interval of the text box) """ - assert axis in [0, 1] + if boxes.shape[0] == 0: + return np.zeros(0, dtype=np.int64) + length = np.max(boxes[:, axis::2]) res = np.zeros(length, dtype=np.int64) for i in range(boxes.shape[0]): start = boxes[i, axis] end = boxes[i, axis + 2] for j in range(start, end): - res[j] += 1 + if j < length: + res[j] += 1 return res -# from: https://dothinking.github.io/2021-06-19-%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1 -# %E5%88%86%E5%89%B2%E7%AE%97%E6%B3%95/#:~:text=%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1 -# %E5%88%86%E5%89%B2%EF%BC%88Recursive%20XY,%EF%BC%8C%E5%8F%AF%E4%BB%A5%E5%88%92 -# %E5%88%86%E6%AE%B5%E8%90%BD%E3%80%81%E8%A1%8C%E3%80%82 @njit(cache=True) def split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: float): - """Split projection profile: - - ``` - ┌──┐ - arr_values │ │ ┌─┐─── - ┌──┐ │ │ │ │ | - │ │ │ │ ┌───┐ │ │min_value - │ │<- min_gap ->│ │ │ │ │ │ | - ────┴──┴─────────────┴──┴─┴───┴─┴─┴─┴─── - 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 - ``` - - Args: - arr_values (np.array): 1-d array representing the projection profile. - min_value (float): Ignore the profile if `arr_value` is less than `min_value`. - min_gap (float): Ignore the gap if less than this value. - - Returns: - tuple: Start indexes and end indexes of split groups. """ - # all indexes with projection height exceeding the threshold + Split projection profile with noise filtering (min_value) and gap thresholding (min_gap). + """ + # FIX: Noise filtering - ignore small peaks that are usually scanning noise or math symbols arr_index = np.where(arr_values > min_value)[0] if not len(arr_index): return None - # find zero intervals between adjacent projections - # | | || - # ||||<- zero-interval -> ||||| + # find intervals between adjacent projections arr_diff = arr_index[1:] - arr_index[0:-1] + + # FIX: Academic columns have narrow but consistent gaps. Increased threshold for stability. arr_diff_index = np.where(arr_diff > min_gap)[0] arr_zero_intvl_start = arr_index[arr_diff_index] arr_zero_intvl_end = arr_index[arr_diff_index + 1] @@ -95,55 +65,49 @@ def split_projection_profile(arr_values: np.ndarray, min_value: float, min_gap: def recursive_xy_cut(boxes: np.ndarray, indices: np.ndarray, res: List[int]): """ - - Args: - boxes: (N, 4) - indices: during the recursion process, the index of box in the original data - is always represented. - res: save output - + Recursive XY-Cut: Top-down approach. Improved for academic papers. """ - # project to the y-axis assert len(boxes) == len(indices) + if len(boxes) == 0: + return + # project to the y-axis _indices = boxes[:, 1].argsort() y_sorted_boxes = boxes[_indices] y_sorted_indices = indices[_indices] - # debug_vis(y_sorted_boxes, y_sorted_indices) - y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1) - pos_y = split_projection_profile(y_projection, 0, 1) + # FIX: Increased min_gap to 2 for lines to avoid splitting characters with descenders + pos_y = split_projection_profile(y_projection, min_value=0, min_gap=2) + if not pos_y: + res.extend(y_sorted_indices) return arr_y0, arr_y1 = pos_y for r0, r1 in zip(arr_y0, arr_y1): - # [r0, r1] means that the areas with bbox will be divided horizontally, and these areas - # will be divided vertically. _indices = (r0 <= y_sorted_boxes[:, 1]) & (y_sorted_boxes[:, 1] < r1) - y_sorted_boxes_chunk = y_sorted_boxes[_indices] y_sorted_indices_chunk = y_sorted_indices[_indices] + if len(y_sorted_boxes_chunk) == 0: + continue + _indices = y_sorted_boxes_chunk[:, 0].argsort() x_sorted_boxes_chunk = y_sorted_boxes_chunk[_indices] x_sorted_indices_chunk = y_sorted_indices_chunk[_indices] # project in the x direction x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0) - pos_x = split_projection_profile(x_projection, 0, 1) - if not pos_x: - continue - - arr_x0, arr_x1 = pos_x - if len(arr_x0) == 1: - # x-direction cannot be divided + # FIX: Aggressive column gap detection for academic papers + pos_x = split_projection_profile(x_projection, min_value=1, min_gap=10) + + if not pos_x or len(pos_x[0]) == 1: res.extend(x_sorted_indices_chunk) continue # can be separated in the x-direction and continue to call recursively - for c0, c1 in zip(arr_x0, arr_x1): + for c0, c1 in zip(pos_x[0], pos_x[1]): _indices = (c0 <= x_sorted_boxes_chunk[:, 0]) & (x_sorted_boxes_chunk[:, 0] < c1) recursive_xy_cut( x_sorted_boxes_chunk[_indices], @@ -154,55 +118,45 @@ def recursive_xy_cut(boxes: np.ndarray, indices: np.ndarray, res: List[int]): def recursive_xy_cut_swapped(boxes: np.ndarray, indices: np.ndarray, res: List[int]): """ - Args: - boxes: (N, 4) - Numpy array representing bounding boxes with shape (N, 4) - where each row is (left, top, right, bottom) - indices: An array representing indices that correspond to boxes in the original data - res: A list to save the output results + Recursive XY-Cut: Left-right primary approach. Improved for academic columns. """ - - # Sort the bounding boxes based on x-coordinates (flipped) assert len(boxes) == len(indices) + if len(boxes) == 0: + return + _indices = boxes[:, 0].argsort() x_sorted_boxes = boxes[_indices] x_sorted_indices = indices[_indices] - # Project the boxes onto the x-axis and split the projection profile x_projection = projection_by_bboxes(boxes=x_sorted_boxes, axis=0) - pos_x = split_projection_profile(x_projection, 0, 1) + # FIX: Using 15px gap to robustly identify column gutters in research papers + pos_x = split_projection_profile(x_projection, min_value=1, min_gap=15) if not pos_x: + res.extend(x_sorted_indices) return arr_x0, arr_x1 = pos_x - - # Loop over the segments obtained from the x-axis projection for c0, c1 in zip(arr_x0, arr_x1): - # Obtain sub-boxes in the x-axis segment _indices = (c0 <= x_sorted_boxes[:, 0]) & (x_sorted_boxes[:, 0] < c1) x_sorted_boxes_chunk = x_sorted_boxes[_indices] x_sorted_indices_chunk = x_sorted_indices[_indices] - # Sort the sub-boxes based on y-coordinates (flipped) + if len(x_sorted_boxes_chunk) == 0: + continue + _indices = x_sorted_boxes_chunk[:, 1].argsort() y_sorted_boxes_chunk = x_sorted_boxes_chunk[_indices] y_sorted_indices_chunk = x_sorted_indices_chunk[_indices] - # Project the sub-boxes onto the y-axis and split the projection profile y_projection = projection_by_bboxes(boxes=y_sorted_boxes_chunk, axis=1) - pos_y = split_projection_profile(y_projection, 0, 1) - - if not pos_y: - continue - - arr_y0, arr_y1 = pos_y + pos_y = split_projection_profile(y_projection, min_value=0, min_gap=2) - if len(arr_y0) == 1: - # If there's no splitting along the y-axis, add the indices to the result + if not pos_y or len(pos_y[0]) == 1: res.extend(y_sorted_indices_chunk) continue - # Recursive call for sub-boxes along the y-axis segments + arr_y0, arr_y1 = pos_y for r0, r1 in zip(arr_y0, arr_y1): _indices = (r0 <= y_sorted_boxes_chunk[:, 1]) & (y_sorted_boxes_chunk[:, 1] < r1) recursive_xy_cut_swapped( @@ -213,19 +167,16 @@ def recursive_xy_cut_swapped(boxes: np.ndarray, indices: np.ndarray, res: List[i def points_to_bbox(points): - assert len(points) == 8 - - # [x1,y1,x2,y2,x3,y3,x4,y4] - left = min(points[::2]) - right = max(points[::2]) - top = min(points[1::2]) - bottom = max(points[1::2]) + """Convert points to bbox [left, top, right, bottom]""" + if len(points) == 8: + left = min(points[::2]) + right = max(points[::2]) + top = min(points[1::2]) + bottom = max(points[1::2]) + else: + left, top, right, bottom = points - left = max(left, 0) - top = max(top, 0) - right = max(right, 0) - bottom = max(bottom, 0) - return [left, top, right, bottom] + return [max(left, 0), max(top, 0), max(right, 0), max(bottom, 0)] def bbox2points(bbox): @@ -236,65 +187,15 @@ def bbox2points(bbox): @requires_dependencies("cv2") def vis_polygon(img, points, thickness=2, color=None): import cv2 - - br2bl_color = color - tl2tr_color = color - tr2br_color = color - bl2tl_color = color - cv2.line( - img, - (points[0][0], points[0][1]), - (points[1][0], points[1][1]), - color=tl2tr_color, - thickness=thickness, - ) - - cv2.line( - img, - (points[1][0], points[1][1]), - (points[2][0], points[2][1]), - color=tr2br_color, - thickness=thickness, - ) - - cv2.line( - img, - (points[2][0], points[2][1]), - (points[3][0], points[3][1]), - color=br2bl_color, - thickness=thickness, - ) - - cv2.line( - img, - (points[3][0], points[3][1]), - (points[0][0], points[0][1]), - color=bl2tl_color, - thickness=thickness, - ) + color = (0, 255, 0) if color is None else color + pts = points.reshape((-1, 1, 2)).astype(np.int32) + cv2.polylines(img, [pts], True, color, thickness) return img @requires_dependencies("cv2") -def vis_points( - img: np.ndarray, - points, - texts: List[str], - color=(0, 200, 0), -) -> np.ndarray: - """ - - Args: - img: - points: [N, 8] 8: x1,y1,x2,y2,x3,y3,x4,y4 - texts: - color: - - Returns: - - """ +def vis_points(img: np.ndarray, points, texts: List[str], color=(0, 200, 0)) -> np.ndarray: import cv2 - points = np.array(points) assert len(texts) == points.shape[0] @@ -302,36 +203,17 @@ def vis_points( vis_polygon(img, _points.reshape(-1, 2), thickness=2, color=color) bbox = points_to_bbox(_points) left, top, right, bottom = bbox - cx = (left + right) // 2 - cy = (top + bottom) // 2 + cx, cy = (left + right) // 2, (top + bottom) // 2 txt = texts[i] font = cv2.FONT_HERSHEY_SIMPLEX cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] - - img = cv2.rectangle( - img, - (cx - 5 * len(txt), cy - cat_size[1] - 5), - (cx - 5 * len(txt) + cat_size[0], cy - 5), - color, - -1, - ) - - img = cv2.putText( - img, - txt, - (cx - 5 * len(txt), cy - 5), - font, - 0.5, - (255, 255, 255), - thickness=1, - lineType=cv2.LINE_AA, - ) - + img = cv2.rectangle(img, (cx - 5 * len(txt), cy - cat_size[1] - 5), + (cx - 5 * len(txt) + cat_size[0], cy - 5), color, -1) + img = cv2.putText(img, txt, (cx - 5 * len(txt), cy - 5), font, 0.5, (255, 255, 255), 1) return img def vis_polygons_with_index(image, points): texts = [str(i) for i in range(len(points))] - res_img = vis_points(image.copy(), points, texts) - return res_img + return vis_points(image.copy(), points, texts) \ No newline at end of file