From ea4f5460a210a80ce63448c3c4d410152af7cacc Mon Sep 17 00:00:00 2001 From: benleetownsend Date: Fri, 24 Apr 2026 11:40:21 +0100 Subject: [PATCH 1/4] fix: make scheduler stats per-model --- finetune/base.py | 2 +- finetune/scheduler.py | 270 ++++++++++++++++++++++++++++----- finetune/util/table_labeler.py | 41 ++--- tests/test_scheduler.py | 268 ++++++++++++++++++++++++++------ 4 files changed, 485 insertions(+), 96 deletions(-) diff --git a/finetune/base.py b/finetune/base.py index 1ebcc109..9a200823 100644 --- a/finetune/base.py +++ b/finetune/base.py @@ -428,7 +428,7 @@ def get_zipped_data(): if v.shape[0] != batch_size or k == "transition_params" } for i in range(batch_size): - progress.update(i) + progress.update(1) step_value = { k: pred_numpy[k] if k in not_batched else pred_numpy[k][i] for k in pred_numpy diff --git a/finetune/scheduler.py b/finetune/scheduler.py index 55add629..ea57aa1c 100644 --- a/finetune/scheduler.py +++ b/finetune/scheduler.py @@ -2,7 +2,7 @@ import gc import logging import sys -from collections import OrderedDict +from collections import OrderedDict, defaultdict, deque import psutil import pynvml @@ -47,9 +47,13 @@ def scheduled_predict( ): # this is just for backwards compat, should always have a blob key going forward cache_key = kwargs.pop("cache_key", None) + resolved_cache_key = self.model_cache_key( + model_file, key=key, cache_key=cache_key + ) model = self._rotate_in_model( model_file, key=key, config_overrides=config_overrides, cache_key=cache_key ) + self._reset_peak_memory_stats() try: preds = fn(self, model_file=model_file, x=x, *args, model=model, **kwargs) except Exception as orig_except: @@ -68,6 +72,7 @@ def scheduled_predict( config_overrides=config_overrides, cache_key=cache_key, ) + self._reset_peak_memory_stats() preds = fn( self, model_file=model_file, x=x, *args, model=model, **kwargs ) @@ -77,6 +82,7 @@ def scheduled_predict( str(orig_except), str(e) ) ) + self._record_prediction_memory(model, resolved_cache_key) self._update_memory_limit(model) return preds @@ -84,45 +90,166 @@ def scheduled_predict( class Scheduler: + # Scheduler behavior overview: + # + # 1. Models are cached by resolved cache key (`model_cache_key`) and tracked in + # `loaded_models` as an LRU queue. Cache hits move the active key to the end + # of that queue. Cache misses load the model, then register its base model + # type so future memory estimates are source-specific. + # + # 2. GPU admission is based on recent observed memory for the model's + # `config.base_model` type, not a single global historical max. For each + # source model type we keep a bounded window of recent: + # - `max_above_resting`: transient peak above steady-state memory during + # prediction + # - `model_size`: approximate resident memory added by loading the model + # This gives fast recovery from one-off spikes while still preserving a + # conservative running max for recent behavior. + # + # 3. On cache miss, the model is loaded before the final GPU headroom decision. + # That lets us discover the source model type and use its own memory history + # instead of falling back to unrelated global spikes. CPU pressure is still + # checked first so we can evict before loading if host RAM is tight. + # + # 4. Before every prediction, including cache hits, we run an execution-time + # headroom check. If the active model's expected transient execution memory + # does not fit alongside currently loaded peers, we evict older unrelated + # models until it does. Very large transient models are forced into + # effectively exclusive execution because allocator fragmentation can make a + # "fits on paper" run still fail in practice. + # + # 5. Peak memory statistics are reset around each prediction attempt, then the + # successful attempt records fresh memory samples back into the per-source + # windows. If prediction still fails, we keep the existing recovery path: + # close all models, reload, and retry once before raising a + # `FinetuneSchedulerError`. def __init__( - self, max_models=None, config=None, reserved=750000000, ram_max_frac=0.8 + self, + max_models=None, + config=None, + reserved=750000000, + ram_max_frac=0.8, + memory_window_size=10, + execution_safety_margin=512 * 1024 * 1024, ): self.loaded_models = list() self.max_models = max_models self.gpu_memory_limit = None self.model_cache = dict() self.max_above_resting = None - self.previous_in_use = 0 self.max_model_size = None self.config = config or {} self.reserved = reserved self.ram_max_frac = ram_max_frac self.etl_cache = EtlCache() + self.memory_window_size = memory_window_size + self.execution_safety_margin = execution_safety_margin + self.memory_histories = defaultdict( + lambda: { + "max_above_resting": deque(maxlen=self.memory_window_size), + "model_size": deque(maxlen=self.memory_window_size), + } + ) + self.global_memory_histories = { + "max_above_resting": deque(maxlen=self.memory_window_size), + "model_size": deque(maxlen=self.memory_window_size), + } + self.cache_key_to_source_model = dict() + self.pending_load_baselines = dict() + + def _source_model_key_from_model(self, model): + base_model = getattr(getattr(model, "config", None), "base_model", None) + if base_model is None: + return None + return getattr(base_model, "__name__", str(base_model)) + + def _append_memory_sample(self, stat_name, source_model_key, value): + if value is None: + return + value = max(int(value), 0) + self.global_memory_histories[stat_name].append(value) + if source_model_key is not None: + self.memory_histories[source_model_key][stat_name].append(value) + + def _record_prediction_memory_sample( + self, source_model_key, max_above_resting=None, model_size=None + ): + self._append_memory_sample( + "max_above_resting", source_model_key, max_above_resting + ) + self._append_memory_sample("model_size", source_model_key, model_size) + estimated = self._estimated_memory_stats(source_model_key) + self.max_above_resting = estimated["max_above_resting"] + self.max_model_size = estimated["model_size"] - def _memory_for_one_more(self): - if self.gpu_memory_limit is None: - return True # first run + def _estimated_memory_stats( + self, source_model_key=None, allow_global_fallback=True + ): + stats = {} + for stat_name in ("max_above_resting", "model_size"): + history = None + if source_model_key is not None: + history = self.memory_histories[source_model_key][stat_name] + if history: + stats[stat_name] = max(history) + elif allow_global_fallback and self.global_memory_histories[stat_name]: + stats[stat_name] = max(self.global_memory_histories[stat_name]) + else: + stats[stat_name] = 0 + return stats + + def _reset_peak_memory_stats(self): + if not is_gpu_available(): + return + reset_memory_stats = getattr(tf.config.experimental, "reset_memory_stats", None) + if reset_memory_stats is None: + return + try: + reset_memory_stats("GPU:0") + except Exception: + LOGGER.debug("Failed to reset TensorFlow peak memory stats.", exc_info=True) + + def _record_prediction_memory(self, model, resolved_cache_key): + source_model_key = self._source_model_key_from_model(model) if is_gpu_available(): in_use = BytesInUse() peak = MaxBytesInUse() - if ( - self.max_above_resting is None - or (peak - in_use) > self.max_above_resting - ): - self.max_above_resting = peak - in_use + max_above_resting = max(peak - in_use, 0) + else: + in_use = 0 + max_above_resting = 0 - if ( - self.max_model_size is None - or (in_use - self.previous_in_use) > self.max_model_size - ): - self.max_model_size = in_use - self.previous_in_use + load_baseline = self.pending_load_baselines.pop(resolved_cache_key, None) + model_size = None + if load_baseline is not None: + model_size = max(in_use - load_baseline, 0) + + self._record_prediction_memory_sample( + source_model_key, + max_above_resting=max_above_resting, + model_size=model_size, + ) - self.previous_in_use = in_use + def _has_memory_for_model( + self, + source_model_key=None, + include_model_size=True, + allow_global_fallback=True, + extra_model_buffer=0, + ): + if self.gpu_memory_limit is None: + return True # first run + estimated = self._estimated_memory_stats( + source_model_key, allow_global_fallback=allow_global_fallback + ) + self.max_above_resting = estimated["max_above_resting"] + self.max_model_size = estimated["model_size"] + if is_gpu_available(): + in_use = BytesInUse() else: LOGGER.info("No GPU available, skipping GPU memory checks.") self.max_above_resting = 0 self.max_model_size = 0 - self.previous_in_use = 0 in_use = 0 cpu_percent = psutil.virtual_memory().percent @@ -141,18 +268,72 @@ def _memory_for_one_more(self): ) if cpu_percent > self.ram_max_frac * 100: return False - return ( - in_use + self.max_above_resting + self.max_model_size + self.reserved - ) < self.gpu_memory_limit - - def _close_oldest_model(self): - if len(self.loaded_models): - name = self.loaded_models.pop(0) + required_gpu = ( + in_use + self.max_above_resting + self.reserved + extra_model_buffer + ) + if include_model_size: + required_gpu += self.max_model_size + return required_gpu < self.gpu_memory_limit + + def _memory_for_one_more(self, source_model_key=None): + return self._has_memory_for_model(source_model_key=source_model_key) + + def _ensure_cpu_headroom(self, exclude=None): + exclude = set(exclude or []) + while psutil.virtual_memory().percent > self.ram_max_frac * 100: + if not self._close_oldest_model(exclude=exclude): + return False + return True + + def _ensure_prediction_headroom( + self, + source_model_key, + active_cache_key, + include_model_size, + allow_global_fallback, + ): + estimated = self._estimated_memory_stats( + source_model_key, allow_global_fallback=allow_global_fallback + ) + global_model_size = 0 + if self.global_memory_histories["model_size"]: + global_model_size = max(self.global_memory_histories["model_size"]) + transient_buffer = int(estimated["max_above_resting"] * 0.2) + execution_buffer = max( + global_model_size, self.execution_safety_margin, transient_buffer + ) + while not self._has_memory_for_model( + source_model_key=source_model_key, + include_model_size=include_model_size, + allow_global_fallback=allow_global_fallback, + extra_model_buffer=execution_buffer, + ): + if not self._close_oldest_model(exclude={active_cache_key}): + return False + include_model_size = False + if ( + self.gpu_memory_limit is not None + and estimated["max_above_resting"] > self.gpu_memory_limit * 0.5 + ): + while len(self.loaded_models) > 1: + if not self._close_oldest_model(exclude={active_cache_key}): + return False + return True + + def _close_oldest_model(self, exclude=None): + exclude = set(exclude or []) + for idx, name in enumerate(self.loaded_models): + if name in exclude: + continue + self.loaded_models.pop(idx) self.model_cache[name].close(update_saver=False) del self.model_cache[name] + self.pending_load_baselines.pop(name, None) + self.cache_key_to_source_model.pop(name, None) gc.collect() - else: - LOGGER.info("No models cached -- cannot remove oldest model.") + return True + LOGGER.info("No models cached -- cannot remove oldest model.") + return False def model_cache_key(self, model, key, cache_key): if cache_key is None: @@ -169,16 +350,26 @@ def model_cache_key(self, model, key, cache_key): def _rotate_in_model(self, model, key, config_overrides=None, cache_key=None): resolved_cache_key = self.model_cache_key(model, key=key, cache_key=cache_key) - if resolved_cache_key not in self.loaded_models: - if ( + source_model_key = self.cache_key_to_source_model.get(resolved_cache_key) + cache_miss = resolved_cache_key not in self.loaded_models + if cache_miss: + while ( self.max_models is not None and len(self.loaded_models) + 1 > self.max_models - ) or not self._memory_for_one_more(): - self._close_oldest_model() + ): + if not self._close_oldest_model(): + break + self._ensure_cpu_headroom() config_overrides = config_overrides or {} merged_config = {**self.config, **config_overrides} + load_baseline = BytesInUse() if is_gpu_available() else 0 out_model = BaseModel.load(model, key=key, **merged_config) self.model_cache[resolved_cache_key] = out_model + self.pending_load_baselines[resolved_cache_key] = load_baseline + self.cache_key_to_source_model[ + resolved_cache_key + ] = self._source_model_key_from_model(out_model) + source_model_key = self.cache_key_to_source_model[resolved_cache_key] else: out_model = self.model_cache[resolved_cache_key] self.loaded_models.remove( @@ -186,6 +377,12 @@ def _rotate_in_model(self, model, key, config_overrides=None, cache_key=None): ) # put it back at the end of the queue self.loaded_models.append(resolved_cache_key) + self._ensure_prediction_headroom( + source_model_key=source_model_key, + active_cache_key=resolved_cache_key, + include_model_size=cache_miss, + allow_global_fallback=False, + ) out_model._cached_predict = True return out_model @@ -195,9 +392,13 @@ def _update_memory_limit(self, model): del model.saver.variables del model.saver.fallback_ if is_gpu_available(): - self.gpu_memory_limit = ( - BytesLimit() - ) # delay this so that any options get applied from finetune. + gpu_memory_limit = BytesLimit() + gpu_fraction = self.config.get("per_process_gpu_memory_fraction") + if gpu_fraction is not None: + gpu_memory_limit = min( + gpu_memory_limit, int(gpu_memory_limit * gpu_fraction) + ) + self.gpu_memory_limit = gpu_memory_limit else: LOGGER.info("No GPU available, skipping GPU memory limit update.") self.gpu_memory_limit = sys.maxsize @@ -205,6 +406,7 @@ def _update_memory_limit(self, model): def close_all(self): while self.loaded_models: self._close_oldest_model() + self.pending_load_baselines.clear() MODEL_REGISTRY.cleanup() @scheduled diff --git a/finetune/util/table_labeler.py b/finetune/util/table_labeler.py index 06bf434e..c8c5093c 100644 --- a/finetune/util/table_labeler.py +++ b/finetune/util/table_labeler.py @@ -1,6 +1,7 @@ """ Finetune-style interface for running a pipeline of table and non-table models. """ +import bisect import copy import functools import logging @@ -438,8 +439,9 @@ def __init__( def get_axis_spans(self, context, token_bounds, context_key): max_row = max(r[context_key] for r in context) - row_spans = [ - [ + row_spans = [[] for _ in range(max_row + 1)] + for c in context: + row_spans[c[context_key]].append( { "start": c["start"], "end": c["end"], @@ -448,11 +450,7 @@ def get_axis_spans(self, context, token_bounds, context_key): ) + 1, } - for c in context - if i == c[context_key] - ] - for i in range(max_row + 1) - ] + ) return self.combine_row_spans(row_spans, token_bounds) def chunk(self, table_text_chunks_and_context): @@ -594,10 +592,13 @@ def _make_chunks(self, row_spans): break max_len_chunks = [] temp_rows = [] + temp_row_tokens = 0 context_included = False for row in row_spans[n_rows_context:]: - if self._num_tokens(context + temp_rows + [row]) < self.max_length: + row_tokens = row["num_effective_tokens"] + if (context_tokens + temp_row_tokens + row_tokens) < self.max_length: temp_rows.append(row) + temp_row_tokens += row_tokens elif len(temp_rows) == 0: # The current row is too long to use any context at all. max_len_chunks.append([row]) @@ -605,6 +606,7 @@ def _make_chunks(self, row_spans): context_included = True max_len_chunks.append(copy.deepcopy(context) + temp_rows) temp_rows = [row] + temp_row_tokens = row_tokens if temp_rows or not context_included: max_len_chunks.append(copy.deepcopy(context) + temp_rows) output_spans = [] @@ -626,11 +628,8 @@ def _make_chunks(self, row_spans): return output_spans def combine_row_spans(self, row_spans, token_spans): - def mark_token(t): - t["used"] = True - return t - - total_num_tokens = 0 + token_starts = [t["start"] for t in token_spans] + token_ends = [t["end"] for t in token_spans] combined_rows = [] for row in row_spans: row_out = [] @@ -640,14 +639,22 @@ def mark_token(t): else: row_out.append(span) for row_span in row_out: - row_span["num_tokens"] = len( - [mark_token(t) for t in token_spans if overlaps_token(row_span, t)] - ) + num_tokens = 0 + token_idx = bisect.bisect_left(token_ends, row_span["start"]) + while ( + token_idx < len(token_spans) + and token_starts[token_idx] <= row_span["end"] + ): + token = token_spans[token_idx] + if overlaps_token(row_span, token): + token["used"] = True + num_tokens += 1 + token_idx += 1 + row_span["num_tokens"] = num_tokens # Accounts for the fact that cells are duplicated when they span cells. row_span["num_effective_tokens"] = ( row_span["num_tokens"] * row_span["max_cell_span"] ) - total_num_tokens += row_span["num_tokens"] combined_rows.append( { "num_tokens": sum(r["num_tokens"] for r in row_out), diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index a918c88b..f34ce151 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,54 +1,234 @@ -import os -import time +from types import SimpleNamespace import pytest +import finetune.scheduler as scheduler_mod +from finetune.base_models import SourceModel +from finetune.errors import FinetuneSchedulerError from finetune.scheduler import Scheduler +class DummySourceModelA(SourceModel): + encoder = None + featurizer = None + settings = {} + + +class DummySourceModelB(SourceModel): + encoder = None + featurizer = None + settings = {} + + +class DummyLoadedModel: + def __init__(self, base_model, outcomes=None): + self.config = SimpleNamespace(base_model=base_model) + self.saver = SimpleNamespace(variables={}, fallback_=object()) + self._outcomes = list(outcomes or [["ok"]]) + self.close_calls = 0 + self._cached_predict = False + + def predict(self, x, *args, **kwargs): + outcome = self._outcomes.pop(0) + if isinstance(outcome, Exception): + raise outcome + return outcome + + def close(self, update_saver=False): + self.close_calls += 1 + + +@pytest.fixture(autouse=True) +def patch_cleanup(monkeypatch): + monkeypatch.setattr(scheduler_mod.MODEL_REGISTRY, "cleanup", lambda: None) + + @pytest.fixture -def models(save_model_dir, trained_classifier, trained_annotation): - model1 = os.path.join(save_model_dir, "1.jl") - model2 = os.path.join(save_model_dir, "2.jl") - trained_classifier.save(model1) - trained_annotation.save(model2) - yield model1, model2 +def fake_memory(monkeypatch): + state = {"in_use": 0, "peak": 0, "gpu_available": True, "cpu_percent": 10} + + monkeypatch.setattr( + scheduler_mod, + "is_gpu_available", + lambda: state["gpu_available"], + ) + monkeypatch.setattr(scheduler_mod, "BytesInUse", lambda: state["in_use"]) + monkeypatch.setattr(scheduler_mod, "MaxBytesInUse", lambda: state["peak"]) + monkeypatch.setattr(scheduler_mod, "BytesLimit", lambda: 1000) + monkeypatch.setattr( + scheduler_mod.psutil, + "virtual_memory", + lambda: SimpleNamespace(percent=state["cpu_percent"]), + ) + return state + + +def test_memory_windows_are_bounded_per_source_model(): + shed = Scheduler(memory_window_size=3) + + for sample in (100, 10, 9, 8): + shed._record_prediction_memory_sample( + "DummySourceModelA", max_above_resting=sample, model_size=sample // 2 + ) + shed._record_prediction_memory_sample( + "DummySourceModelB", max_above_resting=5, model_size=7 + ) + + assert shed._estimated_memory_stats("DummySourceModelA") == { + "max_above_resting": 10, + "model_size": 5, + } + assert shed._estimated_memory_stats("DummySourceModelB") == { + "max_above_resting": 5, + "model_size": 7, + } + assert shed._estimated_memory_stats("unknown") == { + "max_above_resting": 9, + "model_size": 7, + } + + +def test_cache_miss_loads_first_and_uses_source_specific_stats( + monkeypatch, fake_memory +): + shed = Scheduler(reserved=0, execution_safety_margin=0) + shed.gpu_memory_limit = 1000 + shed._record_prediction_memory_sample( + "UnrelatedSource", max_above_resting=900, model_size=50 + ) + + keep = DummyLoadedModel(DummySourceModelB) + shed.model_cache["keep"] = keep + shed.loaded_models.append("keep") + fake_memory["in_use"] = 100 + + loaded = [] + + def fake_load(model_file, key=None, **kwargs): + model = DummyLoadedModel(DummySourceModelA) + loaded.append((model_file, key, model)) + return model + + monkeypatch.setattr(scheduler_mod.BaseModel, "load", fake_load) + + out_model = shed.get_model("fresh.jl") + + assert out_model is loaded[0][2] + assert shed.cache_key_to_source_model["model=fresh.jl"] == "DummySourceModelA" + assert shed.loaded_models == ["keep", "model=fresh.jl"] + assert keep.close_calls == 0 + + +def test_cache_hit_execution_headroom_evicts_unrelated_models(monkeypatch, fake_memory): + shed = Scheduler(reserved=0, execution_safety_margin=0) + shed.gpu_memory_limit = 1000 + shed._record_prediction_memory_sample( + "DummySourceModelA", max_above_resting=400, model_size=0 + ) + + other = DummyLoadedModel(DummySourceModelB) + active = DummyLoadedModel(DummySourceModelA) + shed.model_cache = {"other": other, "active": active} + shed.loaded_models = ["other", "active"] + shed.cache_key_to_source_model = { + "other": "DummySourceModelB", + "active": "DummySourceModelA", + } + fake_memory["in_use"] = 700 -def test_scheduler(models): - model1, model2 = models + def dynamic_in_use(): + return 700 if "other" in shed.loaded_models else 200 + + monkeypatch.setattr(scheduler_mod, "BytesInUse", dynamic_in_use) + + shed._ensure_prediction_headroom( + source_model_key="DummySourceModelA", + active_cache_key="active", + include_model_size=False, + allow_global_fallback=False, + ) + + assert shed.loaded_models == ["active"] + assert other.close_calls == 1 + + +def test_large_transient_models_run_exclusively(fake_memory): + shed = Scheduler(reserved=0, execution_safety_margin=0) + shed.gpu_memory_limit = 1000 + shed._record_prediction_memory_sample( + "DummySourceModelA", max_above_resting=600, model_size=0 + ) + + older = DummyLoadedModel(DummySourceModelB) + newer = DummyLoadedModel(DummySourceModelB) + active = DummyLoadedModel(DummySourceModelA) + shed.model_cache = {"older": older, "newer": newer, "active": active} + shed.loaded_models = ["older", "newer", "active"] + shed.cache_key_to_source_model = { + "older": "DummySourceModelB", + "newer": "DummySourceModelB", + "active": "DummySourceModelA", + } + fake_memory["in_use"] = 100 + + shed._ensure_prediction_headroom( + source_model_key="DummySourceModelA", + active_cache_key="active", + include_model_size=False, + allow_global_fallback=False, + ) + + assert shed.loaded_models == ["active"] + assert older.close_calls == 1 + assert newer.close_calls == 1 + + +def test_scheduler_retries_after_exception(monkeypatch, fake_memory): + fake_memory["gpu_available"] = False shed = Scheduler() - tic_1 = time.time() - preds_m1 = shed.predict(model1, ["A"]) # May need isolation - toc_1 = time.time() - assert len(preds_m1) == 1 - assert isinstance(preds_m1[0], str) # classification - preds_m2 = shed.predict(model2, ["A"]) - assert len(preds_m2) == 1 - assert isinstance( - preds_m2[0], list - ) # Annotation - can't realy expect any labels though. - tic_3 = time.time() - shed.predict(model1, ["something else"]) - toc_3 = time.time() - assert toc_1 - tic_1 > toc_3 - tic_3 - assert len(shed.loaded_models) == 2 - shed.close_all() - assert len(shed.loaded_models) == 0 - shed.predict_proba(model1, ["A"]) - shed.featurize(model1, ["A"]) - shed.featurize_sequence(model1, ["A"]) - - -def test_scheduler_max_models(models): - model1, model2 = models - shed = Scheduler(max_models=1) - time_pre = time.time() - pred1a = shed.predict(model1, ["A"]) - time_mid = time.time() - pred1b = shed.predict(model1, ["A"]) - time_end = time.time() - assert time_end - time_mid < time_mid - time_pre - 1 - assert pred1a == pred1b - shed.predict(model2, ["A"]) # Load another model. - assert len(shed.loaded_models) == 1 + loads = [] + + def fake_load(model_file, key=None, **kwargs): + outcomes = [RuntimeError("oom-ish failure")] if not loads else [["ok"]] + model = DummyLoadedModel(DummySourceModelA, outcomes) + loads.append(model) + return model + + monkeypatch.setattr(scheduler_mod.BaseModel, "load", fake_load) + + result = shed.predict("dummy.jl", ["A"]) + + assert result == ["ok"] + assert len(loads) == 2 + assert loads[0].close_calls == 1 + assert shed.loaded_models == ["model=dummy.jl"] + + +def test_scheduler_wraps_retry_failure(monkeypatch, fake_memory): + fake_memory["gpu_available"] = False + shed = Scheduler() + + def fake_load(model_file, key=None, **kwargs): + return DummyLoadedModel( + DummySourceModelA, + [RuntimeError("first fail"), RuntimeError("second fail")], + ) + + monkeypatch.setattr(scheduler_mod.BaseModel, "load", fake_load) + + with pytest.raises(FinetuneSchedulerError) as excinfo: + shed.predict("dummy.jl", ["A"]) + + message = str(excinfo.value) + assert "Original Error: first fail" in message + assert "Retry Error:" in message + + +def test_update_memory_limit_uses_per_process_fraction(fake_memory): + shed = Scheduler(config={"per_process_gpu_memory_fraction": 0.4}) + model = DummyLoadedModel(DummySourceModelA) + + shed._update_memory_limit(model) + + assert shed.gpu_memory_limit == 400 From ccd7d1b2161a8a041b1a8efd43b1d442365c91c1 Mon Sep 17 00:00:00 2001 From: benleetownsend Date: Fri, 24 Apr 2026 15:23:43 +0100 Subject: [PATCH 2/4] fix: preprocessing changes --- finetune/base.py | 19 +- finetune/base_models/bert/table_utils.py | 352 +++++++++++++++++++++++ tests/test_scheduler.py | 46 +++ 3 files changed, 414 insertions(+), 3 deletions(-) diff --git a/finetune/base.py b/finetune/base.py index 9a200823..bd93cbdf 100644 --- a/finetune/base.py +++ b/finetune/base.py @@ -407,9 +407,22 @@ def _inference( def get_zipped_data(): return iter(zipped_data) - input_fn = self.input_pipeline.get_dataset_from_generator( - get_zipped_data, input_mode=InputMode.PREDICT, update_hook=update_hook - )["predict_dataset"] + predict_batch_iter = getattr( + self.input_pipeline.batch_postprocessor, "iter_predict_batches", None + ) + if predict_batch_iter is not None: + + def feature_iter(): + for data in get_zipped_data(): + yield from self.input_pipeline.text_to_tokens_mask(**data) + + input_fn = predict_batch_iter( + feature_iter(), predict_batch_size=self.config.predict_batch_size + ) + else: + input_fn = self.input_pipeline.get_dataset_from_generator( + get_zipped_data, input_mode=InputMode.PREDICT, update_hook=update_hook + )["predict_dataset"] model = self.model diff --git a/finetune/base_models/bert/table_utils.py b/finetune/base_models/bert/table_utils.py index f8d4484b..fb061a99 100644 --- a/finetune/base_models/bert/table_utils.py +++ b/finetune/base_models/bert/table_utils.py @@ -1,5 +1,6 @@ import logging +import numpy as np import tensorflow as tf from finetune.nn.activations import bert_gelu as gelu @@ -7,6 +8,301 @@ LOGGER = logging.getLogger("finetune") +def _split_ragged_rows_numpy(values, row_lengths): + rows = [] + offset = 0 + for row_length in row_lengths: + next_offset = offset + int(row_length) + rows.append(values[offset:next_offset]) + offset = next_offset + return rows + + +def _flatten_ragged_rows_numpy(rows): + if not rows: + return np.zeros((0, 2), dtype=np.int32), np.zeros((0,), dtype=np.int64) + row_lengths = np.asarray([len(row) for row in rows], dtype=np.int64) + non_empty_rows = [row for row in rows if len(row)] + if non_empty_rows: + values = np.concatenate(non_empty_rows, axis=0).astype(np.int32, copy=False) + else: + values = np.zeros((0, 2), dtype=np.int32) + return values, row_lengths + + +def _masked_ragged_indices_numpy_rows(sequence_lengths, start, end): + bs = int(sequence_lengths.shape[0]) + per_batch_groups = [] + global_max_axis = -1 + + for batch_idx in range(bs): + seq_len = int(sequence_lengths[batch_idx]) + batch_groups = [] + for seq_idx in range(seq_len): + start_idx = int(start[batch_idx, seq_idx]) + end_idx = int(end[batch_idx, seq_idx]) + if end_idx >= len(batch_groups): + batch_groups.extend([] for _ in range(end_idx + 1 - len(batch_groups))) + token_idx = (batch_idx, seq_idx) + for axis_idx in range(start_idx, end_idx + 1): + batch_groups[axis_idx].append(token_idx) + per_batch_groups.append(batch_groups) + global_max_axis = max(global_max_axis, len(batch_groups) - 1) + + rows = [] + for axis_idx in range(global_max_axis + 1): + for batch_idx in range(bs): + batch_groups = per_batch_groups[batch_idx] + if axis_idx >= len(batch_groups): + continue + axis_values = batch_groups[axis_idx] + if not axis_values: + continue + rows.append(np.asarray(axis_values, dtype=np.int32)) + + return rows + + +def _masked_ragged_indices_numpy(sequence_lengths, start, end): + rows = _masked_ragged_indices_numpy_rows(sequence_lengths, start, end) + values, row_lengths = _flatten_ragged_rows_numpy(rows) + + if len(row_lengths) == 0: + return tf.RaggedTensor.from_row_lengths( + values=tf.zeros([0, 2], dtype=tf.int32), + row_lengths=tf.zeros([0], dtype=tf.int64), + ) + + return tf.RaggedTensor.from_row_lengths( + values=tf.convert_to_tensor(values, dtype=tf.int32), + row_lengths=tf.convert_to_tensor(row_lengths, dtype=tf.int64), + ) + + +def _chunk_rows_numpy( + rows, + other_end, + include_n_rows=2, + base_model_max_length=512 - 2, +): + chunked_rows = [] + first_n_mask = np.asarray(other_end < include_n_rows, dtype=bool) + + for row in rows: + if len(row) == 0: + chunked_rows.append(row) + continue + row_mask = first_n_mask[row[:, 0], row[:, 1]] + if np.count_nonzero(row_mask) >= base_model_max_length: + row_mask = np.zeros_like(row_mask, dtype=bool) + first_rows = row[row_mask] + other_rows = row[~row_mask] + remaining_budget = base_model_max_length - len(first_rows) + if remaining_budget <= 0: + remaining_budget = base_model_max_length + + if len(other_rows) == 0: + chunked_rows.append(first_rows) + continue + + for start_idx in range(0, len(other_rows), remaining_budget): + chunked_rows.append( + np.concatenate( + [first_rows, other_rows[start_idx : start_idx + remaining_budget]], + axis=0, + ) + ) + + return chunked_rows + + +def _build_block_diagonal_mask_and_pos_ids(row_lengths): + total_length = int(np.sum(row_lengths)) + if total_length == 0: + return np.zeros((0, 0), dtype=np.float32), np.zeros((0,), dtype=np.int32) + + mask = np.zeros((total_length, total_length), dtype=np.float32) + pos_ids = np.zeros((total_length,), dtype=np.int32) + offset = 0 + for row_length in row_lengths: + next_offset = offset + int(row_length) + mask[offset:next_offset, offset:next_offset] = 1.0 + pos_ids[offset:next_offset] = np.arange(row_length, dtype=np.int32) + offset = next_offset + return mask, pos_ids + + +def _batch_packed_rows_numpy( + rows, + include_mask=True, + base_model_max_length=512, +): + if not rows: + empty_mask = np.zeros((0, 0, 0), dtype=np.float32) if include_mask else None + return ( + [], + np.zeros((0,), dtype=np.int32), + empty_mask, + np.zeros((0, 0), dtype=np.int32), + ) + + row_lengths = np.asarray([len(row) for row in rows], dtype=np.int32) + max_length = min(int(np.max(row_lengths)), base_model_max_length) + + groups = [[0]] + current_total = int(row_lengths[0]) + for row_idx in range(1, len(rows)): + next_length = int(row_lengths[row_idx]) + if min(current_total, base_model_max_length) <= max_length - next_length: + groups[-1].append(row_idx) + current_total += next_length + else: + groups.append([row_idx]) + current_total = next_length + + packed_rows = [] + packed_seq_lens = [] + packed_pos_ids = [] + packed_masks = [] + + for group in groups: + group_rows = [rows[row_idx] for row_idx in group] + packed_rows.append(np.concatenate(group_rows, axis=0)) + group_row_lengths = [len(group_row) for group_row in group_rows] + packed_seq_lens.append(sum(group_row_lengths)) + mask, pos_ids = _build_block_diagonal_mask_and_pos_ids(group_row_lengths) + packed_pos_ids.append(pos_ids) + if include_mask: + packed_masks.append(mask) + + max_seq_len = max((len(row) for row in packed_rows), default=0) + dense_pos_ids = np.zeros((len(packed_rows), max_seq_len), dtype=np.int32) + dense_masks = ( + np.zeros((len(packed_rows), max_seq_len, max_seq_len), dtype=np.float32) + if include_mask + else None + ) + + for idx, pos_ids in enumerate(packed_pos_ids): + dense_pos_ids[idx, : len(pos_ids)] = pos_ids + if include_mask: + dense_masks[idx, : len(pos_ids), : len(pos_ids)] = packed_masks[idx] + dense_masks[idx, len(pos_ids) :, len(pos_ids) :] = 1.0 + + return ( + packed_rows, + np.asarray(packed_seq_lens, dtype=np.int32), + dense_masks, + dense_pos_ids, + ) + + +def _build_gather_outputs_inference_numpy_arrays( + X, + sequence_lengths, + start, + end, + other_end, + chunk_tables, + include_mask, + base_model_max_length=512, + max_tokens_per_batch=512 * 100, +): + sequence_lengths = np.asarray(sequence_lengths, dtype=np.int32) + start = np.asarray(start, dtype=np.int32) + end = np.asarray(end, dtype=np.int32) + other_end = np.asarray(other_end, dtype=np.int32) + X = np.asarray(X, dtype=np.int32) + + rows = _masked_ragged_indices_numpy_rows( + sequence_lengths=sequence_lengths, start=start, end=end + ) + if chunk_tables: + rows = _chunk_rows_numpy( + rows, + other_end=other_end, + base_model_max_length=base_model_max_length - 2, + ) + + seq_len = int(X.shape[1]) + bos_pad = np.asarray([[0, seq_len + 1]], dtype=np.int32) + eos_pad = np.asarray([[0, seq_len]], dtype=np.int32) + rows = [ + np.concatenate([bos_pad, row, eos_pad], axis=0) + if len(row) + else np.concatenate([bos_pad, eos_pad], axis=0) + for row in rows + ] + + packed_rows, seq_lens, attn_mask, pos_ids = _batch_packed_rows_numpy( + rows, + include_mask=include_mask, + base_model_max_length=base_model_max_length, + ) + + max_seq_len = max((len(row) for row in packed_rows), default=0) + col_values = np.full( + (len(packed_rows), max_seq_len, 2), [0, seq_len + 2], dtype=np.int32 + ) + for idx, row in enumerate(packed_rows): + col_values[idx, : len(row)] = row + + max_length = min( + max_tokens_per_batch // max(len(packed_rows), 1), base_model_max_length + ) + if not chunk_tables: + col_values = col_values[:, :max_length] + pos_ids = pos_ids[:, :max_length] + if include_mask: + attn_mask = attn_mask[:, :max_length, :max_length] + + if attn_mask is None: + attn_mask = None + + return { + "seq_lens": seq_lens.astype(np.int32, copy=False), + "values": col_values.astype(np.int32, copy=False), + "attn_mask": None + if attn_mask is None + else attn_mask.astype(np.float32, copy=False), + "pos_ids": pos_ids.astype(np.int32, copy=False), + } + + +def _build_gather_outputs_inference_numpy( + X, + sequence_lengths, + start, + end, + other_end, + chunk_tables, + include_mask, + base_model_max_length=512, + max_tokens_per_batch=512 * 100, +): + arrays = _build_gather_outputs_inference_numpy_arrays( + X=np.asarray(X.numpy(), dtype=np.int32), + sequence_lengths=np.asarray(sequence_lengths.numpy(), dtype=np.int32), + start=np.asarray(start.numpy(), dtype=np.int32), + end=np.asarray(end.numpy(), dtype=np.int32), + other_end=np.asarray(other_end.numpy(), dtype=np.int32), + chunk_tables=chunk_tables, + include_mask=include_mask, + base_model_max_length=base_model_max_length, + max_tokens_per_batch=max_tokens_per_batch, + ) + col_values = tf.convert_to_tensor(arrays["values"], dtype=tf.int32) + col_values.set_shape([None, None, 2]) + return { + "seq_lens": tf.convert_to_tensor(arrays["seq_lens"], dtype=tf.int32), + "values": col_values, + "attn_mask": None + if arrays["attn_mask"] is None + else tf.convert_to_tensor(arrays["attn_mask"], dtype=tf.float32), + "pos_ids": tf.convert_to_tensor(arrays["pos_ids"], dtype=tf.int32), + } + + def batch_packing( ragged_input, include_mask=True, @@ -587,6 +883,62 @@ def modify_input_spec(self, input_spec): types["col_gather"] = gather_types return (types, target_type), (shapes, target_shape) + def _pad_predict_batch(self, feature_batch): + batch_size = len(feature_batch) + max_seq_len = max(len(features["tokens"]) for features in feature_batch) + tokens = np.zeros((batch_size, max_seq_len), dtype=np.int32) + context = np.zeros( + (batch_size, max_seq_len, self.config.context_dim), dtype=np.float32 + ) + lengths = np.zeros((batch_size,), dtype=np.int32) + + for idx, features in enumerate(feature_batch): + seq_len = len(features["tokens"]) + tokens[idx, :seq_len] = features["tokens"] + context[idx, :seq_len] = features["context"] + lengths[idx] = seq_len + + return { + "tokens": tf.convert_to_tensor(tokens, dtype=tf.int32), + "context": tf.convert_to_tensor(context, dtype=tf.float32), + "length": tf.convert_to_tensor(lengths, dtype=tf.int32), + } + + def iter_predict_batches(self, feature_iter, predict_batch_size): + batch = [] + for features in feature_iter: + batch.append(features) + if len(batch) == predict_batch_size: + yield self._postprocess_inference(self._pad_predict_batch(batch)) + batch = [] + if batch: + yield self._postprocess_inference(self._pad_predict_batch(batch)) + + def _postprocess_inference(self, x): + with tf.device("/CPU:0"): + end_col, end_row, start_col, start_row = tf.unstack( + tf.cast(x["context"], tf.int32), num=4, axis=2 + ) + row_gather = _build_gather_outputs_inference_numpy( + X=x["tokens"], + sequence_lengths=x["length"], + start=start_row, + end=end_row, + other_end=end_col, + chunk_tables=self.config.chunk_tables, + include_mask=True, + ) + col_gather = _build_gather_outputs_inference_numpy( + X=x["tokens"], + sequence_lengths=x["length"], + start=start_col, + end=end_col, + other_end=end_row, + chunk_tables=self.config.chunk_tables, + include_mask=True, + ) + return {**x, "row_gather": row_gather, "col_gather": col_gather} + def _postprocess( self, x, diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index f34ce151..e0f5d5a0 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,9 +1,12 @@ from types import SimpleNamespace +import numpy as np import pytest +import tensorflow as tf import finetune.scheduler as scheduler_mod from finetune.base_models import SourceModel +from finetune.base_models.bert.table_utils import TableModelBatchPostprocessor from finetune.errors import FinetuneSchedulerError from finetune.scheduler import Scheduler @@ -232,3 +235,46 @@ def test_update_memory_limit_uses_per_process_fraction(fake_memory): shed._update_memory_limit(model) assert shed.gpu_memory_limit == 400 + + +def test_table_predict_batch_postprocessor_matches_direct_postprocess(): + config = SimpleNamespace( + context_dim=4, + chunk_tables=True, + predict_batch_size=1, + max_length=2048, + ) + postprocessor = TableModelBatchPostprocessor(config=config) + + features = { + "tokens": np.array([1, 2, 3, 4], dtype=np.int32), + "context": np.array( + [ + [0, 0, 0, 0], + [0, 0, 1, 0], + [1, 0, 2, 0], + [1, 0, 3, 0], + ], + dtype=np.float32, + ), + } + + batch = next( + postprocessor.iter_predict_batches(iter([features]), predict_batch_size=1) + ) + direct = postprocessor._postprocess( + { + "tokens": tf.convert_to_tensor([[1, 2, 3, 4]], dtype=tf.int32), + "context": tf.convert_to_tensor([features["context"]], dtype=tf.float32), + "length": tf.convert_to_tensor([4], dtype=tf.int32), + }, + training=False, + ) + + for key in ("tokens", "context", "length"): + np.testing.assert_array_equal(batch[key].numpy(), direct[key].numpy()) + for key in ("row_gather", "col_gather"): + for inner_key in ("seq_lens", "values", "attn_mask", "pos_ids"): + np.testing.assert_array_equal( + batch[key][inner_key].numpy(), direct[key][inner_key].numpy() + ) From 4faa9ca43c213679e82c428fdc75fae0c96b3921 Mon Sep 17 00:00:00 2001 From: benleetownsend Date: Thu, 30 Apr 2026 15:48:03 +0100 Subject: [PATCH 3/4] chore: keep only table chunker optimization --- finetune/base.py | 21 +- finetune/base_models/bert/table_utils.py | 352 ----------------------- finetune/scheduler.py | 270 +++-------------- tests/test_scheduler.py | 314 +++----------------- 4 files changed, 82 insertions(+), 875 deletions(-) diff --git a/finetune/base.py b/finetune/base.py index bd93cbdf..1ebcc109 100644 --- a/finetune/base.py +++ b/finetune/base.py @@ -407,22 +407,9 @@ def _inference( def get_zipped_data(): return iter(zipped_data) - predict_batch_iter = getattr( - self.input_pipeline.batch_postprocessor, "iter_predict_batches", None - ) - if predict_batch_iter is not None: - - def feature_iter(): - for data in get_zipped_data(): - yield from self.input_pipeline.text_to_tokens_mask(**data) - - input_fn = predict_batch_iter( - feature_iter(), predict_batch_size=self.config.predict_batch_size - ) - else: - input_fn = self.input_pipeline.get_dataset_from_generator( - get_zipped_data, input_mode=InputMode.PREDICT, update_hook=update_hook - )["predict_dataset"] + input_fn = self.input_pipeline.get_dataset_from_generator( + get_zipped_data, input_mode=InputMode.PREDICT, update_hook=update_hook + )["predict_dataset"] model = self.model @@ -441,7 +428,7 @@ def feature_iter(): if v.shape[0] != batch_size or k == "transition_params" } for i in range(batch_size): - progress.update(1) + progress.update(i) step_value = { k: pred_numpy[k] if k in not_batched else pred_numpy[k][i] for k in pred_numpy diff --git a/finetune/base_models/bert/table_utils.py b/finetune/base_models/bert/table_utils.py index fb061a99..f8d4484b 100644 --- a/finetune/base_models/bert/table_utils.py +++ b/finetune/base_models/bert/table_utils.py @@ -1,6 +1,5 @@ import logging -import numpy as np import tensorflow as tf from finetune.nn.activations import bert_gelu as gelu @@ -8,301 +7,6 @@ LOGGER = logging.getLogger("finetune") -def _split_ragged_rows_numpy(values, row_lengths): - rows = [] - offset = 0 - for row_length in row_lengths: - next_offset = offset + int(row_length) - rows.append(values[offset:next_offset]) - offset = next_offset - return rows - - -def _flatten_ragged_rows_numpy(rows): - if not rows: - return np.zeros((0, 2), dtype=np.int32), np.zeros((0,), dtype=np.int64) - row_lengths = np.asarray([len(row) for row in rows], dtype=np.int64) - non_empty_rows = [row for row in rows if len(row)] - if non_empty_rows: - values = np.concatenate(non_empty_rows, axis=0).astype(np.int32, copy=False) - else: - values = np.zeros((0, 2), dtype=np.int32) - return values, row_lengths - - -def _masked_ragged_indices_numpy_rows(sequence_lengths, start, end): - bs = int(sequence_lengths.shape[0]) - per_batch_groups = [] - global_max_axis = -1 - - for batch_idx in range(bs): - seq_len = int(sequence_lengths[batch_idx]) - batch_groups = [] - for seq_idx in range(seq_len): - start_idx = int(start[batch_idx, seq_idx]) - end_idx = int(end[batch_idx, seq_idx]) - if end_idx >= len(batch_groups): - batch_groups.extend([] for _ in range(end_idx + 1 - len(batch_groups))) - token_idx = (batch_idx, seq_idx) - for axis_idx in range(start_idx, end_idx + 1): - batch_groups[axis_idx].append(token_idx) - per_batch_groups.append(batch_groups) - global_max_axis = max(global_max_axis, len(batch_groups) - 1) - - rows = [] - for axis_idx in range(global_max_axis + 1): - for batch_idx in range(bs): - batch_groups = per_batch_groups[batch_idx] - if axis_idx >= len(batch_groups): - continue - axis_values = batch_groups[axis_idx] - if not axis_values: - continue - rows.append(np.asarray(axis_values, dtype=np.int32)) - - return rows - - -def _masked_ragged_indices_numpy(sequence_lengths, start, end): - rows = _masked_ragged_indices_numpy_rows(sequence_lengths, start, end) - values, row_lengths = _flatten_ragged_rows_numpy(rows) - - if len(row_lengths) == 0: - return tf.RaggedTensor.from_row_lengths( - values=tf.zeros([0, 2], dtype=tf.int32), - row_lengths=tf.zeros([0], dtype=tf.int64), - ) - - return tf.RaggedTensor.from_row_lengths( - values=tf.convert_to_tensor(values, dtype=tf.int32), - row_lengths=tf.convert_to_tensor(row_lengths, dtype=tf.int64), - ) - - -def _chunk_rows_numpy( - rows, - other_end, - include_n_rows=2, - base_model_max_length=512 - 2, -): - chunked_rows = [] - first_n_mask = np.asarray(other_end < include_n_rows, dtype=bool) - - for row in rows: - if len(row) == 0: - chunked_rows.append(row) - continue - row_mask = first_n_mask[row[:, 0], row[:, 1]] - if np.count_nonzero(row_mask) >= base_model_max_length: - row_mask = np.zeros_like(row_mask, dtype=bool) - first_rows = row[row_mask] - other_rows = row[~row_mask] - remaining_budget = base_model_max_length - len(first_rows) - if remaining_budget <= 0: - remaining_budget = base_model_max_length - - if len(other_rows) == 0: - chunked_rows.append(first_rows) - continue - - for start_idx in range(0, len(other_rows), remaining_budget): - chunked_rows.append( - np.concatenate( - [first_rows, other_rows[start_idx : start_idx + remaining_budget]], - axis=0, - ) - ) - - return chunked_rows - - -def _build_block_diagonal_mask_and_pos_ids(row_lengths): - total_length = int(np.sum(row_lengths)) - if total_length == 0: - return np.zeros((0, 0), dtype=np.float32), np.zeros((0,), dtype=np.int32) - - mask = np.zeros((total_length, total_length), dtype=np.float32) - pos_ids = np.zeros((total_length,), dtype=np.int32) - offset = 0 - for row_length in row_lengths: - next_offset = offset + int(row_length) - mask[offset:next_offset, offset:next_offset] = 1.0 - pos_ids[offset:next_offset] = np.arange(row_length, dtype=np.int32) - offset = next_offset - return mask, pos_ids - - -def _batch_packed_rows_numpy( - rows, - include_mask=True, - base_model_max_length=512, -): - if not rows: - empty_mask = np.zeros((0, 0, 0), dtype=np.float32) if include_mask else None - return ( - [], - np.zeros((0,), dtype=np.int32), - empty_mask, - np.zeros((0, 0), dtype=np.int32), - ) - - row_lengths = np.asarray([len(row) for row in rows], dtype=np.int32) - max_length = min(int(np.max(row_lengths)), base_model_max_length) - - groups = [[0]] - current_total = int(row_lengths[0]) - for row_idx in range(1, len(rows)): - next_length = int(row_lengths[row_idx]) - if min(current_total, base_model_max_length) <= max_length - next_length: - groups[-1].append(row_idx) - current_total += next_length - else: - groups.append([row_idx]) - current_total = next_length - - packed_rows = [] - packed_seq_lens = [] - packed_pos_ids = [] - packed_masks = [] - - for group in groups: - group_rows = [rows[row_idx] for row_idx in group] - packed_rows.append(np.concatenate(group_rows, axis=0)) - group_row_lengths = [len(group_row) for group_row in group_rows] - packed_seq_lens.append(sum(group_row_lengths)) - mask, pos_ids = _build_block_diagonal_mask_and_pos_ids(group_row_lengths) - packed_pos_ids.append(pos_ids) - if include_mask: - packed_masks.append(mask) - - max_seq_len = max((len(row) for row in packed_rows), default=0) - dense_pos_ids = np.zeros((len(packed_rows), max_seq_len), dtype=np.int32) - dense_masks = ( - np.zeros((len(packed_rows), max_seq_len, max_seq_len), dtype=np.float32) - if include_mask - else None - ) - - for idx, pos_ids in enumerate(packed_pos_ids): - dense_pos_ids[idx, : len(pos_ids)] = pos_ids - if include_mask: - dense_masks[idx, : len(pos_ids), : len(pos_ids)] = packed_masks[idx] - dense_masks[idx, len(pos_ids) :, len(pos_ids) :] = 1.0 - - return ( - packed_rows, - np.asarray(packed_seq_lens, dtype=np.int32), - dense_masks, - dense_pos_ids, - ) - - -def _build_gather_outputs_inference_numpy_arrays( - X, - sequence_lengths, - start, - end, - other_end, - chunk_tables, - include_mask, - base_model_max_length=512, - max_tokens_per_batch=512 * 100, -): - sequence_lengths = np.asarray(sequence_lengths, dtype=np.int32) - start = np.asarray(start, dtype=np.int32) - end = np.asarray(end, dtype=np.int32) - other_end = np.asarray(other_end, dtype=np.int32) - X = np.asarray(X, dtype=np.int32) - - rows = _masked_ragged_indices_numpy_rows( - sequence_lengths=sequence_lengths, start=start, end=end - ) - if chunk_tables: - rows = _chunk_rows_numpy( - rows, - other_end=other_end, - base_model_max_length=base_model_max_length - 2, - ) - - seq_len = int(X.shape[1]) - bos_pad = np.asarray([[0, seq_len + 1]], dtype=np.int32) - eos_pad = np.asarray([[0, seq_len]], dtype=np.int32) - rows = [ - np.concatenate([bos_pad, row, eos_pad], axis=0) - if len(row) - else np.concatenate([bos_pad, eos_pad], axis=0) - for row in rows - ] - - packed_rows, seq_lens, attn_mask, pos_ids = _batch_packed_rows_numpy( - rows, - include_mask=include_mask, - base_model_max_length=base_model_max_length, - ) - - max_seq_len = max((len(row) for row in packed_rows), default=0) - col_values = np.full( - (len(packed_rows), max_seq_len, 2), [0, seq_len + 2], dtype=np.int32 - ) - for idx, row in enumerate(packed_rows): - col_values[idx, : len(row)] = row - - max_length = min( - max_tokens_per_batch // max(len(packed_rows), 1), base_model_max_length - ) - if not chunk_tables: - col_values = col_values[:, :max_length] - pos_ids = pos_ids[:, :max_length] - if include_mask: - attn_mask = attn_mask[:, :max_length, :max_length] - - if attn_mask is None: - attn_mask = None - - return { - "seq_lens": seq_lens.astype(np.int32, copy=False), - "values": col_values.astype(np.int32, copy=False), - "attn_mask": None - if attn_mask is None - else attn_mask.astype(np.float32, copy=False), - "pos_ids": pos_ids.astype(np.int32, copy=False), - } - - -def _build_gather_outputs_inference_numpy( - X, - sequence_lengths, - start, - end, - other_end, - chunk_tables, - include_mask, - base_model_max_length=512, - max_tokens_per_batch=512 * 100, -): - arrays = _build_gather_outputs_inference_numpy_arrays( - X=np.asarray(X.numpy(), dtype=np.int32), - sequence_lengths=np.asarray(sequence_lengths.numpy(), dtype=np.int32), - start=np.asarray(start.numpy(), dtype=np.int32), - end=np.asarray(end.numpy(), dtype=np.int32), - other_end=np.asarray(other_end.numpy(), dtype=np.int32), - chunk_tables=chunk_tables, - include_mask=include_mask, - base_model_max_length=base_model_max_length, - max_tokens_per_batch=max_tokens_per_batch, - ) - col_values = tf.convert_to_tensor(arrays["values"], dtype=tf.int32) - col_values.set_shape([None, None, 2]) - return { - "seq_lens": tf.convert_to_tensor(arrays["seq_lens"], dtype=tf.int32), - "values": col_values, - "attn_mask": None - if arrays["attn_mask"] is None - else tf.convert_to_tensor(arrays["attn_mask"], dtype=tf.float32), - "pos_ids": tf.convert_to_tensor(arrays["pos_ids"], dtype=tf.int32), - } - - def batch_packing( ragged_input, include_mask=True, @@ -883,62 +587,6 @@ def modify_input_spec(self, input_spec): types["col_gather"] = gather_types return (types, target_type), (shapes, target_shape) - def _pad_predict_batch(self, feature_batch): - batch_size = len(feature_batch) - max_seq_len = max(len(features["tokens"]) for features in feature_batch) - tokens = np.zeros((batch_size, max_seq_len), dtype=np.int32) - context = np.zeros( - (batch_size, max_seq_len, self.config.context_dim), dtype=np.float32 - ) - lengths = np.zeros((batch_size,), dtype=np.int32) - - for idx, features in enumerate(feature_batch): - seq_len = len(features["tokens"]) - tokens[idx, :seq_len] = features["tokens"] - context[idx, :seq_len] = features["context"] - lengths[idx] = seq_len - - return { - "tokens": tf.convert_to_tensor(tokens, dtype=tf.int32), - "context": tf.convert_to_tensor(context, dtype=tf.float32), - "length": tf.convert_to_tensor(lengths, dtype=tf.int32), - } - - def iter_predict_batches(self, feature_iter, predict_batch_size): - batch = [] - for features in feature_iter: - batch.append(features) - if len(batch) == predict_batch_size: - yield self._postprocess_inference(self._pad_predict_batch(batch)) - batch = [] - if batch: - yield self._postprocess_inference(self._pad_predict_batch(batch)) - - def _postprocess_inference(self, x): - with tf.device("/CPU:0"): - end_col, end_row, start_col, start_row = tf.unstack( - tf.cast(x["context"], tf.int32), num=4, axis=2 - ) - row_gather = _build_gather_outputs_inference_numpy( - X=x["tokens"], - sequence_lengths=x["length"], - start=start_row, - end=end_row, - other_end=end_col, - chunk_tables=self.config.chunk_tables, - include_mask=True, - ) - col_gather = _build_gather_outputs_inference_numpy( - X=x["tokens"], - sequence_lengths=x["length"], - start=start_col, - end=end_col, - other_end=end_row, - chunk_tables=self.config.chunk_tables, - include_mask=True, - ) - return {**x, "row_gather": row_gather, "col_gather": col_gather} - def _postprocess( self, x, diff --git a/finetune/scheduler.py b/finetune/scheduler.py index ea57aa1c..55add629 100644 --- a/finetune/scheduler.py +++ b/finetune/scheduler.py @@ -2,7 +2,7 @@ import gc import logging import sys -from collections import OrderedDict, defaultdict, deque +from collections import OrderedDict import psutil import pynvml @@ -47,13 +47,9 @@ def scheduled_predict( ): # this is just for backwards compat, should always have a blob key going forward cache_key = kwargs.pop("cache_key", None) - resolved_cache_key = self.model_cache_key( - model_file, key=key, cache_key=cache_key - ) model = self._rotate_in_model( model_file, key=key, config_overrides=config_overrides, cache_key=cache_key ) - self._reset_peak_memory_stats() try: preds = fn(self, model_file=model_file, x=x, *args, model=model, **kwargs) except Exception as orig_except: @@ -72,7 +68,6 @@ def scheduled_predict( config_overrides=config_overrides, cache_key=cache_key, ) - self._reset_peak_memory_stats() preds = fn( self, model_file=model_file, x=x, *args, model=model, **kwargs ) @@ -82,7 +77,6 @@ def scheduled_predict( str(orig_except), str(e) ) ) - self._record_prediction_memory(model, resolved_cache_key) self._update_memory_limit(model) return preds @@ -90,166 +84,45 @@ def scheduled_predict( class Scheduler: - # Scheduler behavior overview: - # - # 1. Models are cached by resolved cache key (`model_cache_key`) and tracked in - # `loaded_models` as an LRU queue. Cache hits move the active key to the end - # of that queue. Cache misses load the model, then register its base model - # type so future memory estimates are source-specific. - # - # 2. GPU admission is based on recent observed memory for the model's - # `config.base_model` type, not a single global historical max. For each - # source model type we keep a bounded window of recent: - # - `max_above_resting`: transient peak above steady-state memory during - # prediction - # - `model_size`: approximate resident memory added by loading the model - # This gives fast recovery from one-off spikes while still preserving a - # conservative running max for recent behavior. - # - # 3. On cache miss, the model is loaded before the final GPU headroom decision. - # That lets us discover the source model type and use its own memory history - # instead of falling back to unrelated global spikes. CPU pressure is still - # checked first so we can evict before loading if host RAM is tight. - # - # 4. Before every prediction, including cache hits, we run an execution-time - # headroom check. If the active model's expected transient execution memory - # does not fit alongside currently loaded peers, we evict older unrelated - # models until it does. Very large transient models are forced into - # effectively exclusive execution because allocator fragmentation can make a - # "fits on paper" run still fail in practice. - # - # 5. Peak memory statistics are reset around each prediction attempt, then the - # successful attempt records fresh memory samples back into the per-source - # windows. If prediction still fails, we keep the existing recovery path: - # close all models, reload, and retry once before raising a - # `FinetuneSchedulerError`. def __init__( - self, - max_models=None, - config=None, - reserved=750000000, - ram_max_frac=0.8, - memory_window_size=10, - execution_safety_margin=512 * 1024 * 1024, + self, max_models=None, config=None, reserved=750000000, ram_max_frac=0.8 ): self.loaded_models = list() self.max_models = max_models self.gpu_memory_limit = None self.model_cache = dict() self.max_above_resting = None + self.previous_in_use = 0 self.max_model_size = None self.config = config or {} self.reserved = reserved self.ram_max_frac = ram_max_frac self.etl_cache = EtlCache() - self.memory_window_size = memory_window_size - self.execution_safety_margin = execution_safety_margin - self.memory_histories = defaultdict( - lambda: { - "max_above_resting": deque(maxlen=self.memory_window_size), - "model_size": deque(maxlen=self.memory_window_size), - } - ) - self.global_memory_histories = { - "max_above_resting": deque(maxlen=self.memory_window_size), - "model_size": deque(maxlen=self.memory_window_size), - } - self.cache_key_to_source_model = dict() - self.pending_load_baselines = dict() - - def _source_model_key_from_model(self, model): - base_model = getattr(getattr(model, "config", None), "base_model", None) - if base_model is None: - return None - return getattr(base_model, "__name__", str(base_model)) - - def _append_memory_sample(self, stat_name, source_model_key, value): - if value is None: - return - value = max(int(value), 0) - self.global_memory_histories[stat_name].append(value) - if source_model_key is not None: - self.memory_histories[source_model_key][stat_name].append(value) - - def _record_prediction_memory_sample( - self, source_model_key, max_above_resting=None, model_size=None - ): - self._append_memory_sample( - "max_above_resting", source_model_key, max_above_resting - ) - self._append_memory_sample("model_size", source_model_key, model_size) - estimated = self._estimated_memory_stats(source_model_key) - self.max_above_resting = estimated["max_above_resting"] - self.max_model_size = estimated["model_size"] - def _estimated_memory_stats( - self, source_model_key=None, allow_global_fallback=True - ): - stats = {} - for stat_name in ("max_above_resting", "model_size"): - history = None - if source_model_key is not None: - history = self.memory_histories[source_model_key][stat_name] - if history: - stats[stat_name] = max(history) - elif allow_global_fallback and self.global_memory_histories[stat_name]: - stats[stat_name] = max(self.global_memory_histories[stat_name]) - else: - stats[stat_name] = 0 - return stats - - def _reset_peak_memory_stats(self): - if not is_gpu_available(): - return - reset_memory_stats = getattr(tf.config.experimental, "reset_memory_stats", None) - if reset_memory_stats is None: - return - try: - reset_memory_stats("GPU:0") - except Exception: - LOGGER.debug("Failed to reset TensorFlow peak memory stats.", exc_info=True) - - def _record_prediction_memory(self, model, resolved_cache_key): - source_model_key = self._source_model_key_from_model(model) + def _memory_for_one_more(self): + if self.gpu_memory_limit is None: + return True # first run if is_gpu_available(): in_use = BytesInUse() peak = MaxBytesInUse() - max_above_resting = max(peak - in_use, 0) - else: - in_use = 0 - max_above_resting = 0 - - load_baseline = self.pending_load_baselines.pop(resolved_cache_key, None) - model_size = None - if load_baseline is not None: - model_size = max(in_use - load_baseline, 0) + if ( + self.max_above_resting is None + or (peak - in_use) > self.max_above_resting + ): + self.max_above_resting = peak - in_use - self._record_prediction_memory_sample( - source_model_key, - max_above_resting=max_above_resting, - model_size=model_size, - ) + if ( + self.max_model_size is None + or (in_use - self.previous_in_use) > self.max_model_size + ): + self.max_model_size = in_use - self.previous_in_use - def _has_memory_for_model( - self, - source_model_key=None, - include_model_size=True, - allow_global_fallback=True, - extra_model_buffer=0, - ): - if self.gpu_memory_limit is None: - return True # first run - estimated = self._estimated_memory_stats( - source_model_key, allow_global_fallback=allow_global_fallback - ) - self.max_above_resting = estimated["max_above_resting"] - self.max_model_size = estimated["model_size"] - if is_gpu_available(): - in_use = BytesInUse() + self.previous_in_use = in_use else: LOGGER.info("No GPU available, skipping GPU memory checks.") self.max_above_resting = 0 self.max_model_size = 0 + self.previous_in_use = 0 in_use = 0 cpu_percent = psutil.virtual_memory().percent @@ -268,72 +141,18 @@ def _has_memory_for_model( ) if cpu_percent > self.ram_max_frac * 100: return False - required_gpu = ( - in_use + self.max_above_resting + self.reserved + extra_model_buffer - ) - if include_model_size: - required_gpu += self.max_model_size - return required_gpu < self.gpu_memory_limit - - def _memory_for_one_more(self, source_model_key=None): - return self._has_memory_for_model(source_model_key=source_model_key) - - def _ensure_cpu_headroom(self, exclude=None): - exclude = set(exclude or []) - while psutil.virtual_memory().percent > self.ram_max_frac * 100: - if not self._close_oldest_model(exclude=exclude): - return False - return True - - def _ensure_prediction_headroom( - self, - source_model_key, - active_cache_key, - include_model_size, - allow_global_fallback, - ): - estimated = self._estimated_memory_stats( - source_model_key, allow_global_fallback=allow_global_fallback - ) - global_model_size = 0 - if self.global_memory_histories["model_size"]: - global_model_size = max(self.global_memory_histories["model_size"]) - transient_buffer = int(estimated["max_above_resting"] * 0.2) - execution_buffer = max( - global_model_size, self.execution_safety_margin, transient_buffer - ) - while not self._has_memory_for_model( - source_model_key=source_model_key, - include_model_size=include_model_size, - allow_global_fallback=allow_global_fallback, - extra_model_buffer=execution_buffer, - ): - if not self._close_oldest_model(exclude={active_cache_key}): - return False - include_model_size = False - if ( - self.gpu_memory_limit is not None - and estimated["max_above_resting"] > self.gpu_memory_limit * 0.5 - ): - while len(self.loaded_models) > 1: - if not self._close_oldest_model(exclude={active_cache_key}): - return False - return True - - def _close_oldest_model(self, exclude=None): - exclude = set(exclude or []) - for idx, name in enumerate(self.loaded_models): - if name in exclude: - continue - self.loaded_models.pop(idx) + return ( + in_use + self.max_above_resting + self.max_model_size + self.reserved + ) < self.gpu_memory_limit + + def _close_oldest_model(self): + if len(self.loaded_models): + name = self.loaded_models.pop(0) self.model_cache[name].close(update_saver=False) del self.model_cache[name] - self.pending_load_baselines.pop(name, None) - self.cache_key_to_source_model.pop(name, None) gc.collect() - return True - LOGGER.info("No models cached -- cannot remove oldest model.") - return False + else: + LOGGER.info("No models cached -- cannot remove oldest model.") def model_cache_key(self, model, key, cache_key): if cache_key is None: @@ -350,26 +169,16 @@ def model_cache_key(self, model, key, cache_key): def _rotate_in_model(self, model, key, config_overrides=None, cache_key=None): resolved_cache_key = self.model_cache_key(model, key=key, cache_key=cache_key) - source_model_key = self.cache_key_to_source_model.get(resolved_cache_key) - cache_miss = resolved_cache_key not in self.loaded_models - if cache_miss: - while ( + if resolved_cache_key not in self.loaded_models: + if ( self.max_models is not None and len(self.loaded_models) + 1 > self.max_models - ): - if not self._close_oldest_model(): - break - self._ensure_cpu_headroom() + ) or not self._memory_for_one_more(): + self._close_oldest_model() config_overrides = config_overrides or {} merged_config = {**self.config, **config_overrides} - load_baseline = BytesInUse() if is_gpu_available() else 0 out_model = BaseModel.load(model, key=key, **merged_config) self.model_cache[resolved_cache_key] = out_model - self.pending_load_baselines[resolved_cache_key] = load_baseline - self.cache_key_to_source_model[ - resolved_cache_key - ] = self._source_model_key_from_model(out_model) - source_model_key = self.cache_key_to_source_model[resolved_cache_key] else: out_model = self.model_cache[resolved_cache_key] self.loaded_models.remove( @@ -377,12 +186,6 @@ def _rotate_in_model(self, model, key, config_overrides=None, cache_key=None): ) # put it back at the end of the queue self.loaded_models.append(resolved_cache_key) - self._ensure_prediction_headroom( - source_model_key=source_model_key, - active_cache_key=resolved_cache_key, - include_model_size=cache_miss, - allow_global_fallback=False, - ) out_model._cached_predict = True return out_model @@ -392,13 +195,9 @@ def _update_memory_limit(self, model): del model.saver.variables del model.saver.fallback_ if is_gpu_available(): - gpu_memory_limit = BytesLimit() - gpu_fraction = self.config.get("per_process_gpu_memory_fraction") - if gpu_fraction is not None: - gpu_memory_limit = min( - gpu_memory_limit, int(gpu_memory_limit * gpu_fraction) - ) - self.gpu_memory_limit = gpu_memory_limit + self.gpu_memory_limit = ( + BytesLimit() + ) # delay this so that any options get applied from finetune. else: LOGGER.info("No GPU available, skipping GPU memory limit update.") self.gpu_memory_limit = sys.maxsize @@ -406,7 +205,6 @@ def _update_memory_limit(self, model): def close_all(self): while self.loaded_models: self._close_oldest_model() - self.pending_load_baselines.clear() MODEL_REGISTRY.cleanup() @scheduled diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index e0f5d5a0..a918c88b 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,280 +1,54 @@ -from types import SimpleNamespace +import os +import time -import numpy as np import pytest -import tensorflow as tf -import finetune.scheduler as scheduler_mod -from finetune.base_models import SourceModel -from finetune.base_models.bert.table_utils import TableModelBatchPostprocessor -from finetune.errors import FinetuneSchedulerError from finetune.scheduler import Scheduler -class DummySourceModelA(SourceModel): - encoder = None - featurizer = None - settings = {} - - -class DummySourceModelB(SourceModel): - encoder = None - featurizer = None - settings = {} - - -class DummyLoadedModel: - def __init__(self, base_model, outcomes=None): - self.config = SimpleNamespace(base_model=base_model) - self.saver = SimpleNamespace(variables={}, fallback_=object()) - self._outcomes = list(outcomes or [["ok"]]) - self.close_calls = 0 - self._cached_predict = False - - def predict(self, x, *args, **kwargs): - outcome = self._outcomes.pop(0) - if isinstance(outcome, Exception): - raise outcome - return outcome - - def close(self, update_saver=False): - self.close_calls += 1 - - -@pytest.fixture(autouse=True) -def patch_cleanup(monkeypatch): - monkeypatch.setattr(scheduler_mod.MODEL_REGISTRY, "cleanup", lambda: None) - - @pytest.fixture -def fake_memory(monkeypatch): - state = {"in_use": 0, "peak": 0, "gpu_available": True, "cpu_percent": 10} - - monkeypatch.setattr( - scheduler_mod, - "is_gpu_available", - lambda: state["gpu_available"], - ) - monkeypatch.setattr(scheduler_mod, "BytesInUse", lambda: state["in_use"]) - monkeypatch.setattr(scheduler_mod, "MaxBytesInUse", lambda: state["peak"]) - monkeypatch.setattr(scheduler_mod, "BytesLimit", lambda: 1000) - monkeypatch.setattr( - scheduler_mod.psutil, - "virtual_memory", - lambda: SimpleNamespace(percent=state["cpu_percent"]), - ) - return state - - -def test_memory_windows_are_bounded_per_source_model(): - shed = Scheduler(memory_window_size=3) - - for sample in (100, 10, 9, 8): - shed._record_prediction_memory_sample( - "DummySourceModelA", max_above_resting=sample, model_size=sample // 2 - ) - shed._record_prediction_memory_sample( - "DummySourceModelB", max_above_resting=5, model_size=7 - ) - - assert shed._estimated_memory_stats("DummySourceModelA") == { - "max_above_resting": 10, - "model_size": 5, - } - assert shed._estimated_memory_stats("DummySourceModelB") == { - "max_above_resting": 5, - "model_size": 7, - } - assert shed._estimated_memory_stats("unknown") == { - "max_above_resting": 9, - "model_size": 7, - } - - -def test_cache_miss_loads_first_and_uses_source_specific_stats( - monkeypatch, fake_memory -): - shed = Scheduler(reserved=0, execution_safety_margin=0) - shed.gpu_memory_limit = 1000 - shed._record_prediction_memory_sample( - "UnrelatedSource", max_above_resting=900, model_size=50 - ) - - keep = DummyLoadedModel(DummySourceModelB) - shed.model_cache["keep"] = keep - shed.loaded_models.append("keep") - fake_memory["in_use"] = 100 - - loaded = [] - - def fake_load(model_file, key=None, **kwargs): - model = DummyLoadedModel(DummySourceModelA) - loaded.append((model_file, key, model)) - return model - - monkeypatch.setattr(scheduler_mod.BaseModel, "load", fake_load) - - out_model = shed.get_model("fresh.jl") - - assert out_model is loaded[0][2] - assert shed.cache_key_to_source_model["model=fresh.jl"] == "DummySourceModelA" - assert shed.loaded_models == ["keep", "model=fresh.jl"] - assert keep.close_calls == 0 - - -def test_cache_hit_execution_headroom_evicts_unrelated_models(monkeypatch, fake_memory): - shed = Scheduler(reserved=0, execution_safety_margin=0) - shed.gpu_memory_limit = 1000 - shed._record_prediction_memory_sample( - "DummySourceModelA", max_above_resting=400, model_size=0 - ) - - other = DummyLoadedModel(DummySourceModelB) - active = DummyLoadedModel(DummySourceModelA) - shed.model_cache = {"other": other, "active": active} - shed.loaded_models = ["other", "active"] - shed.cache_key_to_source_model = { - "other": "DummySourceModelB", - "active": "DummySourceModelA", - } - - fake_memory["in_use"] = 700 - - def dynamic_in_use(): - return 700 if "other" in shed.loaded_models else 200 +def models(save_model_dir, trained_classifier, trained_annotation): + model1 = os.path.join(save_model_dir, "1.jl") + model2 = os.path.join(save_model_dir, "2.jl") + trained_classifier.save(model1) + trained_annotation.save(model2) + yield model1, model2 - monkeypatch.setattr(scheduler_mod, "BytesInUse", dynamic_in_use) - shed._ensure_prediction_headroom( - source_model_key="DummySourceModelA", - active_cache_key="active", - include_model_size=False, - allow_global_fallback=False, - ) - - assert shed.loaded_models == ["active"] - assert other.close_calls == 1 - - -def test_large_transient_models_run_exclusively(fake_memory): - shed = Scheduler(reserved=0, execution_safety_margin=0) - shed.gpu_memory_limit = 1000 - shed._record_prediction_memory_sample( - "DummySourceModelA", max_above_resting=600, model_size=0 - ) - - older = DummyLoadedModel(DummySourceModelB) - newer = DummyLoadedModel(DummySourceModelB) - active = DummyLoadedModel(DummySourceModelA) - shed.model_cache = {"older": older, "newer": newer, "active": active} - shed.loaded_models = ["older", "newer", "active"] - shed.cache_key_to_source_model = { - "older": "DummySourceModelB", - "newer": "DummySourceModelB", - "active": "DummySourceModelA", - } - fake_memory["in_use"] = 100 - - shed._ensure_prediction_headroom( - source_model_key="DummySourceModelA", - active_cache_key="active", - include_model_size=False, - allow_global_fallback=False, - ) - - assert shed.loaded_models == ["active"] - assert older.close_calls == 1 - assert newer.close_calls == 1 - - -def test_scheduler_retries_after_exception(monkeypatch, fake_memory): - fake_memory["gpu_available"] = False - shed = Scheduler() - loads = [] - - def fake_load(model_file, key=None, **kwargs): - outcomes = [RuntimeError("oom-ish failure")] if not loads else [["ok"]] - model = DummyLoadedModel(DummySourceModelA, outcomes) - loads.append(model) - return model - - monkeypatch.setattr(scheduler_mod.BaseModel, "load", fake_load) - - result = shed.predict("dummy.jl", ["A"]) - - assert result == ["ok"] - assert len(loads) == 2 - assert loads[0].close_calls == 1 - assert shed.loaded_models == ["model=dummy.jl"] - - -def test_scheduler_wraps_retry_failure(monkeypatch, fake_memory): - fake_memory["gpu_available"] = False +def test_scheduler(models): + model1, model2 = models shed = Scheduler() - - def fake_load(model_file, key=None, **kwargs): - return DummyLoadedModel( - DummySourceModelA, - [RuntimeError("first fail"), RuntimeError("second fail")], - ) - - monkeypatch.setattr(scheduler_mod.BaseModel, "load", fake_load) - - with pytest.raises(FinetuneSchedulerError) as excinfo: - shed.predict("dummy.jl", ["A"]) - - message = str(excinfo.value) - assert "Original Error: first fail" in message - assert "Retry Error:" in message - - -def test_update_memory_limit_uses_per_process_fraction(fake_memory): - shed = Scheduler(config={"per_process_gpu_memory_fraction": 0.4}) - model = DummyLoadedModel(DummySourceModelA) - - shed._update_memory_limit(model) - - assert shed.gpu_memory_limit == 400 - - -def test_table_predict_batch_postprocessor_matches_direct_postprocess(): - config = SimpleNamespace( - context_dim=4, - chunk_tables=True, - predict_batch_size=1, - max_length=2048, - ) - postprocessor = TableModelBatchPostprocessor(config=config) - - features = { - "tokens": np.array([1, 2, 3, 4], dtype=np.int32), - "context": np.array( - [ - [0, 0, 0, 0], - [0, 0, 1, 0], - [1, 0, 2, 0], - [1, 0, 3, 0], - ], - dtype=np.float32, - ), - } - - batch = next( - postprocessor.iter_predict_batches(iter([features]), predict_batch_size=1) - ) - direct = postprocessor._postprocess( - { - "tokens": tf.convert_to_tensor([[1, 2, 3, 4]], dtype=tf.int32), - "context": tf.convert_to_tensor([features["context"]], dtype=tf.float32), - "length": tf.convert_to_tensor([4], dtype=tf.int32), - }, - training=False, - ) - - for key in ("tokens", "context", "length"): - np.testing.assert_array_equal(batch[key].numpy(), direct[key].numpy()) - for key in ("row_gather", "col_gather"): - for inner_key in ("seq_lens", "values", "attn_mask", "pos_ids"): - np.testing.assert_array_equal( - batch[key][inner_key].numpy(), direct[key][inner_key].numpy() - ) + tic_1 = time.time() + preds_m1 = shed.predict(model1, ["A"]) # May need isolation + toc_1 = time.time() + assert len(preds_m1) == 1 + assert isinstance(preds_m1[0], str) # classification + preds_m2 = shed.predict(model2, ["A"]) + assert len(preds_m2) == 1 + assert isinstance( + preds_m2[0], list + ) # Annotation - can't realy expect any labels though. + tic_3 = time.time() + shed.predict(model1, ["something else"]) + toc_3 = time.time() + assert toc_1 - tic_1 > toc_3 - tic_3 + assert len(shed.loaded_models) == 2 + shed.close_all() + assert len(shed.loaded_models) == 0 + shed.predict_proba(model1, ["A"]) + shed.featurize(model1, ["A"]) + shed.featurize_sequence(model1, ["A"]) + + +def test_scheduler_max_models(models): + model1, model2 = models + shed = Scheduler(max_models=1) + time_pre = time.time() + pred1a = shed.predict(model1, ["A"]) + time_mid = time.time() + pred1b = shed.predict(model1, ["A"]) + time_end = time.time() + assert time_end - time_mid < time_mid - time_pre - 1 + assert pred1a == pred1b + shed.predict(model2, ["A"]) # Load another model. + assert len(shed.loaded_models) == 1 From fc41b70cf50acc645042406d2f390ee8c3b9dda1 Mon Sep 17 00:00:00 2001 From: benleetownsend Date: Thu, 30 Apr 2026 16:25:13 +0100 Subject: [PATCH 4/4] fix: add defensive non-monotonic case --- finetune/util/table_labeler.py | 45 ++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/finetune/util/table_labeler.py b/finetune/util/table_labeler.py index c8c5093c..0c8f776f 100644 --- a/finetune/util/table_labeler.py +++ b/finetune/util/table_labeler.py @@ -441,7 +441,10 @@ def get_axis_spans(self, context, token_bounds, context_key): max_row = max(r[context_key] for r in context) row_spans = [[] for _ in range(max_row + 1)] for c in context: - row_spans[c[context_key]].append( + row_idx = c[context_key] + if row_idx < 0: + continue + row_spans[row_idx].append( { "start": c["start"], "end": c["end"], @@ -628,8 +631,18 @@ def _make_chunks(self, row_spans): return output_spans def combine_row_spans(self, row_spans, token_spans): - token_starts = [t["start"] for t in token_spans] - token_ends = [t["end"] for t in token_spans] + token_starts = [] + token_ends = [] + token_spans_monotonic = True + for token in token_spans: + token_start = token["start"] + token_end = token["end"] + if token_starts and ( + token_start < token_starts[-1] or token_end < token_ends[-1] + ): + token_spans_monotonic = False + token_starts.append(token_start) + token_ends.append(token_end) combined_rows = [] for row in row_spans: row_out = [] @@ -640,16 +653,22 @@ def combine_row_spans(self, row_spans, token_spans): row_out.append(span) for row_span in row_out: num_tokens = 0 - token_idx = bisect.bisect_left(token_ends, row_span["start"]) - while ( - token_idx < len(token_spans) - and token_starts[token_idx] <= row_span["end"] - ): - token = token_spans[token_idx] - if overlaps_token(row_span, token): - token["used"] = True - num_tokens += 1 - token_idx += 1 + if token_spans_monotonic: + token_idx = bisect.bisect_left(token_ends, row_span["start"]) + while ( + token_idx < len(token_spans) + and token_starts[token_idx] <= row_span["end"] + ): + token = token_spans[token_idx] + if overlaps_token(row_span, token): + token["used"] = True + num_tokens += 1 + token_idx += 1 + else: + for token in token_spans: + if overlaps_token(row_span, token): + token["used"] = True + num_tokens += 1 row_span["num_tokens"] = num_tokens # Accounts for the fact that cells are duplicated when they span cells. row_span["num_effective_tokens"] = (