Skip to content

Commit 6109684

Browse files
Add leading underscores, revert model name
1 parent 3ee654f commit 6109684

1 file changed

Lines changed: 27 additions & 25 deletions

File tree

tests/models/test_decoders.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def _check_failure_thresholds(
455455
print("passed validation level 0")
456456

457457

458-
def get_common_model_kwargs(is_gptq, model_path):
458+
def _get_common_model_kwargs(is_gptq, model_path):
459459
if is_gptq:
460460
return {}
461461
# Get the micro model kwargs
@@ -491,9 +491,9 @@ def get_common_model_kwargs(is_gptq, model_path):
491491

492492
# NOTE micro_model_state_dict should be None if USE_MICRO_MODELS is true
493493
# Otherwise it should be model.state_dict() where model is the AIU model
494-
def get_cpu_model(model_path, gptq_kwargs, micro_model_state_dict=None):
494+
def _get_cpu_model(model_path, gptq_kwargs, micro_model_state_dict=None):
495495
is_gptq = len(gptq_kwargs) != 0
496-
model_kwargs = get_common_model_kwargs(is_gptq, model_path)
496+
model_kwargs = _get_common_model_kwargs(is_gptq, model_path)
497497

498498
# prepare the cpu model
499499
validation_model = get_model(
@@ -512,9 +512,9 @@ def get_cpu_model(model_path, gptq_kwargs, micro_model_state_dict=None):
512512
return validation_model
513513

514514

515-
def get_aiu_model(model_path, gptq_kwargs, persistent_model_inst):
515+
def _get_aiu_model(model_path, gptq_kwargs, persistent_model_inst):
516516
is_gptq = len(gptq_kwargs) != 0
517-
model_kwargs = get_common_model_kwargs(is_gptq, model_path)
517+
model_kwargs = _get_common_model_kwargs(is_gptq, model_path)
518518

519519
# prepare the AIU model; use the persistent model fixure if the test has it
520520
if persistent_model_inst is not None:
@@ -538,7 +538,7 @@ def get_aiu_model(model_path, gptq_kwargs, persistent_model_inst):
538538
return aiu_model
539539

540540

541-
def get_device_validation_information(
541+
def _get_device_validation_information(
542542
model_path,
543543
batch_size,
544544
seq_length,
@@ -601,7 +601,7 @@ def get_device_validation_information(
601601
return validation_info
602602

603603

604-
def resolve_thresholds(model_path, micro_model_path):
604+
def _resolve_thresholds(model_path, micro_model_path):
605605
# if we do not have real model weights, use a default_metrics_threshold
606606
if USE_MICRO_MODELS and micro_model_path is None:
607607
ce_threshold, diff_threshold = DEFAULT_METRICS_THRESHOLD
@@ -620,7 +620,7 @@ def resolve_thresholds(model_path, micro_model_path):
620620
return ce_threshold, diff_threshold
621621

622622

623-
def run_validation_level_0(
623+
def _run_validation_level_0(
624624
model_path,
625625
batch_size,
626626
seq_length,
@@ -631,7 +631,7 @@ def run_validation_level_0(
631631
extra_kwargs,
632632
model,
633633
):
634-
cpu_validation_info = get_device_validation_information(
634+
cpu_validation_info = _get_device_validation_information(
635635
model_path=model_path,
636636
batch_size=batch_size,
637637
seq_length=seq_length,
@@ -655,7 +655,7 @@ def run_validation_level_0(
655655
)
656656

657657
# first test validation level 0
658-
aiu_validation_info = get_device_validation_information(
658+
aiu_validation_info = _get_device_validation_information(
659659
model_path=model_path,
660660
batch_size=batch_size,
661661
seq_length=seq_length,
@@ -685,7 +685,7 @@ def run_validation_level_0(
685685
return len(failed_responses) != 0, validation_zero_info
686686

687687

688-
def run_validation_level_1(
688+
def _run_validation_level_1(
689689
model_path,
690690
batch_size,
691691
seq_length,
@@ -705,7 +705,7 @@ def run_validation_level_1(
705705
for i in range(iters):
706706
# for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip
707707
if i != 0:
708-
cpu_validation_info = get_device_validation_information(
708+
cpu_validation_info = _get_device_validation_information(
709709
model_path=model_path,
710710
batch_size=batch_size,
711711
seq_length=seq_length,
@@ -733,7 +733,7 @@ def run_validation_level_1(
733733
cpu_static_tokens = validation_zero_info["cpu_static_tokens"]
734734
eos_indexes = validation_zero_info["eos_indexes"]
735735

736-
aiu_validation_info = get_device_validation_information(
736+
aiu_validation_info = _get_device_validation_information(
737737
model_path=model_path,
738738
batch_size=batch_size,
739739
seq_length=seq_length,
@@ -758,7 +758,7 @@ def run_validation_level_1(
758758
# only consider those metrics captured prior to the eos
759759
level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes)
760760

761-
ce_threshold, diff_threshold = resolve_thresholds(model_path, micro_model_path)
761+
ce_threshold, diff_threshold = _resolve_thresholds(model_path, micro_model_path)
762762

763763
# get all failed responses for each metric
764764
ce_fail_responses = filter_failed_level_1_cases(
@@ -801,7 +801,7 @@ def _run_cpu_aiu_validation_test(
801801
)
802802

803803
# Run validation level 0
804-
failed_validation_level_0, validation_zero_info = run_validation_level_0(
804+
failed_validation_level_0, validation_zero_info = _run_validation_level_0(
805805
model_path,
806806
batch_size,
807807
seq_length,
@@ -825,7 +825,7 @@ def _run_cpu_aiu_validation_test(
825825
dprint("failed validation level 0, testing validation level 1")
826826
else:
827827
dprint("passed validation level 0, testing validation level 1")
828-
run_validation_level_1(
828+
_run_validation_level_1(
829829
model_path,
830830
batch_size,
831831
seq_length,
@@ -887,13 +887,13 @@ def verify_cache_miss():
887887
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
888888
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
889889

890-
model = get_aiu_model(
890+
model = _get_aiu_model(
891891
model_path,
892892
gptq_kwargs_aiu,
893893
persistent_model_inst=None,
894894
)
895895

896-
validation_model = get_cpu_model(
896+
validation_model = _get_cpu_model(
897897
model_path,
898898
gptq_kwargs_cpu,
899899
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
@@ -912,9 +912,11 @@ def verify_cache_miss():
912912

913913

914914
def _get_cache_test_params():
915-
model_path = "/models/tiny-models/granite-3.3-8b-layers-3-step-100000" # ibm-granite/granite-3.3-8b-instruct"
916-
batch_size = 1 # common_batch_sizes[0]
917-
seq_length = 128 # common_seq_lengths[0]
915+
# NOTE - currently we always use granite 3.3 for the cache test,
916+
# TODO make this configurable as tests are refactored
917+
model_path = GRANITE_3p3_8B_INSTRUCT
918+
batch_size = COMMON_BATCH_SIZES[0]
919+
seq_length = COMMON_SEQ_LENGTHS[0]
918920
max_new_tokens = COMMON_MAX_NEW_TOKENS[0]
919921
return [model_path, batch_size, seq_length, max_new_tokens]
920922

@@ -937,13 +939,13 @@ def test_common_shapes(
937939
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
938940
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
939941

940-
model = get_aiu_model(
942+
model = _get_aiu_model(
941943
model_path,
942944
gptq_kwargs_aiu,
943945
persistent_model_inst=persistent_model,
944946
)
945947

946-
validation_model = get_cpu_model(
948+
validation_model = _get_cpu_model(
947949
model_path,
948950
gptq_kwargs_cpu,
949951
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
@@ -983,13 +985,13 @@ def verify_cache_hit():
983985
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
984986
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
985987

986-
model = get_aiu_model(
988+
model = _get_aiu_model(
987989
model_path,
988990
gptq_kwargs_aiu,
989991
persistent_model_inst=None,
990992
)
991993

992-
validation_model = get_cpu_model(
994+
validation_model = _get_cpu_model(
993995
model_path,
994996
gptq_kwargs_cpu,
995997
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,

0 commit comments

Comments
 (0)