@@ -527,17 +527,13 @@ def _get_common_model_kwargs(is_gptq, model_path):
527527
528528# NOTE micro_model_state_dict should be None if USE_MICRO_MODELS is true
529529# Otherwise it should be model.state_dict() where model is the AIU model
530- def _get_cpu_model (model_path , gptq_kwargs , micro_model_state_dict = None ):
531- is_gptq = len (gptq_kwargs ) != 0
532- model_kwargs = _get_common_model_kwargs (is_gptq , model_path )
533-
530+ def _get_cpu_model (model_path , is_gptq , is_fp8 , micro_model_state_dict = None , ** kwargs ):
534531 # prepare the cpu model
535532 validation_model = get_model (
536533 device_type = "cpu" ,
537- data_type = None if is_gptq else torch .float32 ,
534+ data_type = None if is_fp8 or is_gptq else torch .float32 ,
538535 fused_weights = False ,
539- ** gptq_kwargs ,
540- ** model_kwargs ,
536+ ** kwargs ,
541537 )
542538
543539 # This is a micro model, so we need to copy the state dict directly.
@@ -548,32 +544,6 @@ def _get_cpu_model(model_path, gptq_kwargs, micro_model_state_dict=None):
548544 return validation_model
549545
550546
551- def _get_aiu_model (model_path , gptq_kwargs , persistent_model_inst ):
552- is_gptq = len (gptq_kwargs ) != 0
553- is_fp8 = "fp8" in ATTN_NAME
554- model_kwargs = _get_common_model_kwargs (is_gptq , model_path )
555-
556- # prepare the AIU model; use the persistent model fixure if the test has it
557- if persistent_model_inst is not None :
558- aiu_model = persistent_model_inst .get_or_create (
559- is_gptq , is_fp8 , ** gptq_kwargs , ** model_kwargs
560- )
561- # otherwise create it directly
562- else :
563- aiu_model = get_model (
564- device_type = "cpu" ,
565- data_type = None if is_gptq else torch .float16 ,
566- fused_weights = False ,
567- ** gptq_kwargs ,
568- ** model_kwargs ,
569- )
570- aiu_model .eval ()
571- aiu_model .compile (
572- backend = "sendnn" ,
573- options = {"sendnn.dynamic" : COMPILE_DYNAMIC_SENDNN },
574- )
575- return aiu_model
576-
577547
578548def _get_device_validation_information (
579549 model_path ,
@@ -891,17 +861,24 @@ def test_common_shapes(
891861 # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
892862 gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
893863
894- model = _get_aiu_model (
895- model_path ,
896- gptq_kwargs_aiu ,
897- persistent_model_inst = persistent_model ,
864+ is_gptq = len (gptq_kwargs_aiu ) != 0
865+ is_fp8 = "fp8" in ATTN_NAME
866+ model_kwargs = _get_common_model_kwargs (is_gptq , model_path )
867+
868+ # Get the AIU model w/ the persistent model fixture
869+ model = persistent_model .get_or_create (
870+ is_gptq , is_fp8 , ** gptq_kwargs_aiu , ** model_kwargs
898871 )
899872
900873 validation_model = _get_cpu_model (
901874 model_path ,
902- gptq_kwargs_cpu ,
875+ is_gptq ,
876+ is_fp8 ,
903877 micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
878+ ** gptq_kwargs_cpu ,
879+ ** model_kwargs ,
904880 )
881+
905882 _run_cpu_aiu_validation_test (
906883 model_path ,
907884 batch_size ,
0 commit comments