Skip to content

Commit 3ee654f

Browse files
Explictly clear cache paths
1 parent 6a7d19b commit 3ee654f

1 file changed

Lines changed: 40 additions & 35 deletions

File tree

tests/models/test_decoders.py

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,6 @@
231231
# note: llama already has many adapters for aiu and they are the same for all models, so just use llama. This way we don't need to re-register a new architecture / adapter step (we can just re-use)
232232
__custom_adapter = {"architecture": "llama", "source": "fms_aiu"}
233233

234-
### Additional configuration for testing caching correctness
235-
# Directory to be used for testing torch sendnn caching
236-
CACHE_DIR = os.path.join(os.getcwd(), ".cache")
237-
238234

239235
@pytest.fixture(autouse=True)
240236
def reset_compiler():
@@ -243,8 +239,6 @@ def reset_compiler():
243239
torch.compiler.reset()
244240
torch._dynamo.reset()
245241
os.environ.pop("COMPILATION_MODE", None)
246-
os.environ.pop("TORCH_SENDNN_CACHE_ENABLE", None)
247-
# FIXME - this fixture breaks stuff if we don't run the cache test first
248242

249243

250244
# TODO: Currently, gptq does not have the same level of support as non-gptq models for get_model. This method provides the extra requirements for gptq for get_model,
@@ -848,12 +842,20 @@ def _run_cpu_aiu_validation_test(
848842

849843
def _reset_cache_settings(purge_cache_dir):
850844
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
851-
os.environ["TORCH_SENDNN_CACHE_DIR"] = CACHE_DIR
852845
os.environ["COMPILATION_MODE"] = "offline_decoder"
846+
cache_dir = os.environ["TORCH_SENDNN_CACHE_DIR"]
853847

854848
# Ensure we start in clean state
855-
if purge_cache_dir and os.path.isdir(CACHE_DIR):
856-
shutil.rmtree(CACHE_DIR)
849+
if purge_cache_dir and os.path.isdir(cache_dir):
850+
shutil.rmtree(cache_dir)
851+
os.mkdir(cache_dir)
852+
853+
from torch_sendnn.backends import cache
854+
855+
# Explicitly clear cache paths from the global torch sendnn graph;
856+
# TODO would be better to add a helper to explicitly do this in
857+
# torch sendnn
858+
cache.cache = {}
857859

858860

859861
@pytest.fixture
@@ -865,12 +867,14 @@ def use_cached_model():
865867
torch.manual_seed(42)
866868
torch.set_grad_enabled(False)
867869
_reset_cache_settings(purge_cache_dir=True)
870+
868871
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
869872
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
870873

871874
def verify_cache_miss():
875+
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
872876
updated_cache_len = (
873-
len(os.listdir(CACHE_DIR)) if os.path.isdir(CACHE_DIR) else 0
877+
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
874878
)
875879
assert updated_cache_len == max_new_tokens, (
876880
"cache directory not populated on cache miss"
@@ -915,24 +919,19 @@ def _get_cache_test_params():
915919
return [model_path, batch_size, seq_length, max_new_tokens]
916920

917921

918-
def test_cache(use_cached_model):
922+
@pytest.mark.parametrize(
923+
"model_path,batch_size,seq_length,max_new_tokens", common_shapes
924+
)
925+
def test_common_shapes(
926+
model_path, batch_size, seq_length, max_new_tokens, persistent_model
927+
):
919928
torch.manual_seed(42)
920929
torch.set_grad_enabled(False)
921-
_reset_cache_settings(purge_cache_dir=False)
922-
923-
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
930+
os.environ["COMPILATION_MODE"] = "offline_decoder"
924931
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
925932

926-
def verify_cache_hit():
927-
updated_cache_len = (
928-
len(os.listdir(CACHE_DIR)) if os.path.isdir(CACHE_DIR) else 0
929-
)
930-
assert updated_cache_len == max_new_tokens, (
931-
"cache miss occurred when hit was expected"
932-
)
933-
934933
dprint(
935-
f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit"
934+
f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
936935
)
937936

938937
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
@@ -941,15 +940,14 @@ def verify_cache_hit():
941940
model = get_aiu_model(
942941
model_path,
943942
gptq_kwargs_aiu,
944-
persistent_model_inst=None,
943+
persistent_model_inst=persistent_model,
945944
)
946945

947946
validation_model = get_cpu_model(
948947
model_path,
949948
gptq_kwargs_cpu,
950949
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
951950
)
952-
953951
_run_cpu_aiu_validation_test(
954952
model_path,
955953
batch_size,
@@ -958,23 +956,28 @@ def verify_cache_hit():
958956
validation_model,
959957
model,
960958
micro_model_path,
961-
verify_cache_state=verify_cache_hit,
962959
)
963960

964961

965-
@pytest.mark.parametrize(
966-
"model_path,batch_size,seq_length,max_new_tokens", common_shapes
967-
)
968-
def test_common_shapes(
969-
model_path, batch_size, seq_length, max_new_tokens, persistent_model
970-
):
962+
def test_cache(use_cached_model):
971963
torch.manual_seed(42)
972964
torch.set_grad_enabled(False)
973-
os.environ["COMPILATION_MODE"] = "offline_decoder"
965+
_reset_cache_settings(purge_cache_dir=False)
966+
967+
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
974968
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
975969

970+
def verify_cache_hit():
971+
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
972+
updated_cache_len = (
973+
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
974+
)
975+
assert updated_cache_len == max_new_tokens, (
976+
"cache miss occurred when hit was expected"
977+
)
978+
976979
dprint(
977-
f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
980+
f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit"
978981
)
979982

980983
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
@@ -983,14 +986,15 @@ def test_common_shapes(
983986
model = get_aiu_model(
984987
model_path,
985988
gptq_kwargs_aiu,
986-
persistent_model_inst=persistent_model,
989+
persistent_model_inst=None,
987990
)
988991

989992
validation_model = get_cpu_model(
990993
model_path,
991994
gptq_kwargs_cpu,
992995
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
993996
)
997+
994998
_run_cpu_aiu_validation_test(
995999
model_path,
9961000
batch_size,
@@ -999,4 +1003,5 @@ def test_common_shapes(
9991003
validation_model,
10001004
model,
10011005
micro_model_path,
1006+
verify_cache_state=verify_cache_hit,
10021007
)

0 commit comments

Comments
 (0)