Skip to content

Commit 209a547

Browse files
FIX: [DEV-15236] improvements to scheduler and table model predict speed. (#856)
* fix: make scheduler stats per-model * fix: preprocessing changes * chore: keep only table chunker optimization * fix: add defensive non-monotonic case
1 parent 7d72283 commit 209a547

1 file changed

Lines changed: 43 additions & 17 deletions

File tree

finetune/util/table_labeler.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Finetune-style interface for running a pipeline of table and non-table models.
33
"""
4+
import bisect
45
import copy
56
import functools
67
import logging
@@ -438,8 +439,12 @@ def __init__(
438439

439440
def get_axis_spans(self, context, token_bounds, context_key):
440441
max_row = max(r[context_key] for r in context)
441-
row_spans = [
442-
[
442+
row_spans = [[] for _ in range(max_row + 1)]
443+
for c in context:
444+
row_idx = c[context_key]
445+
if row_idx < 0:
446+
continue
447+
row_spans[row_idx].append(
443448
{
444449
"start": c["start"],
445450
"end": c["end"],
@@ -448,11 +453,7 @@ def get_axis_spans(self, context, token_bounds, context_key):
448453
)
449454
+ 1,
450455
}
451-
for c in context
452-
if i == c[context_key]
453-
]
454-
for i in range(max_row + 1)
455-
]
456+
)
456457
return self.combine_row_spans(row_spans, token_bounds)
457458

458459
def chunk(self, table_text_chunks_and_context):
@@ -594,17 +595,21 @@ def _make_chunks(self, row_spans):
594595
break
595596
max_len_chunks = []
596597
temp_rows = []
598+
temp_row_tokens = 0
597599
context_included = False
598600
for row in row_spans[n_rows_context:]:
599-
if self._num_tokens(context + temp_rows + [row]) < self.max_length:
601+
row_tokens = row["num_effective_tokens"]
602+
if (context_tokens + temp_row_tokens + row_tokens) < self.max_length:
600603
temp_rows.append(row)
604+
temp_row_tokens += row_tokens
601605
elif len(temp_rows) == 0:
602606
# The current row is too long to use any context at all.
603607
max_len_chunks.append([row])
604608
else:
605609
context_included = True
606610
max_len_chunks.append(copy.deepcopy(context) + temp_rows)
607611
temp_rows = [row]
612+
temp_row_tokens = row_tokens
608613
if temp_rows or not context_included:
609614
max_len_chunks.append(copy.deepcopy(context) + temp_rows)
610615
output_spans = []
@@ -626,11 +631,18 @@ def _make_chunks(self, row_spans):
626631
return output_spans
627632

628633
def combine_row_spans(self, row_spans, token_spans):
629-
def mark_token(t):
630-
t["used"] = True
631-
return t
632-
633-
total_num_tokens = 0
634+
token_starts = []
635+
token_ends = []
636+
token_spans_monotonic = True
637+
for token in token_spans:
638+
token_start = token["start"]
639+
token_end = token["end"]
640+
if token_starts and (
641+
token_start < token_starts[-1] or token_end < token_ends[-1]
642+
):
643+
token_spans_monotonic = False
644+
token_starts.append(token_start)
645+
token_ends.append(token_end)
634646
combined_rows = []
635647
for row in row_spans:
636648
row_out = []
@@ -640,14 +652,28 @@ def mark_token(t):
640652
else:
641653
row_out.append(span)
642654
for row_span in row_out:
643-
row_span["num_tokens"] = len(
644-
[mark_token(t) for t in token_spans if overlaps_token(row_span, t)]
645-
)
655+
num_tokens = 0
656+
if token_spans_monotonic:
657+
token_idx = bisect.bisect_left(token_ends, row_span["start"])
658+
while (
659+
token_idx < len(token_spans)
660+
and token_starts[token_idx] <= row_span["end"]
661+
):
662+
token = token_spans[token_idx]
663+
if overlaps_token(row_span, token):
664+
token["used"] = True
665+
num_tokens += 1
666+
token_idx += 1
667+
else:
668+
for token in token_spans:
669+
if overlaps_token(row_span, token):
670+
token["used"] = True
671+
num_tokens += 1
672+
row_span["num_tokens"] = num_tokens
646673
# Accounts for the fact that cells are duplicated when they span cells.
647674
row_span["num_effective_tokens"] = (
648675
row_span["num_tokens"] * row_span["max_cell_span"]
649676
)
650-
total_num_tokens += row_span["num_tokens"]
651677
combined_rows.append(
652678
{
653679
"num_tokens": sum(r["num_tokens"] for r in row_out),

0 commit comments

Comments
 (0)