2424)
2525import json
2626from aiu_fms_testing_utils .utils .aiu_setup import dprint , aiu_dist_setup
27- import shutil
2827import os
2928
3029try :
@@ -787,7 +786,6 @@ def _run_cpu_aiu_validation_test(
787786 cpu_model ,
788787 aiu_model ,
789788 micro_model_path ,
790- verify_cache_state = None ,
791789):
792790 # Get the tokenizer and AIU / CPU models to compare
793791 tokenizer = tokenizers .get_tokenizer (model_path )
@@ -813,12 +811,6 @@ def _run_cpu_aiu_validation_test(
813811 aiu_model ,
814812 )
815813
816- # Used only for cache tests; this is a nonparametric closure that
817- # should assert the cache for torch sendnn is in the correct state
818- # for this test
819- if verify_cache_state is not None :
820- verify_cache_state ()
821-
822814 # if level 0 fails validation, validate level 1
823815 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0 :
824816 if failed_validation_level_0 :
@@ -840,87 +832,6 @@ def _run_cpu_aiu_validation_test(
840832 )
841833
842834
843- def _reset_cache_settings (purge_cache_dir ):
844- os .environ ["TORCH_SENDNN_CACHE_ENABLE" ] = "1"
845- os .environ ["COMPILATION_MODE" ] = "offline_decoder"
846- cache_dir = os .environ ["TORCH_SENDNN_CACHE_DIR" ]
847-
848- # Ensure we start in clean state
849- if purge_cache_dir and os .path .isdir (cache_dir ):
850- shutil .rmtree (cache_dir )
851- os .mkdir (cache_dir )
852-
853- from torch_sendnn .backends import cache
854-
855- # Explicitly clear cache paths from the global torch sendnn graph;
856- # TODO would be better to add a helper to explicitly do this in
857- # torch sendnn
858- cache .cache = {}
859-
860-
861- @pytest .fixture
862- def use_cached_model ():
863- """Configures the tochsendnn cache and runs the AIU model prior to test execution;
864- this is computationally expensive and should only be used in situations like testing
865- cache hit correctness;
866- """
867- torch .manual_seed (42 )
868- torch .set_grad_enabled (False )
869- _reset_cache_settings (purge_cache_dir = True )
870-
871- model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
872- micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
873-
874- def verify_cache_miss ():
875- cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
876- updated_cache_len = (
877- len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
878- )
879- assert updated_cache_len == max_new_tokens , (
880- "cache directory not populated on cache miss"
881- )
882-
883- dprint (
884- f"Setting up cache [i.e., cache miss check] for model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } "
885- )
886-
887- # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
888- gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
889-
890- model = _get_aiu_model (
891- model_path ,
892- gptq_kwargs_aiu ,
893- persistent_model_inst = None ,
894- )
895-
896- validation_model = _get_cpu_model (
897- model_path ,
898- gptq_kwargs_cpu ,
899- micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
900- )
901-
902- _run_cpu_aiu_validation_test (
903- model_path ,
904- batch_size ,
905- seq_length ,
906- max_new_tokens ,
907- validation_model ,
908- model ,
909- micro_model_path ,
910- verify_cache_state = verify_cache_miss ,
911- )
912-
913-
914- def _get_cache_test_params ():
915- # NOTE - currently we always use granite 3.3 for the cache test,
916- # TODO make this configurable as tests are refactored
917- model_path = GRANITE_3p3_8B_INSTRUCT
918- batch_size = COMMON_BATCH_SIZES [0 ]
919- seq_length = COMMON_SEQ_LENGTHS [0 ]
920- max_new_tokens = COMMON_MAX_NEW_TOKENS [0 ]
921- return [model_path , batch_size , seq_length , max_new_tokens ]
922-
923-
924835@pytest .mark .parametrize (
925836 "model_path,batch_size,seq_length,max_new_tokens" , common_shapes
926837)
@@ -959,51 +870,3 @@ def test_common_shapes(
959870 model ,
960871 micro_model_path ,
961872 )
962-
963-
964- def test_cache (use_cached_model ):
965- torch .manual_seed (42 )
966- torch .set_grad_enabled (False )
967- _reset_cache_settings (purge_cache_dir = False )
968-
969- model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
970- micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
971-
972- def verify_cache_hit ():
973- cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
974- updated_cache_len = (
975- len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
976- )
977- assert updated_cache_len == max_new_tokens , (
978- "cache miss occurred when hit was expected"
979- )
980-
981- dprint (
982- f"testing: model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } , for cache hit"
983- )
984-
985- # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
986- gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
987-
988- model = _get_aiu_model (
989- model_path ,
990- gptq_kwargs_aiu ,
991- persistent_model_inst = None ,
992- )
993-
994- validation_model = _get_cpu_model (
995- model_path ,
996- gptq_kwargs_cpu ,
997- micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
998- )
999-
1000- _run_cpu_aiu_validation_test (
1001- model_path ,
1002- batch_size ,
1003- seq_length ,
1004- max_new_tokens ,
1005- validation_model ,
1006- model ,
1007- micro_model_path ,
1008- verify_cache_state = verify_cache_hit ,
1009- )
0 commit comments