Skip to content

Commit 9dc90ac

Browse files
authored
Merge pull request #179 from foundation-model-stack/fix_add_eos_id
pass EOS token id to validate
2 parents 9595e3d + 2222178 commit 9dc90ac

5 files changed

Lines changed: 11 additions & 0 deletions

File tree

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
682682
max_new_tokens,
683683
LogitsExtractorHook(),
684684
attn_algorithm="math",
685+
eos_token_id=tokenizer.eos_token_id,
685686
**extra_kwargs,
686687
)
687688
# save the cpu validation info for later consumption
@@ -706,6 +707,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
706707
input_ids,
707708
max_new_tokens,
708709
GoldenTokenHook(cpu_validation_info.get_info("tokens")),
710+
eos_token_id=tokenizer.eos_token_id,
709711
last_n_tokens=64,
710712
timing=TIMING,
711713
prefill_chunk_size=args.prefill_chunk_size,
@@ -751,6 +753,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
751753
input_ids,
752754
max_new_tokens,
753755
None,
756+
eos_token_id=tokenizer.eos_token_id,
754757
last_n_tokens=64,
755758
timing=TIMING,
756759
prefill_chunk_size=args.prefill_chunk_size,
@@ -794,6 +797,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
794797
input_ids,
795798
max_new_tokens,
796799
None,
800+
eos_token_id=tokenizer.eos_token_id,
797801
last_n_tokens=64,
798802
timing=TIMING,
799803
prefill_chunk_size=args.prefill_chunk_size,

aiu_fms_testing_utils/scripts/generate_metrics.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def write_csv(metrics, path, metric_name):
245245
args.max_new_tokens,
246246
LogitsExtractorHook(),
247247
attn_algorithm="math",
248+
eos_token_id=tokenizer.eos_token_id,
248249
**padding_kwargs,
249250
)
250251
cpu_static_tokens = cpu_validation_info.get_info("tokens")
@@ -259,6 +260,7 @@ def write_csv(metrics, path, metric_name):
259260
ids.to("cuda"),
260261
args.max_new_tokens,
261262
None,
263+
eos_token_id=tokenizer.eos_token_id,
262264
last_n_tokens=1,
263265
**{k: v.to("cuda") for k, v in padding_kwargs.items()},
264266
)
@@ -325,6 +327,7 @@ def write_csv(metrics, path, metric_name):
325327
args.max_new_tokens,
326328
LogitsExtractorHook(),
327329
attn_algorithm="math",
330+
eos_token_id=tokenizer.eos_token_id,
328331
**padding_kwargs,
329332
)
330333

@@ -334,6 +337,7 @@ def write_csv(metrics, path, metric_name):
334337
ids.to("cuda"),
335338
args.max_new_tokens,
336339
GoldenTokenHook(cpu_validation_info.get_info("tokens"), "cuda"),
340+
eos_token_id=tokenizer.eos_token_id,
337341
last_n_tokens=1,
338342
**{k: v.to("cuda") for k, v in padding_kwargs.items()},
339343
)

aiu_fms_testing_utils/scripts/save_cpu_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def process_row(row):
9797
max_new_tokens,
9898
LogitsExtractorHook(),
9999
attn_algorithm="math",
100+
eos_token_id=tokenizer.eos_token_id,
100101
)
101102
return {"id": id, "input_ids": input_ids, "validation": cpu_validation_info}
102103

aiu_fms_testing_utils/scripts/validation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,7 @@ def print_result(result, result_idx: int = 0, file_prefix: str = ""):
690690
ids.to(validation_device),
691691
args.max_new_tokens,
692692
LogitsExtractorHook(),
693+
eos_token_id=None if args.no_early_termination else tokenizer.eos_token_id,
693694
attn_algorithm="math",
694695
**padding_kwargs,
695696
)

tests/models/test_decoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def _get_device_validation_information(
612612
max_new_tokens,
613613
post_iteration_hook,
614614
timing=TIMING,
615+
eos_token_id=tokenizer.eos_token_id,
615616
**extra_kwargs,
616617
**device_dependent_kwargs,
617618
)

0 commit comments

Comments
 (0)