Skip to content

Commit 32c1e6c

Browse files
fix fp8 dtype, always use persistent model fixture
1 parent 85d90f0 commit 32c1e6c

1 file changed

Lines changed: 15 additions & 38 deletions

File tree

tests/models/test_decoders.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -527,17 +527,13 @@ def _get_common_model_kwargs(is_gptq, model_path):
527527

528528
# NOTE micro_model_state_dict should be None if USE_MICRO_MODELS is true
529529
# Otherwise it should be model.state_dict() where model is the AIU model
530-
def _get_cpu_model(model_path, gptq_kwargs, micro_model_state_dict=None):
531-
is_gptq = len(gptq_kwargs) != 0
532-
model_kwargs = _get_common_model_kwargs(is_gptq, model_path)
533-
530+
def _get_cpu_model(model_path, is_gptq, is_fp8, micro_model_state_dict=None, **kwargs):
534531
# prepare the cpu model
535532
validation_model = get_model(
536533
device_type="cpu",
537-
data_type=None if is_gptq else torch.float32,
534+
data_type=None if is_fp8 or is_gptq else torch.float32,
538535
fused_weights=False,
539-
**gptq_kwargs,
540-
**model_kwargs,
536+
**kwargs,
541537
)
542538

543539
# This is a micro model, so we need to copy the state dict directly.
@@ -548,32 +544,6 @@ def _get_cpu_model(model_path, gptq_kwargs, micro_model_state_dict=None):
548544
return validation_model
549545

550546

551-
def _get_aiu_model(model_path, gptq_kwargs, persistent_model_inst):
552-
is_gptq = len(gptq_kwargs) != 0
553-
is_fp8 = "fp8" in ATTN_NAME
554-
model_kwargs = _get_common_model_kwargs(is_gptq, model_path)
555-
556-
# prepare the AIU model; use the persistent model fixure if the test has it
557-
if persistent_model_inst is not None:
558-
aiu_model = persistent_model_inst.get_or_create(
559-
is_gptq, is_fp8, **gptq_kwargs, **model_kwargs
560-
)
561-
# otherwise create it directly
562-
else:
563-
aiu_model = get_model(
564-
device_type="cpu",
565-
data_type=None if is_gptq else torch.float16,
566-
fused_weights=False,
567-
**gptq_kwargs,
568-
**model_kwargs,
569-
)
570-
aiu_model.eval()
571-
aiu_model.compile(
572-
backend="sendnn",
573-
options={"sendnn.dynamic": COMPILE_DYNAMIC_SENDNN},
574-
)
575-
return aiu_model
576-
577547

578548
def _get_device_validation_information(
579549
model_path,
@@ -891,17 +861,24 @@ def test_common_shapes(
891861
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
892862
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
893863

894-
model = _get_aiu_model(
895-
model_path,
896-
gptq_kwargs_aiu,
897-
persistent_model_inst=persistent_model,
864+
is_gptq = len(gptq_kwargs_aiu) != 0
865+
is_fp8 = "fp8" in ATTN_NAME
866+
model_kwargs = _get_common_model_kwargs(is_gptq, model_path)
867+
868+
# Get the AIU model w/ the persistent model fixture
869+
model = persistent_model.get_or_create(
870+
is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs
898871
)
899872

900873
validation_model = _get_cpu_model(
901874
model_path,
902-
gptq_kwargs_cpu,
875+
is_gptq,
876+
is_fp8,
903877
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
878+
**gptq_kwargs_cpu,
879+
**model_kwargs,
904880
)
881+
905882
_run_cpu_aiu_validation_test(
906883
model_path,
907884
batch_size,

0 commit comments

Comments
 (0)