@@ -891,33 +891,31 @@ def get_valid_prompts(
891891
892892
893893def generate_cpu_validation (
894- skip_validation : bool ,
895894 model_variant : str ,
896895 max_new_tokens : int ,
897896 validation_info_outputs_dir : str ,
898897 save_validation_info_outputs : bool ,
899- validation_model : Optional [ torch .nn .Module ] ,
898+ validation_model : torch .nn .Module ,
900899 valid_prompt ,
901900 input_ids : torch .Tensor ,
902901 extra_kwargs : Dict [str , Any ],
903902 sample_key : str ,
904903 attn_name : str ,
905904 cpu_dtype : str ,
906905 tokenizer : AutoTokenizer ,
907- ) -> ValidationInfo | None :
906+ ) -> ValidationInfo :
908907 """Generates or loads CPU validation information for reference comparison.
909908
910- Attempts to load pre-computed CPU validation data from disk. If not found and
911- a validation model is provided, runs CPU inference to generate reference outputs
912- (tokens and logits). Optionally saves the validation info for future use.
909+ Attempts to load pre-computed CPU validation data from disk. If not found,
910+ runs CPU inference to generate reference outputs (tokens and logits).
911+ Optionally saves the validation info for future use.
913912
914913 Args:
915- skip_validation: Whether to skip validation entirely.
916914 model_variant: Model identifier or path.
917915 max_new_tokens: Maximum number of tokens to generate.
918916 validation_info_outputs_dir: Directory for validation info outputs.
919917 save_validation_info_outputs: Whether to save validation info to disk.
920- validation_model: Optional CPU model for generating validation data.
918+ validation_model: CPU model for generating validation data.
921919 valid_prompt: Tuple of (batch_size, seq_length) for the prompt shape.
922920 input_ids: Tokenized input tensor.
923921 extra_kwargs: Dictionary with attention mask and other model inputs.
@@ -927,54 +925,51 @@ def generate_cpu_validation(
927925 tokenizer: HuggingFace tokenizer for the model.
928926
929927 Returns:
930- Optional[ ValidationInfo] : ValidationInfo object containing CPU reference outputs
931- (tokens and logits), or None if validation is skipped .
928+ ValidationInfo: ValidationInfo object containing CPU reference outputs
929+ (tokens and logits).
932930 """
933- cpu_validation_info : Optional [ValidationInfo ] = None
934- if not skip_validation :
935- # attempt to load the cpu validation info if it is already computed
936- cpu_validation_info = _load_validation_info (
937- model_variant = model_variant ,
938- batch_size = valid_prompt [0 ],
939- seq_length = valid_prompt [1 ],
931+ # attempt to load the cpu validation info if it is already computed
932+ cpu_validation_info = _load_validation_info (
933+ model_variant = model_variant ,
934+ batch_size = valid_prompt [0 ],
935+ seq_length = valid_prompt [1 ],
936+ max_new_tokens = max_new_tokens ,
937+ tokenizer = tokenizer ,
938+ seed = 0 ,
939+ cpu_dtype = cpu_dtype ,
940+ attn_type = attn_name ,
941+ validation_info_outputs_dir = validation_info_outputs_dir ,
942+ sample_key = sample_key ,
943+ )
944+ if cpu_validation_info is None :
945+ cpu_validation_info = extract_validation_information (
946+ model = validation_model ,
947+ input_ids = input_ids ,
940948 max_new_tokens = max_new_tokens ,
941- tokenizer = tokenizer ,
942- seed = 0 ,
943- cpu_dtype = cpu_dtype ,
944- attn_type = attn_name ,
945- validation_info_outputs_dir = validation_info_outputs_dir ,
946- sample_key = sample_key ,
949+ post_iteration_hook = LogitsExtractorHook (),
950+ attn_algorithm = "math" ,
951+ ** extra_kwargs ,
947952 )
948- if cpu_validation_info is None and validation_model is not None :
949- cpu_validation_info = extract_validation_information (
950- model = validation_model ,
951- input_ids = input_ids ,
952- max_new_tokens = max_new_tokens ,
953- post_iteration_hook = LogitsExtractorHook (),
954- attn_algorithm = "math" ,
955- ** extra_kwargs ,
956- )
957- if save_validation_info_outputs :
958- cpu_validation_info .save (
959- get_validation_info_path (
960- validation_info_dir = validation_info_outputs_dir ,
961- model_variant = model_variant ,
962- batch_size = valid_prompt [0 ],
963- seq_length = valid_prompt [1 ],
964- max_new_tokens = max_new_tokens ,
965- seed = 0 ,
966- attn_type = attn_name ,
967- dtype = cpu_dtype ,
968- sample_key = sample_key ,
969- )
953+ if save_validation_info_outputs :
954+ cpu_validation_info .save (
955+ get_validation_info_path (
956+ validation_info_dir = validation_info_outputs_dir ,
957+ model_variant = model_variant ,
958+ batch_size = valid_prompt [0 ],
959+ seq_length = valid_prompt [1 ],
960+ max_new_tokens = max_new_tokens ,
961+ seed = 0 ,
962+ attn_type = attn_name ,
963+ dtype = cpu_dtype ,
964+ sample_key = sample_key ,
970965 )
966+ )
971967
972968 return cpu_validation_info
973969
974970
975971def generate_aiu_validation (
976972 test_type : str ,
977- skip_validation : bool ,
978973 max_new_tokens : int ,
979974 timing : str ,
980975 prefill_chunk_size : int ,
@@ -991,7 +986,6 @@ def generate_aiu_validation(
991986
992987 Args:
993988 test_type: Type of test being run ("metrics" or "tokens").
994- skip_validation: Whether to skip validation entirely.
995989 max_new_tokens: Maximum number of tokens to generate.
996990 timing: Whether to collect timing information.
997991 prefill_chunk_size: Chunk size for prefill operations.
@@ -1005,9 +999,8 @@ def generate_aiu_validation(
1005999 and optional timing information).
10061000 """
10071001 golden_hook = None
1008- if test_type == "metrics" :
1009- if not skip_validation and cpu_validation_info :
1010- golden_hook = GoldenTokenHook (cpu_validation_info .get_info ("tokens" ))
1002+ if test_type == "metrics" and cpu_validation_info :
1003+ golden_hook = GoldenTokenHook (cpu_validation_info .get_info ("tokens" ))
10111004
10121005 aiu_validation_info = extract_validation_information (
10131006 model = model ,
@@ -1277,75 +1270,85 @@ def generate_validation_info_and_test(
12771270 f"program id: { valid_prompt .program_id } , valid prompt: { valid_prompt .shape } , input shape: { valid_prompt .input_ids .shape } "
12781271 )
12791272
1280- # Returns none if skipping CPU validation
1281- cpu_validation_info = generate_cpu_validation (
1282- skip_validation = skip_validation ,
1283- model_variant = model_variant ,
1284- max_new_tokens = max_new_tokens ,
1285- validation_info_outputs_dir = validation_info_outputs_dir ,
1286- save_validation_info_outputs = save_validation_info_outputs ,
1287- validation_model = validation_model ,
1288- valid_prompt = valid_prompt .shape ,
1289- input_ids = valid_prompt .input_ids ,
1290- extra_kwargs = valid_prompt .extra_kwargs ,
1291- sample_key = valid_prompt .sample_key ,
1292- attn_name = env_config .attn_name ,
1293- cpu_dtype = env_config .cpu_dtype ,
1294- tokenizer = tokenizer ,
1295- )
1296-
1297- aiu_validation_info = generate_aiu_validation (
1298- test_type = test_type ,
1299- skip_validation = skip_validation ,
1300- max_new_tokens = max_new_tokens ,
1301- timing = timing ,
1302- prefill_chunk_size = prefill_chunk_size ,
1303- model = model ,
1304- input_ids = valid_prompt .input_ids ,
1305- cpu_validation_info = cpu_validation_info ,
1306- extra_kwargs = valid_prompt .extra_kwargs ,
1307- )
1308-
1309- if test_type == "metrics" :
1310- failure_rate = evaluate_cross_entropy_metrics (
1311- cross_entropy_threshold = cross_entropy_threshold ,
1312- aiu_validation_info = aiu_validation_info ,
1313- cpu_validation_info = cpu_validation_info ,
1314- program_id = valid_prompt .program_id ,
1315- prompt_shape = valid_prompt .shape ,
1273+ if not skip_validation :
1274+ # Generate or load CPU validation info
1275+ cpu_validation_info = generate_cpu_validation (
1276+ model_variant = model_variant ,
1277+ max_new_tokens = max_new_tokens ,
1278+ validation_info_outputs_dir = validation_info_outputs_dir ,
1279+ save_validation_info_outputs = save_validation_info_outputs ,
1280+ validation_model = validation_model ,
1281+ valid_prompt = valid_prompt .shape ,
1282+ input_ids = valid_prompt .input_ids ,
1283+ extra_kwargs = valid_prompt .extra_kwargs ,
1284+ sample_key = valid_prompt .sample_key ,
1285+ attn_name = env_config .attn_name ,
1286+ cpu_dtype = env_config .cpu_dtype ,
13161287 tokenizer = tokenizer ,
13171288 )
1318- if failure_rate > failure_rate_threshold :
1319- failed_cases .append (
1320- (valid_prompt .program_id , valid_prompt .shape , failure_rate )
1321- )
13221289
1323- elif test_type == "tokens" :
1324- report_token_comparison (
1290+ aiu_validation_info = generate_aiu_validation (
1291+ test_type = test_type ,
13251292 max_new_tokens = max_new_tokens ,
1326- aiu_validation_info = aiu_validation_info ,
1293+ timing = timing ,
1294+ prefill_chunk_size = prefill_chunk_size ,
1295+ model = model ,
1296+ input_ids = valid_prompt .input_ids ,
13271297 cpu_validation_info = cpu_validation_info ,
1328- program_id = valid_prompt .program_id ,
1329- tokenizer = tokenizer ,
1298+ extra_kwargs = valid_prompt .extra_kwargs ,
13301299 )
13311300
1332- else :
1333- raise ValueError ("test type must be one of metrics or tokens" )
1334-
1335- if skip_validation and local_rank == 0 :
1336- for sentence_idx , test_sentence in enumerate (
1337- aiu_validation_info .get_info ("tokens" )
1338- ):
1339- tokens_prompt = [t .item () for t in test_sentence [:- max_new_tokens ]]
1340- aiu_tokens_generated = [
1341- t .item () for t in test_sentence [- max_new_tokens :]
1342- ]
1343- dprint (
1344- f"For Program { valid_prompt .program_id } in sentence { sentence_idx + 1 } :"
1301+ if test_type == "metrics" :
1302+ failure_rate = evaluate_cross_entropy_metrics (
1303+ cross_entropy_threshold = cross_entropy_threshold ,
1304+ aiu_validation_info = aiu_validation_info ,
1305+ cpu_validation_info = cpu_validation_info ,
1306+ program_id = valid_prompt .program_id ,
1307+ prompt_shape = valid_prompt .shape ,
1308+ tokenizer = tokenizer ,
1309+ )
1310+ if failure_rate > failure_rate_threshold :
1311+ failed_cases .append (
1312+ (valid_prompt .program_id , valid_prompt .shape , failure_rate )
1313+ )
1314+
1315+ elif test_type == "tokens" :
1316+ report_token_comparison (
1317+ max_new_tokens = max_new_tokens ,
1318+ aiu_validation_info = aiu_validation_info ,
1319+ cpu_validation_info = cpu_validation_info ,
1320+ program_id = valid_prompt .program_id ,
1321+ tokenizer = tokenizer ,
13451322 )
1346- dprint (f"Prompt:\n { tokenizer .decode (tokens_prompt )} " )
1347- dprint (f"AIU tokens:\n { aiu_tokens_generated } " )
1348- dprint (f"AIU output:\n { tokenizer .decode (aiu_tokens_generated )} " )
1323+
1324+ else :
1325+ raise ValueError ("test type must be one of metrics or tokens" )
1326+ else :
1327+ aiu_validation_info = generate_aiu_validation (
1328+ test_type = test_type ,
1329+ max_new_tokens = max_new_tokens ,
1330+ timing = timing ,
1331+ prefill_chunk_size = prefill_chunk_size ,
1332+ model = model ,
1333+ input_ids = valid_prompt .input_ids ,
1334+ cpu_validation_info = None ,
1335+ extra_kwargs = valid_prompt .extra_kwargs ,
1336+ )
1337+
1338+ if local_rank == 0 :
1339+ for sentence_idx , test_sentence in enumerate (
1340+ aiu_validation_info .get_info ("tokens" )
1341+ ):
1342+ tokens_prompt = [t .item () for t in test_sentence [:- max_new_tokens ]]
1343+ aiu_tokens_generated = [
1344+ t .item () for t in test_sentence [- max_new_tokens :]
1345+ ]
1346+ dprint (
1347+ f"For Program { valid_prompt .program_id } in sentence { sentence_idx + 1 } :"
1348+ )
1349+ dprint (f"Prompt:\n { tokenizer .decode (tokens_prompt )} " )
1350+ dprint (f"AIU tokens:\n { aiu_tokens_generated } " )
1351+ dprint (f"AIU output:\n { tokenizer .decode (aiu_tokens_generated )} " )
13491352
13501353 return failed_cases
13511354
@@ -1495,8 +1498,8 @@ def main() -> None:
14951498 dprint (
14961499 f"Program ID: { failed_case [0 ]} , Prompt Shape: { failed_case [1 ]} , Failure Rate: { failed_case [2 ]} "
14971500 )
1498- else :
1499- dprint ("all tests passed" )
1501+ else :
1502+ dprint ("all tests passed" )
15001503
15011504
15021505if __name__ == "__main__" :
0 commit comments