11"""
22Finetune-style interface for running a pipeline of table and non-table models.
33"""
4+ import bisect
45import copy
56import functools
67import logging
@@ -438,8 +439,12 @@ def __init__(
438439
439440 def get_axis_spans (self , context , token_bounds , context_key ):
440441 max_row = max (r [context_key ] for r in context )
441- row_spans = [
442- [
442+ row_spans = [[] for _ in range (max_row + 1 )]
443+ for c in context :
444+ row_idx = c [context_key ]
445+ if row_idx < 0 :
446+ continue
447+ row_spans [row_idx ].append (
443448 {
444449 "start" : c ["start" ],
445450 "end" : c ["end" ],
@@ -448,11 +453,7 @@ def get_axis_spans(self, context, token_bounds, context_key):
448453 )
449454 + 1 ,
450455 }
451- for c in context
452- if i == c [context_key ]
453- ]
454- for i in range (max_row + 1 )
455- ]
456+ )
456457 return self .combine_row_spans (row_spans , token_bounds )
457458
458459 def chunk (self , table_text_chunks_and_context ):
@@ -594,17 +595,21 @@ def _make_chunks(self, row_spans):
594595 break
595596 max_len_chunks = []
596597 temp_rows = []
598+ temp_row_tokens = 0
597599 context_included = False
598600 for row in row_spans [n_rows_context :]:
599- if self ._num_tokens (context + temp_rows + [row ]) < self .max_length :
601+ row_tokens = row ["num_effective_tokens" ]
602+ if (context_tokens + temp_row_tokens + row_tokens ) < self .max_length :
600603 temp_rows .append (row )
604+ temp_row_tokens += row_tokens
601605 elif len (temp_rows ) == 0 :
602606 # The current row is too long to use any context at all.
603607 max_len_chunks .append ([row ])
604608 else :
605609 context_included = True
606610 max_len_chunks .append (copy .deepcopy (context ) + temp_rows )
607611 temp_rows = [row ]
612+ temp_row_tokens = row_tokens
608613 if temp_rows or not context_included :
609614 max_len_chunks .append (copy .deepcopy (context ) + temp_rows )
610615 output_spans = []
@@ -626,11 +631,18 @@ def _make_chunks(self, row_spans):
626631 return output_spans
627632
628633 def combine_row_spans (self , row_spans , token_spans ):
629- def mark_token (t ):
630- t ["used" ] = True
631- return t
632-
633- total_num_tokens = 0
634+ token_starts = []
635+ token_ends = []
636+ token_spans_monotonic = True
637+ for token in token_spans :
638+ token_start = token ["start" ]
639+ token_end = token ["end" ]
640+ if token_starts and (
641+ token_start < token_starts [- 1 ] or token_end < token_ends [- 1 ]
642+ ):
643+ token_spans_monotonic = False
644+ token_starts .append (token_start )
645+ token_ends .append (token_end )
634646 combined_rows = []
635647 for row in row_spans :
636648 row_out = []
@@ -640,14 +652,28 @@ def mark_token(t):
640652 else :
641653 row_out .append (span )
642654 for row_span in row_out :
643- row_span ["num_tokens" ] = len (
644- [mark_token (t ) for t in token_spans if overlaps_token (row_span , t )]
645- )
655+ num_tokens = 0
656+ if token_spans_monotonic :
657+ token_idx = bisect .bisect_left (token_ends , row_span ["start" ])
658+ while (
659+ token_idx < len (token_spans )
660+ and token_starts [token_idx ] <= row_span ["end" ]
661+ ):
662+ token = token_spans [token_idx ]
663+ if overlaps_token (row_span , token ):
664+ token ["used" ] = True
665+ num_tokens += 1
666+ token_idx += 1
667+ else :
668+ for token in token_spans :
669+ if overlaps_token (row_span , token ):
670+ token ["used" ] = True
671+ num_tokens += 1
672+ row_span ["num_tokens" ] = num_tokens
646673 # Accounts for the fact that cells are duplicated when they span cells.
647674 row_span ["num_effective_tokens" ] = (
648675 row_span ["num_tokens" ] * row_span ["max_cell_span" ]
649676 )
650- total_num_tokens += row_span ["num_tokens" ]
651677 combined_rows .append (
652678 {
653679 "num_tokens" : sum (r ["num_tokens" ] for r in row_out ),
0 commit comments