Skip to content

Commit eafb818

Browse files
Remove cache tests
1 parent 6109684 commit eafb818

1 file changed

Lines changed: 0 additions & 137 deletions

File tree

tests/models/test_decoders.py

Lines changed: 0 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
)
2525
import json
2626
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup
27-
import shutil
2827
import os
2928

3029
try:
@@ -787,7 +786,6 @@ def _run_cpu_aiu_validation_test(
787786
cpu_model,
788787
aiu_model,
789788
micro_model_path,
790-
verify_cache_state=None,
791789
):
792790
# Get the tokenizer and AIU / CPU models to compare
793791
tokenizer = tokenizers.get_tokenizer(model_path)
@@ -813,12 +811,6 @@ def _run_cpu_aiu_validation_test(
813811
aiu_model,
814812
)
815813

816-
# Used only for cache tests; this is a nonparametric closure that
817-
# should assert the cache for torch sendnn is in the correct state
818-
# for this test
819-
if verify_cache_state is not None:
820-
verify_cache_state()
821-
822814
# if level 0 fails validation, validate level 1
823815
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
824816
if failed_validation_level_0:
@@ -840,87 +832,6 @@ def _run_cpu_aiu_validation_test(
840832
)
841833

842834

843-
def _reset_cache_settings(purge_cache_dir):
844-
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
845-
os.environ["COMPILATION_MODE"] = "offline_decoder"
846-
cache_dir = os.environ["TORCH_SENDNN_CACHE_DIR"]
847-
848-
# Ensure we start in clean state
849-
if purge_cache_dir and os.path.isdir(cache_dir):
850-
shutil.rmtree(cache_dir)
851-
os.mkdir(cache_dir)
852-
853-
from torch_sendnn.backends import cache
854-
855-
# Explicitly clear cache paths from the global torch sendnn graph;
856-
# TODO would be better to add a helper to explicitly do this in
857-
# torch sendnn
858-
cache.cache = {}
859-
860-
861-
@pytest.fixture
862-
def use_cached_model():
863-
"""Configures the tochsendnn cache and runs the AIU model prior to test execution;
864-
this is computationally expensive and should only be used in situations like testing
865-
cache hit correctness;
866-
"""
867-
torch.manual_seed(42)
868-
torch.set_grad_enabled(False)
869-
_reset_cache_settings(purge_cache_dir=True)
870-
871-
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
872-
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
873-
874-
def verify_cache_miss():
875-
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
876-
updated_cache_len = (
877-
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
878-
)
879-
assert updated_cache_len == max_new_tokens, (
880-
"cache directory not populated on cache miss"
881-
)
882-
883-
dprint(
884-
f"Setting up cache [i.e., cache miss check] for model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
885-
)
886-
887-
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
888-
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
889-
890-
model = _get_aiu_model(
891-
model_path,
892-
gptq_kwargs_aiu,
893-
persistent_model_inst=None,
894-
)
895-
896-
validation_model = _get_cpu_model(
897-
model_path,
898-
gptq_kwargs_cpu,
899-
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
900-
)
901-
902-
_run_cpu_aiu_validation_test(
903-
model_path,
904-
batch_size,
905-
seq_length,
906-
max_new_tokens,
907-
validation_model,
908-
model,
909-
micro_model_path,
910-
verify_cache_state=verify_cache_miss,
911-
)
912-
913-
914-
def _get_cache_test_params():
915-
# NOTE - currently we always use granite 3.3 for the cache test,
916-
# TODO make this configurable as tests are refactored
917-
model_path = GRANITE_3p3_8B_INSTRUCT
918-
batch_size = COMMON_BATCH_SIZES[0]
919-
seq_length = COMMON_SEQ_LENGTHS[0]
920-
max_new_tokens = COMMON_MAX_NEW_TOKENS[0]
921-
return [model_path, batch_size, seq_length, max_new_tokens]
922-
923-
924835
@pytest.mark.parametrize(
925836
"model_path,batch_size,seq_length,max_new_tokens", common_shapes
926837
)
@@ -959,51 +870,3 @@ def test_common_shapes(
959870
model,
960871
micro_model_path,
961872
)
962-
963-
964-
def test_cache(use_cached_model):
965-
torch.manual_seed(42)
966-
torch.set_grad_enabled(False)
967-
_reset_cache_settings(purge_cache_dir=False)
968-
969-
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
970-
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
971-
972-
def verify_cache_hit():
973-
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
974-
updated_cache_len = (
975-
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
976-
)
977-
assert updated_cache_len == max_new_tokens, (
978-
"cache miss occurred when hit was expected"
979-
)
980-
981-
dprint(
982-
f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit"
983-
)
984-
985-
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
986-
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
987-
988-
model = _get_aiu_model(
989-
model_path,
990-
gptq_kwargs_aiu,
991-
persistent_model_inst=None,
992-
)
993-
994-
validation_model = _get_cpu_model(
995-
model_path,
996-
gptq_kwargs_cpu,
997-
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
998-
)
999-
1000-
_run_cpu_aiu_validation_test(
1001-
model_path,
1002-
batch_size,
1003-
seq_length,
1004-
max_new_tokens,
1005-
validation_model,
1006-
model,
1007-
micro_model_path,
1008-
verify_cache_state=verify_cache_hit,
1009-
)

0 commit comments

Comments
 (0)