@@ -4251,44 +4251,50 @@ def test_causal_lm_training_multi_gpu_eetq(self):
42514251
42524252@require_non_cpu
42534253@require_torchao
4254- class PeftTorchaoGPUTests (unittest .TestCase ):
4255- r"""
4256- torchao + peft tests
4257- """
4258-
4254+ class TestPeftTorchao :
4255+ causal_lm_model_id = "peft-internal-testing/opt-125m"
42594256 supported_quant_types = [
42604257 "int8_weight_only" ,
42614258 "int8_dynamic_activation_int8_weight" ,
42624259 # int4_weight_only raises an error:
4263- # RuntimeError: derivative for aten::_weight_int4pack_mm is not implemented
4260+ # RuntimeError: We encountered some issues during automatic conversion of the weights
42644261 # "int4_weight_only",
42654262 ]
42664263
4267- def setUp (self ):
4268- self .causal_lm_model_id = "peft-internal-testing/opt-125m"
4269- self .tokenizer = AutoTokenizer .from_pretrained (self .causal_lm_model_id )
4270- # torchao breaks with fp16 and if a previous test uses fp16, transformers will set this env var, which affects
4271- # subsequent tests, therefore the env var needs to be cleared explicitly
4272- #
4273- # TODO: remove this once https://github.com/huggingface/transformers/pull/39483 is merged
4274- os .environ .pop ("ACCELERATE_MIXED_PRECISION" , None )
4264+ @pytest .fixture (scope = "class" )
4265+ def tokenizer (self ):
4266+ return AutoTokenizer .from_pretrained (self .causal_lm_model_id )
42754267
4276- def tearDown ( self ):
4277- r"""
4278- Efficient mechanism to free GPU memory after each test. Based on
4279- https://github.com/huggingface/transformers/issues/21094
4280- """
4268+ @ pytest . fixture ( scope = "class" , autouse = True )
4269+ def setup_teardown ( self ):
4270+ # Efficient mechanism to free GPU memory after each test. Based on
4271+ # https://github.com/huggingface/transformers/issues/21094
4272+ yield
42814273 clear_device_cache (garbage_collection = True )
42824274
4283- @parameterized .expand (supported_quant_types )
4275+ @staticmethod
4276+ def get_quant_type (quant_type : str ):
4277+ from torchao .quantization import (
4278+ Int4WeightOnlyConfig ,
4279+ Int8DynamicActivationInt8WeightConfig ,
4280+ Int8WeightOnlyConfig ,
4281+ )
4282+
4283+ return {
4284+ "int4_weight_only" : Int4WeightOnlyConfig (),
4285+ "int8_weight_only" : Int8WeightOnlyConfig (),
4286+ "int8_dynamic_activation_int8_weight" : Int8DynamicActivationInt8WeightConfig (),
4287+ }[quant_type ]
4288+
4289+ @pytest .mark .parametrize ("quant_type" , supported_quant_types )
42844290 @pytest .mark .single_gpu_tests
4285- def test_causal_lm_training_single_gpu_torchao (self , quant_type ):
4291+ def test_causal_lm_training_single_gpu_torchao (self , quant_type , tokenizer ):
42864292 from transformers import TorchAoConfig
42874293
42884294 device = 0
42894295
42904296 with tempfile .TemporaryDirectory () as tmp_dir :
4291- quantization_config = TorchAoConfig (quant_type = quant_type )
4297+ quantization_config = TorchAoConfig (quant_type = self . get_quant_type ( quant_type ) )
42924298 model = AutoModelForCausalLM .from_pretrained (
42934299 self .causal_lm_model_id , device_map = device , quantization_config = quantization_config
42944300 )
@@ -4305,7 +4311,7 @@ def test_causal_lm_training_single_gpu_torchao(self, quant_type):
43054311 model = get_peft_model (model , config )
43064312
43074313 data = load_dataset_english_quotes ()
4308- data = data .map (lambda samples : self . tokenizer (samples ["quote" ]), batched = True )
4314+ data = data .map (lambda samples : tokenizer (samples ["quote" ]), batched = True )
43094315
43104316 trainer = Trainer (
43114317 model = model ,
@@ -4319,7 +4325,7 @@ def test_causal_lm_training_single_gpu_torchao(self, quant_type):
43194325 logging_steps = 1 ,
43204326 output_dir = tmp_dir ,
43214327 ),
4322- data_collator = DataCollatorForLanguageModeling (self . tokenizer , mlm = False ),
4328+ data_collator = DataCollatorForLanguageModeling (tokenizer , mlm = False ),
43234329 )
43244330 trainer .model .config .use_cache = False
43254331 trainer .train ()
@@ -4333,13 +4339,13 @@ def test_causal_lm_training_single_gpu_torchao(self, quant_type):
43334339 assert trainer .state .log_history [- 1 ]["train_loss" ] is not None
43344340
43354341 @pytest .mark .single_gpu_tests
4336- def test_causal_lm_training_single_gpu_torchao_dora_int8_weight_only (self ):
4342+ def test_causal_lm_training_single_gpu_torchao_dora_int8_weight_only (self , tokenizer ):
43374343 from transformers import TorchAoConfig
43384344
43394345 device = 0
43404346
43414347 with tempfile .TemporaryDirectory () as tmp_dir :
4342- quantization_config = TorchAoConfig (quant_type = "int8_weight_only" )
4348+ quantization_config = TorchAoConfig (quant_type = self . get_quant_type ( "int8_weight_only" ) )
43434349 model = AutoModelForCausalLM .from_pretrained (
43444350 self .causal_lm_model_id , device_map = device , quantization_config = quantization_config
43454351 )
@@ -4357,7 +4363,7 @@ def test_causal_lm_training_single_gpu_torchao_dora_int8_weight_only(self):
43574363 model = get_peft_model (model , config )
43584364
43594365 data = load_dataset_english_quotes ()
4360- data = data .map (lambda samples : self . tokenizer (samples ["quote" ]), batched = True )
4366+ data = data .map (lambda samples : tokenizer (samples ["quote" ]), batched = True )
43614367
43624368 trainer = Trainer (
43634369 model = model ,
@@ -4371,7 +4377,7 @@ def test_causal_lm_training_single_gpu_torchao_dora_int8_weight_only(self):
43714377 logging_steps = 1 ,
43724378 output_dir = tmp_dir ,
43734379 ),
4374- data_collator = DataCollatorForLanguageModeling (self . tokenizer , mlm = False ),
4380+ data_collator = DataCollatorForLanguageModeling (tokenizer , mlm = False ),
43754381 )
43764382 trainer .model .config .use_cache = False
43774383 trainer .train ()
@@ -4390,7 +4396,7 @@ def test_causal_lm_training_single_gpu_torchao_dora_int8_dynamic_activation_int8
43904396
43914397 device = 0
43924398
4393- quantization_config = TorchAoConfig (quant_type = "int8_dynamic_activation_int8_weight" )
4399+ quantization_config = TorchAoConfig (quant_type = self . get_quant_type ( "int8_dynamic_activation_int8_weight" ) )
43944400 model = AutoModelForCausalLM .from_pretrained (
43954401 self .causal_lm_model_id , device_map = device , quantization_config = quantization_config
43964402 )
@@ -4419,7 +4425,7 @@ def test_causal_lm_training_single_gpu_torchao_int4_raises(self):
44194425
44204426 device = 0
44214427
4422- quantization_config = TorchAoConfig (quant_type = "int4_weight_only" )
4428+ quantization_config = TorchAoConfig (quant_type = self . get_quant_type ( "int4_weight_only" ) )
44234429 model = AutoModelForCausalLM .from_pretrained (
44244430 self .causal_lm_model_id , device_map = device , quantization_config = quantization_config
44254431 )
@@ -4441,10 +4447,10 @@ def test_causal_lm_training_single_gpu_torchao_int4_raises(self):
44414447 # tested in multiple matchines
44424448 model (inputs )
44434449
4444- @parameterized . expand ( supported_quant_types )
4450+ @pytest . mark . parametrize ( "quant_type" , supported_quant_types )
44454451 @pytest .mark .multi_gpu_tests
44464452 @require_torch_multi_accelerator
4447- def test_causal_lm_training_multi_accelerator_torchao (self , quant_type ):
4453+ def test_causal_lm_training_multi_accelerator_torchao (self , quant_type , tokenizer ):
44484454 from transformers import TorchAoConfig
44494455
44504456 device_map = {
@@ -4469,7 +4475,7 @@ def test_causal_lm_training_multi_accelerator_torchao(self, quant_type):
44694475 }
44704476
44714477 with tempfile .TemporaryDirectory () as tmp_dir :
4472- quantization_config = TorchAoConfig (quant_type = quant_type )
4478+ quantization_config = TorchAoConfig (quant_type = self . get_quant_type ( quant_type ) )
44734479 model = AutoModelForCausalLM .from_pretrained (
44744480 self .causal_lm_model_id ,
44754481 device_map = device_map ,
@@ -4495,7 +4501,7 @@ def test_causal_lm_training_multi_accelerator_torchao(self, quant_type):
44954501 model = get_peft_model (model , config )
44964502
44974503 data = load_dataset_english_quotes ()
4498- data = data .map (lambda samples : self . tokenizer (samples ["quote" ]), batched = True )
4504+ data = data .map (lambda samples : tokenizer (samples ["quote" ]), batched = True )
44994505
45004506 trainer = Trainer (
45014507 model = model ,
@@ -4509,7 +4515,7 @@ def test_causal_lm_training_multi_accelerator_torchao(self, quant_type):
45094515 logging_steps = 1 ,
45104516 output_dir = tmp_dir ,
45114517 ),
4512- data_collator = DataCollatorForLanguageModeling (self . tokenizer , mlm = False ),
4518+ data_collator = DataCollatorForLanguageModeling (tokenizer , mlm = False ),
45134519 )
45144520 trainer .model .config .use_cache = False
45154521 trainer .train ()
@@ -4550,7 +4556,7 @@ def test_causal_lm_training_multi_accelerator_torchao_int4_raises(self):
45504556 "model.decoder.layers.11" : 1 ,
45514557 "model.decoder.final_layer_norm" : 1 ,
45524558 }
4553- quantization_config = TorchAoConfig (quant_type = "int4_weight_only" )
4559+ quantization_config = TorchAoConfig (self . get_quant_type ( quant_type = "int4_weight_only" ) )
45544560 model = AutoModelForCausalLM .from_pretrained (
45554561 self .causal_lm_model_id ,
45564562 device_map = device_map ,
@@ -4588,7 +4594,7 @@ def test_torchao_merge_layers_int8_weight_only(self):
45884594 device = 0
45894595 dummy_input = torch .arange (10 ).view (- 1 , 1 ).to (device )
45904596
4591- quantization_config = TorchAoConfig (quant_type = quant_type )
4597+ quantization_config = TorchAoConfig (self . get_quant_type ( quant_type = quant_type ) )
45924598 model = AutoModelForCausalLM .from_pretrained (
45934599 self .causal_lm_model_id , device_map = device , quantization_config = quantization_config
45944600 ).eval ()
@@ -4641,7 +4647,7 @@ def test_torchao_merge_layers_int8_dynamic_activation_int8_weight_raises(self):
46414647 torch .manual_seed (0 )
46424648 device = 0
46434649
4644- quantization_config = TorchAoConfig (quant_type = quant_type )
4650+ quantization_config = TorchAoConfig (quant_type = self . get_quant_type ( quant_type ) )
46454651 model = AutoModelForCausalLM .from_pretrained (
46464652 self .causal_lm_model_id , device_map = device , quantization_config = quantization_config
46474653 ).eval ()
0 commit comments