-
Notifications
You must be signed in to change notification settings - Fork 80
FIX: [DEV-15236] improvements to scheduler and table model predict speed. #856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| """ | ||
| Finetune-style interface for running a pipeline of table and non-table models. | ||
| """ | ||
| import bisect | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't know this was part of stdlib! |
||
| import copy | ||
| import functools | ||
| import logging | ||
|
|
@@ -438,8 +439,12 @@ def __init__( | |
|
|
||
| def get_axis_spans(self, context, token_bounds, context_key): | ||
| max_row = max(r[context_key] for r in context) | ||
| row_spans = [ | ||
| [ | ||
| row_spans = [[] for _ in range(max_row + 1)] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This bucketing change is part of the chunker speedup. On the synthetic 150 x 20 table chunking benchmark, the chunker dropped from 55.443s to 0.263s (~99.5% faster) after the get_axis_spans / _make_chunks / combine_row_spans optimization pass. |
||
| for c in context: | ||
| row_idx = c[context_key] | ||
| if row_idx < 0: | ||
| continue | ||
| row_spans[row_idx].append( | ||
| { | ||
| "start": c["start"], | ||
| "end": c["end"], | ||
|
|
@@ -448,11 +453,7 @@ def get_axis_spans(self, context, token_bounds, context_key): | |
| ) | ||
| + 1, | ||
| } | ||
| for c in context | ||
| if i == c[context_key] | ||
| ] | ||
| for i in range(max_row + 1) | ||
| ] | ||
| ) | ||
| return self.combine_row_spans(row_spans, token_bounds) | ||
|
|
||
| def chunk(self, table_text_chunks_and_context): | ||
|
|
@@ -594,17 +595,21 @@ def _make_chunks(self, row_spans): | |
| break | ||
| max_len_chunks = [] | ||
| temp_rows = [] | ||
| temp_row_tokens = 0 | ||
| context_included = False | ||
| for row in row_spans[n_rows_context:]: | ||
| if self._num_tokens(context + temp_rows + [row]) < self.max_length: | ||
| row_tokens = row["num_effective_tokens"] | ||
| if (context_tokens + temp_row_tokens + row_tokens) < self.max_length: | ||
| temp_rows.append(row) | ||
| temp_row_tokens += row_tokens | ||
| elif len(temp_rows) == 0: | ||
| # The current row is too long to use any context at all. | ||
| max_len_chunks.append([row]) | ||
| else: | ||
| context_included = True | ||
| max_len_chunks.append(copy.deepcopy(context) + temp_rows) | ||
| temp_rows = [row] | ||
| temp_row_tokens = row_tokens | ||
| if temp_rows or not context_included: | ||
| max_len_chunks.append(copy.deepcopy(context) + temp_rows) | ||
| output_spans = [] | ||
|
|
@@ -626,11 +631,18 @@ def _make_chunks(self, row_spans): | |
| return output_spans | ||
|
|
||
| def combine_row_spans(self, row_spans, token_spans): | ||
| def mark_token(t): | ||
| t["used"] = True | ||
| return t | ||
|
|
||
| total_num_tokens = 0 | ||
| token_starts = [] | ||
| token_ends = [] | ||
| token_spans_monotonic = True | ||
| for token in token_spans: | ||
| token_start = token["start"] | ||
| token_end = token["end"] | ||
| if token_starts and ( | ||
| token_start < token_starts[-1] or token_end < token_ends[-1] | ||
| ): | ||
| token_spans_monotonic = False | ||
| token_starts.append(token_start) | ||
| token_ends.append(token_end) | ||
| combined_rows = [] | ||
| for row in row_spans: | ||
| row_out = [] | ||
|
|
@@ -640,14 +652,28 @@ def mark_token(t): | |
| else: | ||
| row_out.append(span) | ||
| for row_span in row_out: | ||
| row_span["num_tokens"] = len( | ||
| [mark_token(t) for t in token_spans if overlaps_token(row_span, t)] | ||
| ) | ||
| num_tokens = 0 | ||
| if token_spans_monotonic: | ||
| token_idx = bisect.bisect_left(token_ends, row_span["start"]) | ||
| while ( | ||
| token_idx < len(token_spans) | ||
| and token_starts[token_idx] <= row_span["end"] | ||
| ): | ||
| token = token_spans[token_idx] | ||
| if overlaps_token(row_span, token): | ||
| token["used"] = True | ||
| num_tokens += 1 | ||
| token_idx += 1 | ||
| else: | ||
| for token in token_spans: | ||
| if overlaps_token(row_span, token): | ||
| token["used"] = True | ||
| num_tokens += 1 | ||
| row_span["num_tokens"] = num_tokens | ||
| # Accounts for the fact that cells are duplicated when they span cells. | ||
| row_span["num_effective_tokens"] = ( | ||
| row_span["num_tokens"] * row_span["max_cell_span"] | ||
| ) | ||
| total_num_tokens += row_span["num_tokens"] | ||
| combined_rows.append( | ||
| { | ||
| "num_tokens": sum(r["num_tokens"] for r in row_out), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This files changes are probably worth pulling in. Big improvements for the risk on larger tables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other changes maybe we drop as being not valuable enough for the risk at this point.