Skip to content

Commit 714297b

Browse files
authored
Merge pull request #102 from kcirred/test_decoder_update
test_decoder update
2 parents 9cd18ac + 9cb8ff8 commit 714297b

2 files changed

Lines changed: 79 additions & 22 deletions

File tree

aiu_fms_testing_utils/testing/validation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,14 @@ def __len__(self):
124124

125125

126126
def get_default_validation_prefix(
127-
model_id: str, max_new_tokens: int, batch_size: int, seq_length: int, dtype: str
127+
model_id: str,
128+
max_new_tokens: int,
129+
batch_size: int,
130+
seq_length: int,
131+
dtype: str,
132+
attn_type: str,
128133
):
129-
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}"
134+
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}"
130135

131136

132137
def load_validation_information(

tests/models/test_decoders.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from fms.models import get_model
55
from fms.utils.generation import pad_input_ids
66
import itertools
7+
import warnings
8+
import re
79
import torch
810
from torch import distributed as dist
911
from aiu_fms_testing_utils.testing.validation import (
@@ -171,7 +173,7 @@
171173
# the compiler supports certain max context lengths (VLLM_DT_MAX_CONTEXT_LEN)
172174
# this will ensure that we select smallest supported VLLM_DT_MAX_CONTEXT_LEN that fits the largest possible context (prompt size + max_new_tokens)
173175
__largest_context = max(common_seq_lengths) + max(common_max_new_tokens)
174-
__supported_context_lengths = [256, 512, 1024, 2048, 4096, 8192]
176+
__supported_context_lengths = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
175177
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str(
176178
__supported_context_lengths[
177179
bisect.bisect_left(__supported_context_lengths, __largest_context)
@@ -301,14 +303,28 @@ def __maybe_get_gptq_kwargs(model_path):
301303

302304

303305
def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
304-
prompts_and_sizes = sample_sharegpt_requests(
305-
SHARE_GPT_DATASET_PATH,
306-
batch_size,
307-
tokenizer,
308-
seq_length // 2,
309-
seq_length,
310-
seed,
311-
)
306+
if "paged" in ATTN_NAME:
307+
prompts_and_sizes = sample_sharegpt_requests(
308+
SHARE_GPT_DATASET_PATH,
309+
batch_size,
310+
tokenizer,
311+
32,
312+
seq_length,
313+
seed,
314+
enforce_heterogeneous=True,
315+
enforce_sizes=[seq_length], # ensure at least the max seq length is sampled
316+
pad_multiple=64,
317+
)
318+
else:
319+
prompts_and_sizes = sample_sharegpt_requests(
320+
SHARE_GPT_DATASET_PATH,
321+
batch_size,
322+
tokenizer,
323+
seq_length // 2,
324+
seq_length,
325+
seed,
326+
)
327+
312328
prompt_list = []
313329
for prompt, _ in prompts_and_sizes:
314330
prompt_list.append(tokenizer.encode(prompt, return_tensors="pt").squeeze(0))
@@ -341,25 +357,44 @@ def __filter_before_eos(metrics, filter_indexes):
341357

342358

343359
def __get_validation_info_full_path(
344-
model_path, batch_size, seq_length, max_new_tokens, seed, device_type="cpu"
360+
model_path,
361+
batch_size,
362+
seq_length,
363+
max_new_tokens,
364+
seed,
365+
attn_type: str,
366+
device_type="cpu",
345367
):
346-
validation_file_name = f"{get_default_validation_prefix(model_path, max_new_tokens, batch_size, seq_length, 'fp16')}.{device_type}_validation_info.{seed}.out"
368+
validation_file_name = f"{get_default_validation_prefix(model_path, max_new_tokens, batch_size, seq_length, 'fp16', attn_type)}.{device_type}_validation_info.{seed}.out"
347369
full_path = os.path.join(validation_info_dir, validation_file_name)
348370
return full_path
349371

350372

351373
def __load_validation_info(
352-
model_path, batch_size, seq_length, max_new_tokens, tokenizer, seed
374+
model_path, batch_size, seq_length, max_new_tokens, tokenizer, seed, attn_type: str
353375
):
376+
# if path doesn't exist and paged isn't in the attention name, remove `attn_type` and recheck again, warn that we will no longer in the future have paths without 'attn_type'
354377
full_path = __get_validation_info_full_path(
355-
model_path, batch_size, seq_length, max_new_tokens, seed
378+
model_path, batch_size, seq_length, max_new_tokens, seed, attn_type
356379
)
357380

358381
if os.path.exists(full_path):
359382
dprint(f"cpu validation info found for seed={seed} -- loading it")
360383
return load_validation_information(full_path, "logits", batch_size, tokenizer)
361-
else:
362-
return None
384+
elif "paged" not in attn_type:
385+
# This regex applies to a very specific file name format
386+
modified_full_path = re.sub(r"_attn-type[^.]*", "", full_path)
387+
388+
if os.path.exists(modified_full_path):
389+
warnings.warn(
390+
f"All future paths should contain attn_type prefix information in path name, please modify {full_path=} to {modified_full_path=}",
391+
stacklevel=2,
392+
)
393+
dprint(f"cpu validation info found for seed={seed} -- loading it")
394+
return load_validation_information(
395+
modified_full_path, "logits", batch_size, tokenizer
396+
)
397+
return None
363398

364399

365400
class PersistentModel:
@@ -513,7 +548,7 @@ def test_common_shapes(
513548

514549
# generate cpu validation info
515550
cpu_validation_info = __load_validation_info(
516-
model_path, batch_size, seq_length, max_new_tokens, tokenizer, 0
551+
model_path, batch_size, seq_length, max_new_tokens, tokenizer, 0, ATTN_NAME
517552
)
518553
if cpu_validation_info is None:
519554
cpu_validation_info = extract_validation_information(
@@ -529,7 +564,7 @@ def test_common_shapes(
529564
if save_validation_info_outputs:
530565
cpu_validation_info.save(
531566
__get_validation_info_full_path(
532-
model_path, batch_size, seq_length, max_new_tokens, 0
567+
model_path, batch_size, seq_length, max_new_tokens, 0, ATTN_NAME
533568
)
534569
)
535570
cpu_static_tokens = cpu_validation_info.get_info("tokens")
@@ -591,7 +626,13 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
591626
)
592627
extra_kwargs["attn_name"] = ATTN_NAME
593628
cpu_validation_info = __load_validation_info(
594-
model_path, batch_size, seq_length, max_new_tokens, tokenizer, i
629+
model_path,
630+
batch_size,
631+
seq_length,
632+
max_new_tokens,
633+
tokenizer,
634+
i,
635+
ATTN_NAME,
595636
)
596637
if cpu_validation_info is None:
597638
cpu_validation_info = extract_validation_information(
@@ -609,7 +650,12 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
609650
if save_validation_info_outputs:
610651
cpu_validation_info.save(
611652
__get_validation_info_full_path(
612-
model_path, batch_size, seq_length, max_new_tokens, i
653+
model_path,
654+
batch_size,
655+
seq_length,
656+
max_new_tokens,
657+
i,
658+
ATTN_NAME,
613659
)
614660
)
615661
cpu_static_tokens = cpu_validation_info.get_info("tokens")
@@ -634,7 +680,13 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
634680
if save_validation_info_outputs:
635681
aiu_validation_info.save(
636682
__get_validation_info_full_path(
637-
model_path, batch_size, seq_length, max_new_tokens, i, "aiu"
683+
model_path,
684+
batch_size,
685+
seq_length,
686+
max_new_tokens,
687+
i,
688+
ATTN_NAME,
689+
"aiu",
638690
)
639691
)
640692

0 commit comments

Comments
 (0)