@@ -455,7 +455,7 @@ def _check_failure_thresholds(
455455 print ("passed validation level 0" )
456456
457457
458- def get_common_model_kwargs (is_gptq , model_path ):
458+ def _get_common_model_kwargs (is_gptq , model_path ):
459459 if is_gptq :
460460 return {}
461461 # Get the micro model kwargs
@@ -491,9 +491,9 @@ def get_common_model_kwargs(is_gptq, model_path):
491491
492492# NOTE micro_model_state_dict should be None if USE_MICRO_MODELS is true
493493# Otherwise it should be model.state_dict() where model is the AIU model
494- def get_cpu_model (model_path , gptq_kwargs , micro_model_state_dict = None ):
494+ def _get_cpu_model (model_path , gptq_kwargs , micro_model_state_dict = None ):
495495 is_gptq = len (gptq_kwargs ) != 0
496- model_kwargs = get_common_model_kwargs (is_gptq , model_path )
496+ model_kwargs = _get_common_model_kwargs (is_gptq , model_path )
497497
498498 # prepare the cpu model
499499 validation_model = get_model (
@@ -512,9 +512,9 @@ def get_cpu_model(model_path, gptq_kwargs, micro_model_state_dict=None):
512512 return validation_model
513513
514514
515- def get_aiu_model (model_path , gptq_kwargs , persistent_model_inst ):
515+ def _get_aiu_model (model_path , gptq_kwargs , persistent_model_inst ):
516516 is_gptq = len (gptq_kwargs ) != 0
517- model_kwargs = get_common_model_kwargs (is_gptq , model_path )
517+ model_kwargs = _get_common_model_kwargs (is_gptq , model_path )
518518
519519 # prepare the AIU model; use the persistent model fixure if the test has it
520520 if persistent_model_inst is not None :
@@ -538,7 +538,7 @@ def get_aiu_model(model_path, gptq_kwargs, persistent_model_inst):
538538 return aiu_model
539539
540540
541- def get_device_validation_information (
541+ def _get_device_validation_information (
542542 model_path ,
543543 batch_size ,
544544 seq_length ,
@@ -601,7 +601,7 @@ def get_device_validation_information(
601601 return validation_info
602602
603603
604- def resolve_thresholds (model_path , micro_model_path ):
604+ def _resolve_thresholds (model_path , micro_model_path ):
605605 # if we do not have real model weights, use a default_metrics_threshold
606606 if USE_MICRO_MODELS and micro_model_path is None :
607607 ce_threshold , diff_threshold = DEFAULT_METRICS_THRESHOLD
@@ -620,7 +620,7 @@ def resolve_thresholds(model_path, micro_model_path):
620620 return ce_threshold , diff_threshold
621621
622622
623- def run_validation_level_0 (
623+ def _run_validation_level_0 (
624624 model_path ,
625625 batch_size ,
626626 seq_length ,
@@ -631,7 +631,7 @@ def run_validation_level_0(
631631 extra_kwargs ,
632632 model ,
633633):
634- cpu_validation_info = get_device_validation_information (
634+ cpu_validation_info = _get_device_validation_information (
635635 model_path = model_path ,
636636 batch_size = batch_size ,
637637 seq_length = seq_length ,
@@ -655,7 +655,7 @@ def run_validation_level_0(
655655 )
656656
657657 # first test validation level 0
658- aiu_validation_info = get_device_validation_information (
658+ aiu_validation_info = _get_device_validation_information (
659659 model_path = model_path ,
660660 batch_size = batch_size ,
661661 seq_length = seq_length ,
@@ -685,7 +685,7 @@ def run_validation_level_0(
685685 return len (failed_responses ) != 0 , validation_zero_info
686686
687687
688- def run_validation_level_1 (
688+ def _run_validation_level_1 (
689689 model_path ,
690690 batch_size ,
691691 seq_length ,
@@ -705,7 +705,7 @@ def run_validation_level_1(
705705 for i in range (iters ):
706706 # for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip
707707 if i != 0 :
708- cpu_validation_info = get_device_validation_information (
708+ cpu_validation_info = _get_device_validation_information (
709709 model_path = model_path ,
710710 batch_size = batch_size ,
711711 seq_length = seq_length ,
@@ -733,7 +733,7 @@ def run_validation_level_1(
733733 cpu_static_tokens = validation_zero_info ["cpu_static_tokens" ]
734734 eos_indexes = validation_zero_info ["eos_indexes" ]
735735
736- aiu_validation_info = get_device_validation_information (
736+ aiu_validation_info = _get_device_validation_information (
737737 model_path = model_path ,
738738 batch_size = batch_size ,
739739 seq_length = seq_length ,
@@ -758,7 +758,7 @@ def run_validation_level_1(
758758 # only consider those metrics captured prior to the eos
759759 level_1_metrics = __filter_before_eos (level_1_metrics , eos_indexes )
760760
761- ce_threshold , diff_threshold = resolve_thresholds (model_path , micro_model_path )
761+ ce_threshold , diff_threshold = _resolve_thresholds (model_path , micro_model_path )
762762
763763 # get all failed responses for each metric
764764 ce_fail_responses = filter_failed_level_1_cases (
@@ -801,7 +801,7 @@ def _run_cpu_aiu_validation_test(
801801 )
802802
803803 # Run validation level 0
804- failed_validation_level_0 , validation_zero_info = run_validation_level_0 (
804+ failed_validation_level_0 , validation_zero_info = _run_validation_level_0 (
805805 model_path ,
806806 batch_size ,
807807 seq_length ,
@@ -825,7 +825,7 @@ def _run_cpu_aiu_validation_test(
825825 dprint ("failed validation level 0, testing validation level 1" )
826826 else :
827827 dprint ("passed validation level 0, testing validation level 1" )
828- run_validation_level_1 (
828+ _run_validation_level_1 (
829829 model_path ,
830830 batch_size ,
831831 seq_length ,
@@ -887,13 +887,13 @@ def verify_cache_miss():
887887 # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
888888 gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
889889
890- model = get_aiu_model (
890+ model = _get_aiu_model (
891891 model_path ,
892892 gptq_kwargs_aiu ,
893893 persistent_model_inst = None ,
894894 )
895895
896- validation_model = get_cpu_model (
896+ validation_model = _get_cpu_model (
897897 model_path ,
898898 gptq_kwargs_cpu ,
899899 micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
@@ -912,9 +912,11 @@ def verify_cache_miss():
912912
913913
914914def _get_cache_test_params ():
915- model_path = "/models/tiny-models/granite-3.3-8b-layers-3-step-100000" # ibm-granite/granite-3.3-8b-instruct"
916- batch_size = 1 # common_batch_sizes[0]
917- seq_length = 128 # common_seq_lengths[0]
915+ # NOTE - currently we always use granite 3.3 for the cache test,
916+ # TODO make this configurable as tests are refactored
917+ model_path = GRANITE_3p3_8B_INSTRUCT
918+ batch_size = COMMON_BATCH_SIZES [0 ]
919+ seq_length = COMMON_SEQ_LENGTHS [0 ]
918920 max_new_tokens = COMMON_MAX_NEW_TOKENS [0 ]
919921 return [model_path , batch_size , seq_length , max_new_tokens ]
920922
@@ -937,13 +939,13 @@ def test_common_shapes(
937939 # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
938940 gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
939941
940- model = get_aiu_model (
942+ model = _get_aiu_model (
941943 model_path ,
942944 gptq_kwargs_aiu ,
943945 persistent_model_inst = persistent_model ,
944946 )
945947
946- validation_model = get_cpu_model (
948+ validation_model = _get_cpu_model (
947949 model_path ,
948950 gptq_kwargs_cpu ,
949951 micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
@@ -983,13 +985,13 @@ def verify_cache_hit():
983985 # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
984986 gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
985987
986- model = get_aiu_model (
988+ model = _get_aiu_model (
987989 model_path ,
988990 gptq_kwargs_aiu ,
989991 persistent_model_inst = None ,
990992 )
991993
992- validation_model = get_cpu_model (
994+ validation_model = _get_cpu_model (
993995 model_path ,
994996 gptq_kwargs_cpu ,
995997 micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
0 commit comments