Skip to content

Commit e79e65c

Browse files
Merge pull request #204 from foundation-model-stack/fasoli/upgrade_transf
feat: update transformers to 5.x
2 parents 6de59b9 + d9f67cb commit e79e65c

5 files changed

Lines changed: 37 additions & 8 deletions

File tree

fms_mo/dq.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,15 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
8888
config_kwargs = {
8989
"cache_dir": model_args.cache_dir,
9090
"revision": model_args.model_revision,
91-
"use_auth_token": True if model_args.use_auth_token else None,
92-
"torchscript": True,
91+
"token": True if model_args.use_auth_token else None,
9392
"attn_implementation": attn_implementation,
9493
}
9594
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
9695
tokenizer_kwargs = {
9796
"cache_dir": model_args.cache_dir,
9897
"use_fast": model_args.use_fast_tokenizer,
9998
"revision": model_args.model_revision,
100-
"use_auth_token": True if model_args.use_auth_token else None,
99+
"token": True if model_args.use_auth_token else None,
101100
}
102101
tokenizer = AutoTokenizer.from_pretrained(
103102
model_args.model_name_or_path, **tokenizer_kwargs
@@ -121,7 +120,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
121120
config=config,
122121
cache_dir=model_args.cache_dir,
123122
revision="main",
124-
use_auth_token=True if model_args.use_auth_token else None,
123+
token=True if model_args.use_auth_token else None,
125124
torch_dtype=torch_dtype,
126125
device_map=model_args.device_map,
127126
low_cpu_mem_usage=bool(model_args.device_map),

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dynamic = ["version"]
2424
dependencies = [
2525
"numpy>=1.26.4,<2.3.0",
2626
"accelerate>=0.20.3,!=0.34,<1.11",
27-
"transformers>=4.45,<4.58",
27+
"transformers>4.45,<5.9",
2828
"torch>=2.2.0,<2.11.0",
2929
"tqdm>=4.66.2,<5.0",
3030
"datasets>=3.0.0,<5.0",
@@ -36,6 +36,7 @@ dependencies = [
3636
[project.optional-dependencies]
3737
examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"]
3838
fp8 = ["llmcompressor", "torchao==0.11"] # FP8 matmul on CPU needs a fix before advancing torchao > 0.11
39+
fp8-infer = ["torchao==0.11"]
3940
gptq = ["Cython", "gptqmodel>=1.7.3"]
4041
mx = ["microxcaling>=1.1"]
4142
opt = ["fms-model-optimizer[fp8, gptq, mx]"]

tests/build/test_launch_script.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# Third Party
2323
import pytest
2424
import torch
25+
import transformers
2526

2627
# First Party
2728
from build.accelerate_launch import main
@@ -241,16 +242,21 @@ def _validate_termination_files_when_quantization_succeeds(base_dir):
241242
"""Check whether the termination log and .complete files exists"""
242243
assert os.path.exists(os.path.join(base_dir, "/termination-log")) is False
243244
assert os.path.exists(os.path.join(base_dir, ".complete")) is True
244-
# assert os.path.exists(os.path.join(base_dir, training_logs_filename)) is True
245245

246246

247247
def _validate_quantization_output(base_dir, quant_method):
248248
"""Check whether the tokenizer and quantized model artifacts exists"""
249249
# Check tokenizer files exist
250250
assert os.path.exists(os.path.join(base_dir, "tokenizer.json")) is True
251-
assert os.path.exists(os.path.join(base_dir, "special_tokens_map.json")) is True
251+
252+
# special_tokens_map.json is optional in transformers 5.0+ for some tokenizers
253+
transformers_version = tuple(
254+
int(x) for x in transformers.__version__.split(".")[:2]
255+
)
256+
if transformers_version < (5, 0):
257+
assert os.path.exists(os.path.join(base_dir, "special_tokens_map.json")) is True
258+
252259
assert os.path.exists(os.path.join(base_dir, "tokenizer_config.json")) is True
253-
# assert os.path.exists(os.path.join(base_dir, "tokenizer.model")) is True
254260

255261
# Check quantized model files exist
256262
if quant_method == "gptq":

tests/models/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import pytest
4040
import torch
4141
import torch.nn.functional as F
42+
import transformers
4243

4344
# Local
4445
# fms_mo imports
@@ -1302,6 +1303,12 @@ def model_bert():
13021303
Returns:
13031304
transformers.models.bert.modeling_bert.BertModel: BERT model
13041305
"""
1306+
# torchscript parameter removed in transformers 5.0
1307+
transformers_version = tuple(
1308+
int(x) for x in transformers.__version__.split(".")[:2]
1309+
)
1310+
if transformers_version >= (5, 0):
1311+
return BertModel.from_pretrained("google-bert/bert-base-uncased")
13051312
return BertModel.from_pretrained("google-bert/bert-base-uncased", torchscript=True)
13061313

13071314

@@ -1313,6 +1320,14 @@ def model_bert_eager():
13131320
Returns:
13141321
transformers.models.bert.modeling_bert.BertModel: BERT model
13151322
"""
1323+
# torchscript parameter removed in transformers 5.0
1324+
transformers_version = tuple(
1325+
int(x) for x in transformers.__version__.split(".")[:2]
1326+
)
1327+
if transformers_version >= (5, 0):
1328+
return BertModel.from_pretrained(
1329+
"google-bert/bert-base-uncased", attn_implementation="eager"
1330+
)
13161331
return BertModel.from_pretrained(
13171332
"google-bert/bert-base-uncased", torchscript=True, attn_implementation="eager"
13181333
)

tests/models/test_qmodelprep.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ def test_vit_dynamo(
272272
qmodule_error(model_vit, 2, 36)
273273

274274

275+
@pytest.mark.skipif(
276+
not available_packages["torchvision"],
277+
reason="Requires torchvision",
278+
)
275279
def test_resnet18(
276280
model_resnet18,
277281
batch_resnet18,
@@ -290,6 +294,10 @@ def test_resnet18(
290294
qmodule_error(model_resnet18, 4, 17)
291295

292296

297+
@pytest.mark.skipif(
298+
not available_packages["torchvision"],
299+
reason="Requires torchvision",
300+
)
293301
def test_vit_base(
294302
model_vit_base,
295303
batch_vit_base,

0 commit comments

Comments
 (0)