231231# note: llama already has many adapters for aiu and they are the same for all models, so just use llama. This way we don't need to re-register a new architecture / adapter step (we can just re-use)
232232__custom_adapter = {"architecture" : "llama" , "source" : "fms_aiu" }
233233
234- ### Additional configuration for testing caching correctness
235- # Directory to be used for testing torch sendnn caching
236- CACHE_DIR = os .path .join (os .getcwd (), ".cache" )
237-
238234
239235@pytest .fixture (autouse = True )
240236def reset_compiler ():
@@ -243,8 +239,6 @@ def reset_compiler():
243239 torch .compiler .reset ()
244240 torch ._dynamo .reset ()
245241 os .environ .pop ("COMPILATION_MODE" , None )
246- os .environ .pop ("TORCH_SENDNN_CACHE_ENABLE" , None )
247- # FIXME - this fixture breaks stuff if we don't run the cache test first
248242
249243
250244# TODO: Currently, gptq does not have the same level of support as non-gptq models for get_model. This method provides the extra requirements for gptq for get_model,
@@ -848,12 +842,20 @@ def _run_cpu_aiu_validation_test(
848842
849843def _reset_cache_settings (purge_cache_dir ):
850844 os .environ ["TORCH_SENDNN_CACHE_ENABLE" ] = "1"
851- os .environ ["TORCH_SENDNN_CACHE_DIR" ] = CACHE_DIR
852845 os .environ ["COMPILATION_MODE" ] = "offline_decoder"
846+ cache_dir = os .environ ["TORCH_SENDNN_CACHE_DIR" ]
853847
854848 # Ensure we start in clean state
855- if purge_cache_dir and os .path .isdir (CACHE_DIR ):
856- shutil .rmtree (CACHE_DIR )
849+ if purge_cache_dir and os .path .isdir (cache_dir ):
850+ shutil .rmtree (cache_dir )
851+ os .mkdir (cache_dir )
852+
853+ from torch_sendnn .backends import cache
854+
855+ # Explicitly clear cache paths from the global torch sendnn graph;
856+ # TODO would be better to add a helper to explicitly do this in
857+ # torch sendnn
858+ cache .cache = {}
857859
858860
859861@pytest .fixture
@@ -865,12 +867,14 @@ def use_cached_model():
865867 torch .manual_seed (42 )
866868 torch .set_grad_enabled (False )
867869 _reset_cache_settings (purge_cache_dir = True )
870+
868871 model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
869872 micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
870873
871874 def verify_cache_miss ():
875+ cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
872876 updated_cache_len = (
873- len (os .listdir (CACHE_DIR )) if os .path .isdir (CACHE_DIR ) else 0
877+ len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
874878 )
875879 assert updated_cache_len == max_new_tokens , (
876880 "cache directory not populated on cache miss"
@@ -915,24 +919,19 @@ def _get_cache_test_params():
915919 return [model_path , batch_size , seq_length , max_new_tokens ]
916920
917921
918- def test_cache (use_cached_model ):
922+ @pytest .mark .parametrize (
923+ "model_path,batch_size,seq_length,max_new_tokens" , common_shapes
924+ )
925+ def test_common_shapes (
926+ model_path , batch_size , seq_length , max_new_tokens , persistent_model
927+ ):
919928 torch .manual_seed (42 )
920929 torch .set_grad_enabled (False )
921- _reset_cache_settings (purge_cache_dir = False )
922-
923- model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
930+ os .environ ["COMPILATION_MODE" ] = "offline_decoder"
924931 micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
925932
926- def verify_cache_hit ():
927- updated_cache_len = (
928- len (os .listdir (CACHE_DIR )) if os .path .isdir (CACHE_DIR ) else 0
929- )
930- assert updated_cache_len == max_new_tokens , (
931- "cache miss occurred when hit was expected"
932- )
933-
934933 dprint (
935- f"testing: model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } , for cache hit "
934+ f"testing model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } "
936935 )
937936
938937 # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
@@ -941,15 +940,14 @@ def verify_cache_hit():
941940 model = get_aiu_model (
942941 model_path ,
943942 gptq_kwargs_aiu ,
944- persistent_model_inst = None ,
943+ persistent_model_inst = persistent_model ,
945944 )
946945
947946 validation_model = get_cpu_model (
948947 model_path ,
949948 gptq_kwargs_cpu ,
950949 micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
951950 )
952-
953951 _run_cpu_aiu_validation_test (
954952 model_path ,
955953 batch_size ,
@@ -958,23 +956,28 @@ def verify_cache_hit():
958956 validation_model ,
959957 model ,
960958 micro_model_path ,
961- verify_cache_state = verify_cache_hit ,
962959 )
963960
964961
965- @pytest .mark .parametrize (
966- "model_path,batch_size,seq_length,max_new_tokens" , common_shapes
967- )
968- def test_common_shapes (
969- model_path , batch_size , seq_length , max_new_tokens , persistent_model
970- ):
962+ def test_cache (use_cached_model ):
971963 torch .manual_seed (42 )
972964 torch .set_grad_enabled (False )
973- os .environ ["COMPILATION_MODE" ] = "offline_decoder"
965+ _reset_cache_settings (purge_cache_dir = False )
966+
967+ model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
974968 micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
975969
970+ def verify_cache_hit ():
971+ cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
972+ updated_cache_len = (
973+ len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
974+ )
975+ assert updated_cache_len == max_new_tokens , (
976+ "cache miss occurred when hit was expected"
977+ )
978+
976979 dprint (
977- f"testing model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } "
980+ f"testing: model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } , for cache hit "
978981 )
979982
980983 # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
@@ -983,14 +986,15 @@ def test_common_shapes(
983986 model = get_aiu_model (
984987 model_path ,
985988 gptq_kwargs_aiu ,
986- persistent_model_inst = persistent_model ,
989+ persistent_model_inst = None ,
987990 )
988991
989992 validation_model = get_cpu_model (
990993 model_path ,
991994 gptq_kwargs_cpu ,
992995 micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
993996 )
997+
994998 _run_cpu_aiu_validation_test (
995999 model_path ,
9961000 batch_size ,
@@ -999,4 +1003,5 @@ def test_common_shapes(
9991003 validation_model ,
10001004 model ,
10011005 micro_model_path ,
1006+ verify_cache_state = verify_cache_hit ,
10021007 )
0 commit comments