Skip to content

Commit 12c9f3b

Browse files
authored
[Fix] AttributeError when skipping CPU validation (#184)
* Fix --skip-validation related logic --------- Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
1 parent 736a28c commit 12c9f3b

1 file changed

Lines changed: 116 additions & 113 deletions

File tree

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 116 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -891,33 +891,31 @@ def get_valid_prompts(
891891

892892

893893
def generate_cpu_validation(
894-
skip_validation: bool,
895894
model_variant: str,
896895
max_new_tokens: int,
897896
validation_info_outputs_dir: str,
898897
save_validation_info_outputs: bool,
899-
validation_model: Optional[torch.nn.Module],
898+
validation_model: torch.nn.Module,
900899
valid_prompt,
901900
input_ids: torch.Tensor,
902901
extra_kwargs: Dict[str, Any],
903902
sample_key: str,
904903
attn_name: str,
905904
cpu_dtype: str,
906905
tokenizer: AutoTokenizer,
907-
) -> ValidationInfo | None:
906+
) -> ValidationInfo:
908907
"""Generates or loads CPU validation information for reference comparison.
909908
910-
Attempts to load pre-computed CPU validation data from disk. If not found and
911-
a validation model is provided, runs CPU inference to generate reference outputs
912-
(tokens and logits). Optionally saves the validation info for future use.
909+
Attempts to load pre-computed CPU validation data from disk. If not found,
910+
runs CPU inference to generate reference outputs (tokens and logits).
911+
Optionally saves the validation info for future use.
913912
914913
Args:
915-
skip_validation: Whether to skip validation entirely.
916914
model_variant: Model identifier or path.
917915
max_new_tokens: Maximum number of tokens to generate.
918916
validation_info_outputs_dir: Directory for validation info outputs.
919917
save_validation_info_outputs: Whether to save validation info to disk.
920-
validation_model: Optional CPU model for generating validation data.
918+
validation_model: CPU model for generating validation data.
921919
valid_prompt: Tuple of (batch_size, seq_length) for the prompt shape.
922920
input_ids: Tokenized input tensor.
923921
extra_kwargs: Dictionary with attention mask and other model inputs.
@@ -927,54 +925,51 @@ def generate_cpu_validation(
927925
tokenizer: HuggingFace tokenizer for the model.
928926
929927
Returns:
930-
Optional[ValidationInfo]: ValidationInfo object containing CPU reference outputs
931-
(tokens and logits), or None if validation is skipped.
928+
ValidationInfo: ValidationInfo object containing CPU reference outputs
929+
(tokens and logits).
932930
"""
933-
cpu_validation_info: Optional[ValidationInfo] = None
934-
if not skip_validation:
935-
# attempt to load the cpu validation info if it is already computed
936-
cpu_validation_info = _load_validation_info(
937-
model_variant=model_variant,
938-
batch_size=valid_prompt[0],
939-
seq_length=valid_prompt[1],
931+
# attempt to load the cpu validation info if it is already computed
932+
cpu_validation_info = _load_validation_info(
933+
model_variant=model_variant,
934+
batch_size=valid_prompt[0],
935+
seq_length=valid_prompt[1],
936+
max_new_tokens=max_new_tokens,
937+
tokenizer=tokenizer,
938+
seed=0,
939+
cpu_dtype=cpu_dtype,
940+
attn_type=attn_name,
941+
validation_info_outputs_dir=validation_info_outputs_dir,
942+
sample_key=sample_key,
943+
)
944+
if cpu_validation_info is None:
945+
cpu_validation_info = extract_validation_information(
946+
model=validation_model,
947+
input_ids=input_ids,
940948
max_new_tokens=max_new_tokens,
941-
tokenizer=tokenizer,
942-
seed=0,
943-
cpu_dtype=cpu_dtype,
944-
attn_type=attn_name,
945-
validation_info_outputs_dir=validation_info_outputs_dir,
946-
sample_key=sample_key,
949+
post_iteration_hook=LogitsExtractorHook(),
950+
attn_algorithm="math",
951+
**extra_kwargs,
947952
)
948-
if cpu_validation_info is None and validation_model is not None:
949-
cpu_validation_info = extract_validation_information(
950-
model=validation_model,
951-
input_ids=input_ids,
952-
max_new_tokens=max_new_tokens,
953-
post_iteration_hook=LogitsExtractorHook(),
954-
attn_algorithm="math",
955-
**extra_kwargs,
956-
)
957-
if save_validation_info_outputs:
958-
cpu_validation_info.save(
959-
get_validation_info_path(
960-
validation_info_dir=validation_info_outputs_dir,
961-
model_variant=model_variant,
962-
batch_size=valid_prompt[0],
963-
seq_length=valid_prompt[1],
964-
max_new_tokens=max_new_tokens,
965-
seed=0,
966-
attn_type=attn_name,
967-
dtype=cpu_dtype,
968-
sample_key=sample_key,
969-
)
953+
if save_validation_info_outputs:
954+
cpu_validation_info.save(
955+
get_validation_info_path(
956+
validation_info_dir=validation_info_outputs_dir,
957+
model_variant=model_variant,
958+
batch_size=valid_prompt[0],
959+
seq_length=valid_prompt[1],
960+
max_new_tokens=max_new_tokens,
961+
seed=0,
962+
attn_type=attn_name,
963+
dtype=cpu_dtype,
964+
sample_key=sample_key,
970965
)
966+
)
971967

972968
return cpu_validation_info
973969

974970

975971
def generate_aiu_validation(
976972
test_type: str,
977-
skip_validation: bool,
978973
max_new_tokens: int,
979974
timing: str,
980975
prefill_chunk_size: int,
@@ -991,7 +986,6 @@ def generate_aiu_validation(
991986
992987
Args:
993988
test_type: Type of test being run ("metrics" or "tokens").
994-
skip_validation: Whether to skip validation entirely.
995989
max_new_tokens: Maximum number of tokens to generate.
996990
timing: Whether to collect timing information.
997991
prefill_chunk_size: Chunk size for prefill operations.
@@ -1005,9 +999,8 @@ def generate_aiu_validation(
1005999
and optional timing information).
10061000
"""
10071001
golden_hook = None
1008-
if test_type == "metrics":
1009-
if not skip_validation and cpu_validation_info:
1010-
golden_hook = GoldenTokenHook(cpu_validation_info.get_info("tokens"))
1002+
if test_type == "metrics" and cpu_validation_info:
1003+
golden_hook = GoldenTokenHook(cpu_validation_info.get_info("tokens"))
10111004

10121005
aiu_validation_info = extract_validation_information(
10131006
model=model,
@@ -1277,75 +1270,85 @@ def generate_validation_info_and_test(
12771270
f"program id: {valid_prompt.program_id}, valid prompt: {valid_prompt.shape}, input shape: {valid_prompt.input_ids.shape}"
12781271
)
12791272

1280-
# Returns none if skipping CPU validation
1281-
cpu_validation_info = generate_cpu_validation(
1282-
skip_validation=skip_validation,
1283-
model_variant=model_variant,
1284-
max_new_tokens=max_new_tokens,
1285-
validation_info_outputs_dir=validation_info_outputs_dir,
1286-
save_validation_info_outputs=save_validation_info_outputs,
1287-
validation_model=validation_model,
1288-
valid_prompt=valid_prompt.shape,
1289-
input_ids=valid_prompt.input_ids,
1290-
extra_kwargs=valid_prompt.extra_kwargs,
1291-
sample_key=valid_prompt.sample_key,
1292-
attn_name=env_config.attn_name,
1293-
cpu_dtype=env_config.cpu_dtype,
1294-
tokenizer=tokenizer,
1295-
)
1296-
1297-
aiu_validation_info = generate_aiu_validation(
1298-
test_type=test_type,
1299-
skip_validation=skip_validation,
1300-
max_new_tokens=max_new_tokens,
1301-
timing=timing,
1302-
prefill_chunk_size=prefill_chunk_size,
1303-
model=model,
1304-
input_ids=valid_prompt.input_ids,
1305-
cpu_validation_info=cpu_validation_info,
1306-
extra_kwargs=valid_prompt.extra_kwargs,
1307-
)
1308-
1309-
if test_type == "metrics":
1310-
failure_rate = evaluate_cross_entropy_metrics(
1311-
cross_entropy_threshold=cross_entropy_threshold,
1312-
aiu_validation_info=aiu_validation_info,
1313-
cpu_validation_info=cpu_validation_info,
1314-
program_id=valid_prompt.program_id,
1315-
prompt_shape=valid_prompt.shape,
1273+
if not skip_validation:
1274+
# Generate or load CPU validation info
1275+
cpu_validation_info = generate_cpu_validation(
1276+
model_variant=model_variant,
1277+
max_new_tokens=max_new_tokens,
1278+
validation_info_outputs_dir=validation_info_outputs_dir,
1279+
save_validation_info_outputs=save_validation_info_outputs,
1280+
validation_model=validation_model,
1281+
valid_prompt=valid_prompt.shape,
1282+
input_ids=valid_prompt.input_ids,
1283+
extra_kwargs=valid_prompt.extra_kwargs,
1284+
sample_key=valid_prompt.sample_key,
1285+
attn_name=env_config.attn_name,
1286+
cpu_dtype=env_config.cpu_dtype,
13161287
tokenizer=tokenizer,
13171288
)
1318-
if failure_rate > failure_rate_threshold:
1319-
failed_cases.append(
1320-
(valid_prompt.program_id, valid_prompt.shape, failure_rate)
1321-
)
13221289

1323-
elif test_type == "tokens":
1324-
report_token_comparison(
1290+
aiu_validation_info = generate_aiu_validation(
1291+
test_type=test_type,
13251292
max_new_tokens=max_new_tokens,
1326-
aiu_validation_info=aiu_validation_info,
1293+
timing=timing,
1294+
prefill_chunk_size=prefill_chunk_size,
1295+
model=model,
1296+
input_ids=valid_prompt.input_ids,
13271297
cpu_validation_info=cpu_validation_info,
1328-
program_id=valid_prompt.program_id,
1329-
tokenizer=tokenizer,
1298+
extra_kwargs=valid_prompt.extra_kwargs,
13301299
)
13311300

1332-
else:
1333-
raise ValueError("test type must be one of metrics or tokens")
1334-
1335-
if skip_validation and local_rank == 0:
1336-
for sentence_idx, test_sentence in enumerate(
1337-
aiu_validation_info.get_info("tokens")
1338-
):
1339-
tokens_prompt = [t.item() for t in test_sentence[:-max_new_tokens]]
1340-
aiu_tokens_generated = [
1341-
t.item() for t in test_sentence[-max_new_tokens:]
1342-
]
1343-
dprint(
1344-
f"For Program {valid_prompt.program_id} in sentence {sentence_idx + 1}:"
1301+
if test_type == "metrics":
1302+
failure_rate = evaluate_cross_entropy_metrics(
1303+
cross_entropy_threshold=cross_entropy_threshold,
1304+
aiu_validation_info=aiu_validation_info,
1305+
cpu_validation_info=cpu_validation_info,
1306+
program_id=valid_prompt.program_id,
1307+
prompt_shape=valid_prompt.shape,
1308+
tokenizer=tokenizer,
1309+
)
1310+
if failure_rate > failure_rate_threshold:
1311+
failed_cases.append(
1312+
(valid_prompt.program_id, valid_prompt.shape, failure_rate)
1313+
)
1314+
1315+
elif test_type == "tokens":
1316+
report_token_comparison(
1317+
max_new_tokens=max_new_tokens,
1318+
aiu_validation_info=aiu_validation_info,
1319+
cpu_validation_info=cpu_validation_info,
1320+
program_id=valid_prompt.program_id,
1321+
tokenizer=tokenizer,
13451322
)
1346-
dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}")
1347-
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
1348-
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")
1323+
1324+
else:
1325+
raise ValueError("test type must be one of metrics or tokens")
1326+
else:
1327+
aiu_validation_info = generate_aiu_validation(
1328+
test_type=test_type,
1329+
max_new_tokens=max_new_tokens,
1330+
timing=timing,
1331+
prefill_chunk_size=prefill_chunk_size,
1332+
model=model,
1333+
input_ids=valid_prompt.input_ids,
1334+
cpu_validation_info=None,
1335+
extra_kwargs=valid_prompt.extra_kwargs,
1336+
)
1337+
1338+
if local_rank == 0:
1339+
for sentence_idx, test_sentence in enumerate(
1340+
aiu_validation_info.get_info("tokens")
1341+
):
1342+
tokens_prompt = [t.item() for t in test_sentence[:-max_new_tokens]]
1343+
aiu_tokens_generated = [
1344+
t.item() for t in test_sentence[-max_new_tokens:]
1345+
]
1346+
dprint(
1347+
f"For Program {valid_prompt.program_id} in sentence {sentence_idx + 1}:"
1348+
)
1349+
dprint(f"Prompt:\n{tokenizer.decode(tokens_prompt)}")
1350+
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
1351+
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")
13491352

13501353
return failed_cases
13511354

@@ -1495,8 +1498,8 @@ def main() -> None:
14951498
dprint(
14961499
f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Failure Rate: {failed_case[2]}"
14971500
)
1498-
else:
1499-
dprint("all tests passed")
1501+
else:
1502+
dprint("all tests passed")
15001503

15011504

15021505
if __name__ == "__main__":

0 commit comments

Comments
 (0)