diff --git a/prettyprinter/textractprettyprinter/t_pretty_print_layout.py b/prettyprinter/textractprettyprinter/t_pretty_print_layout.py index b0f997bc..8ff1910e 100644 --- a/prettyprinter/textractprettyprinter/t_pretty_print_layout.py +++ b/prettyprinter/textractprettyprinter/t_pretty_print_layout.py @@ -41,12 +41,36 @@ def _get_layout_blocks(self) -> tuple: for optimum output") return layouts, id2block - def _geometry_match(self, geom1, geom2, tolerance=0.1): - """Check if two geometries match within a given tolerance.""" - for key in ['Width', 'Height', 'Left', 'Top']: - if abs(geom1[key] - geom2[key]) > tolerance: - return False - return True + def _iou(self, boxA, boxB): + """Compute the Intersection-over-Union (IoU) for two bounding boxes.""" + # boxA edges + A_left = boxA["Left"] + A_top = boxA["Top"] + A_right = A_left + boxA["Width"] + A_bottom = A_top + boxA["Height"] + # boxB edges + B_left = boxB["Left"] + B_top = boxB["Top"] + B_right = B_left + boxB["Width"] + B_bottom = B_top + boxB["Height"] + + # intersection + interLeft = max(A_left, B_left) + interRight = min(A_right, B_right) + interTop = max(A_top, B_top) + interBottom = min(A_bottom, B_bottom) + + if interRight < interLeft or interBottom < interTop: + return 0.0 + + interArea = (interRight - interLeft) * (interBottom - interTop) + areaA = (A_right - A_left) * (A_bottom - A_top) + areaB = (B_right - B_left) * (B_bottom - B_top) + union = areaA + areaB - interArea + if union <= 0: + return 0.0 + return interArea / union + def _is_inside(self, inner_geom, outer_geom): """Check if inner geometry is fully contained within the outer geometry.""" @@ -76,7 +100,28 @@ def _validate_block_skip(self, blockType: str) -> bool: return True else: return False - + + def _find_best_table_match(self, layout_box, page_num, tolerance=.2): + """For all 'TABLE' blocks on the same page, compute IoU with 'layout_box'. Return the TABLE block with highest IoU if it exceeds a threshold""" + best_iou = 0.0 + best_table = None + + candidates = [ + b for b in self.j["Blocks"] + if b["BlockType"] == "TABLE" + and b.get("Page", 1) == page_num + ] + + for tb in candidates: + iou_val = self._iou(layout_box, tb["Geometry"]["BoundingBox"]) + if iou_val > best_iou: + best_iou = iou_val + best_table = tb + + if best_iou < tolerance: + return None + return best_table + def _dfs(self, root, id2block): texts = [] stack = [(root, 0)] @@ -90,14 +135,10 @@ def _dfs(self, root, id2block): # Handle LAYOUT_TABLE type if not self.skip_table and block["BlockType"] == "LAYOUT_TABLE": - table_data = [] - # Find the matching TABLE block for the LAYOUT_TABLE - table_block = None - for potential_table in [b for b in self.j['Blocks'] if b['BlockType'] == 'TABLE' and b.get('Page',1) == block.get('Page', 1)]: - if self._geometry_match(block['Geometry']['BoundingBox'], potential_table['Geometry']['BoundingBox']): - table_block = potential_table - break + layout_box = block["Geometry"]["BoundingBox"] + page_num = block.get("Page", 1) + table_block = self._find_best_table_match(layout_box, page_num) if table_block and "Relationships" in table_block: table_content = {} headers = {} @@ -119,7 +160,7 @@ def _dfs(self, root, id2block): headers[col_idx + c] = cell_text else: table_content[(row_idx + r, col_idx + c)] = cell_text - + table_data = [] start_row = 2 if headers else 1 for r in range(start_row, max_row + 1): @@ -129,7 +170,7 @@ def _dfs(self, root, id2block): table_data.append(row_data) header_list = [headers.get(c, "") for c in range(1, max_col + 1)] - + try: from tabulate import tabulate except ImportError: @@ -147,7 +188,7 @@ def _dfs(self, root, id2block): else: logger.warning("LAYOUT_TABLE detected but TABLES feature was not provided in API call. \ Inlcuding TABLES feature may improve the layout output") - + if block["BlockType"] == "LINE" and "Text" in block: if self.exclude_figure_text and self.figures: if any(self._is_inside(block['Geometry']['BoundingBox'], figure_geom["geometry"]) \ @@ -166,13 +207,13 @@ def _dfs(self, root, id2block): elif block["BlockType"] == "LAYOUT_SECTION_HEADER": combined_text = f"## {combined_text}" yield combined_text - + if block["BlockType"].startswith('LAYOUT') and block["BlockType"] not in ["LAYOUT_TITLE", "LAYOUT_SECTION_HEADER"]: if "Relationships" in block: relationships = block["Relationships"] children = [(x, depth + 1) for x in relationships[0]['Ids']] stack.extend(reversed(children)) - + def _save_to_s3(self, page_texts: dict) -> None: try: import boto3