Skip to content

Commit 7d72283

Browse files
FIX: [DEV-14438] ModernBert Accuracy new predict batch size default (#854)
* FIX: [DEV-14438] ModernBert Accuracy new predict batch size default * FIX: being defensive
1 parent 2d1ab49 commit 7d72283

3 files changed

Lines changed: 29 additions & 1 deletion

File tree

finetune/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,19 @@ def resolve_config(self, no_fp16=False, **kwargs):
187187
raise ValueError("There is no auto setting for {}".format(ak))
188188
config[ak] = overrides[ak]
189189

190+
override_predict_batch = overrides.get("predict_batch_size")
191+
current_predict_batch = config.get("predict_batch_size")
192+
if (
193+
override_predict_batch is not None
194+
and current_predict_batch is not None
195+
and override_predict_batch < current_predict_batch
196+
and hasattr(self, "config")
197+
):
198+
LOGGER.info(
199+
f"Overriding loaded predict batch size from {current_predict_batch} to the new default of {override_predict_batch}"
200+
)
201+
self.config["predict_batch_size"] = override_predict_batch
202+
190203
if hasattr(self, "input_pipeline"):
191204
self.input_pipeline.config = config
192205

finetune/base_models/modern_bert/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_optimal_params(cls, config):
7373
"n_epochs": base_n_epochs,
7474
"batch_size": 8,
7575
"chunk_context": None,
76-
"predict_batch_size": 16,
76+
"predict_batch_size": 8,
7777
"mixed_precision": True,
7878
"float_16_predict": True,
7979
"lr": base_learning_rate,

tests/test_backwards_compat.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import glob
22
import os
3+
from pathlib import Path
34

45
import numpy as np
56
import pytest
@@ -45,6 +46,20 @@ def test_roberta_default_no_change(get_untrained_sequence_labeler):
4546
assert model.config.collapse_whitespace == True
4647

4748

49+
def test_table_model_predict_batch_size_override(monkeypatch):
50+
model_path = (
51+
Path(__file__).parent
52+
/ "backwards_compat_bundles"
53+
/ "ModernBertModel_sequence_labeling_crf_True.jl"
54+
)
55+
model = SequenceLabeler.load(model_path, key="model")
56+
try:
57+
assert model.config.optimize_for.lower() == "accuracy"
58+
assert model.config.predict_batch_size == 8
59+
finally:
60+
model.close()
61+
62+
4863
# TODO: eventually clean this up and push these files to s3.
4964
BUNDLES = glob.glob(os.path.join("/Finetune/tests/backwards_compat_bundles/*.jl"))
5065

0 commit comments

Comments
 (0)