Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 43 additions & 17 deletions finetune/util/table_labeler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This files changes are probably worth pulling in. Big improvements for the risk on larger tables.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other changes maybe we drop as being not valuable enough for the risk at this point.

Finetune-style interface for running a pipeline of table and non-table models.
"""
import bisect
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know this was part of stdlib!

import copy
import functools
import logging
Expand Down Expand Up @@ -438,8 +439,12 @@ def __init__(

def get_axis_spans(self, context, token_bounds, context_key):
max_row = max(r[context_key] for r in context)
row_spans = [
[
row_spans = [[] for _ in range(max_row + 1)]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bucketing change is part of the chunker speedup. On the synthetic 150 x 20 table chunking benchmark, the chunker dropped from 55.443s to 0.263s (~99.5% faster) after the get_axis_spans / _make_chunks / combine_row_spans optimization pass.

for c in context:
row_idx = c[context_key]
if row_idx < 0:
continue
row_spans[row_idx].append(
{
"start": c["start"],
"end": c["end"],
Expand All @@ -448,11 +453,7 @@ def get_axis_spans(self, context, token_bounds, context_key):
)
+ 1,
}
for c in context
if i == c[context_key]
]
for i in range(max_row + 1)
]
)
return self.combine_row_spans(row_spans, token_bounds)

def chunk(self, table_text_chunks_and_context):
Expand Down Expand Up @@ -594,17 +595,21 @@ def _make_chunks(self, row_spans):
break
max_len_chunks = []
temp_rows = []
temp_row_tokens = 0
context_included = False
for row in row_spans[n_rows_context:]:
if self._num_tokens(context + temp_rows + [row]) < self.max_length:
row_tokens = row["num_effective_tokens"]
if (context_tokens + temp_row_tokens + row_tokens) < self.max_length:
temp_rows.append(row)
temp_row_tokens += row_tokens
elif len(temp_rows) == 0:
# The current row is too long to use any context at all.
max_len_chunks.append([row])
else:
context_included = True
max_len_chunks.append(copy.deepcopy(context) + temp_rows)
temp_rows = [row]
temp_row_tokens = row_tokens
if temp_rows or not context_included:
max_len_chunks.append(copy.deepcopy(context) + temp_rows)
output_spans = []
Expand All @@ -626,11 +631,18 @@ def _make_chunks(self, row_spans):
return output_spans

def combine_row_spans(self, row_spans, token_spans):
def mark_token(t):
t["used"] = True
return t

total_num_tokens = 0
token_starts = []
token_ends = []
token_spans_monotonic = True
for token in token_spans:
token_start = token["start"]
token_end = token["end"]
if token_starts and (
token_start < token_starts[-1] or token_end < token_ends[-1]
):
token_spans_monotonic = False
token_starts.append(token_start)
token_ends.append(token_end)
combined_rows = []
for row in row_spans:
row_out = []
Expand All @@ -640,14 +652,28 @@ def mark_token(t):
else:
row_out.append(span)
for row_span in row_out:
row_span["num_tokens"] = len(
[mark_token(t) for t in token_spans if overlaps_token(row_span, t)]
)
num_tokens = 0
if token_spans_monotonic:
token_idx = bisect.bisect_left(token_ends, row_span["start"])
while (
token_idx < len(token_spans)
and token_starts[token_idx] <= row_span["end"]
):
token = token_spans[token_idx]
if overlaps_token(row_span, token):
token["used"] = True
num_tokens += 1
token_idx += 1
else:
for token in token_spans:
if overlaps_token(row_span, token):
token["used"] = True
num_tokens += 1
row_span["num_tokens"] = num_tokens
# Accounts for the fact that cells are duplicated when they span cells.
row_span["num_effective_tokens"] = (
row_span["num_tokens"] * row_span["max_cell_span"]
)
total_num_tokens += row_span["num_tokens"]
combined_rows.append(
{
"num_tokens": sum(r["num_tokens"] for r in row_out),
Expand Down
Loading