diff --git a/finetune/util/table_labeler.py b/finetune/util/table_labeler.py index 06bf434e..0c8f776f 100644 --- a/finetune/util/table_labeler.py +++ b/finetune/util/table_labeler.py @@ -1,6 +1,7 @@ """ Finetune-style interface for running a pipeline of table and non-table models. """ +import bisect import copy import functools import logging @@ -438,8 +439,12 @@ def __init__( def get_axis_spans(self, context, token_bounds, context_key): max_row = max(r[context_key] for r in context) - row_spans = [ - [ + row_spans = [[] for _ in range(max_row + 1)] + for c in context: + row_idx = c[context_key] + if row_idx < 0: + continue + row_spans[row_idx].append( { "start": c["start"], "end": c["end"], @@ -448,11 +453,7 @@ def get_axis_spans(self, context, token_bounds, context_key): ) + 1, } - for c in context - if i == c[context_key] - ] - for i in range(max_row + 1) - ] + ) return self.combine_row_spans(row_spans, token_bounds) def chunk(self, table_text_chunks_and_context): @@ -594,10 +595,13 @@ def _make_chunks(self, row_spans): break max_len_chunks = [] temp_rows = [] + temp_row_tokens = 0 context_included = False for row in row_spans[n_rows_context:]: - if self._num_tokens(context + temp_rows + [row]) < self.max_length: + row_tokens = row["num_effective_tokens"] + if (context_tokens + temp_row_tokens + row_tokens) < self.max_length: temp_rows.append(row) + temp_row_tokens += row_tokens elif len(temp_rows) == 0: # The current row is too long to use any context at all. max_len_chunks.append([row]) @@ -605,6 +609,7 @@ def _make_chunks(self, row_spans): context_included = True max_len_chunks.append(copy.deepcopy(context) + temp_rows) temp_rows = [row] + temp_row_tokens = row_tokens if temp_rows or not context_included: max_len_chunks.append(copy.deepcopy(context) + temp_rows) output_spans = [] @@ -626,11 +631,18 @@ def _make_chunks(self, row_spans): return output_spans def combine_row_spans(self, row_spans, token_spans): - def mark_token(t): - t["used"] = True - return t - - total_num_tokens = 0 + token_starts = [] + token_ends = [] + token_spans_monotonic = True + for token in token_spans: + token_start = token["start"] + token_end = token["end"] + if token_starts and ( + token_start < token_starts[-1] or token_end < token_ends[-1] + ): + token_spans_monotonic = False + token_starts.append(token_start) + token_ends.append(token_end) combined_rows = [] for row in row_spans: row_out = [] @@ -640,14 +652,28 @@ def mark_token(t): else: row_out.append(span) for row_span in row_out: - row_span["num_tokens"] = len( - [mark_token(t) for t in token_spans if overlaps_token(row_span, t)] - ) + num_tokens = 0 + if token_spans_monotonic: + token_idx = bisect.bisect_left(token_ends, row_span["start"]) + while ( + token_idx < len(token_spans) + and token_starts[token_idx] <= row_span["end"] + ): + token = token_spans[token_idx] + if overlaps_token(row_span, token): + token["used"] = True + num_tokens += 1 + token_idx += 1 + else: + for token in token_spans: + if overlaps_token(row_span, token): + token["used"] = True + num_tokens += 1 + row_span["num_tokens"] = num_tokens # Accounts for the fact that cells are duplicated when they span cells. row_span["num_effective_tokens"] = ( row_span["num_tokens"] * row_span["max_cell_span"] ) - total_num_tokens += row_span["num_tokens"] combined_rows.append( { "num_tokens": sum(r["num_tokens"] for r in row_out),