Skip to content

Commit 935e77f

Browse files
authored
Merge branch 'foundation-model-stack:main' into tinto_play_main
2 parents 645bec8 + fc3a30f commit 935e77f

3 files changed

Lines changed: 133 additions & 85 deletions

File tree

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,7 @@ def __custom_line_sampler(*args, **kwargs):
264264
max_tkv = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"])
265265

266266

267-
def __prepare_inputs(
268-
batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0, pad_multiple=64
269-
):
267+
def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0):
270268
start = time.time()
271269
prompts_and_sizes, sample_key = sampler(
272270
DATASET_PATH,
@@ -278,7 +276,6 @@ def __prepare_inputs(
278276
enforce_sizes=enforce_sizes,
279277
truncation=allow_truncation,
280278
return_key=True,
281-
pad_multiple=pad_multiple,
282279
)
283280
end = time.time()
284281
if local_rank == 0:
@@ -291,6 +288,10 @@ def __prepare_inputs(
291288
encoded = encoded[:seq_length]
292289
prompt_list.append(encoded)
293290

291+
if not prompt_list:
292+
raise ValueError(
293+
f"No valid prompt sample exists in dataset for input shape (Batch Size={batch_size}, Seq Length={seq_length})"
294+
)
294295
if len(prompt_list) < batch_size:
295296
dprint(
296297
f"You requested {batch_size} prompts but we were only able to get {len(prompt_list)} valid prompts. We will be repeating the first prompt."
@@ -396,13 +397,7 @@ def __load_validation_info(
396397
# warmup with any input so compiler produces criteria json
397398
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
398399
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
399-
pad_multiple = 64
400-
if args.prefill_chunk_size > 0:
401-
assert args.prefill_chunk_size % 64 == 0, (
402-
"Chunk size must be a multiple of the page size"
403-
)
404-
pad_multiple = args.prefill_chunk_size
405-
prompt_list = [torch.arange(0, pad_multiple, dtype=torch.int64)]
400+
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
406401
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
407402
if is_fp8:
408403
prompt_list = prompt_list * 2
@@ -526,7 +521,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
526521
# FIXME: filter condition for this on prompt and batch
527522
program_map = get_programs_prompts(
528523
program_criteria_list,
529-
multiple=pad_multiple,
524+
multiple=64,
530525
max_batch_size=max_batch_size,
531526
max_tkv=max_tkv,
532527
program_cycles=max_new_tokens,
@@ -547,7 +542,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
547542
valid_prompt_shape[1],
548543
tokenizer,
549544
enforce_sizes=enforce_sizes,
550-
pad_multiple=pad_multiple,
551545
)
552546
valid_prompts = [
553547
(
@@ -601,29 +595,14 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
601595
# if there does not exist enough sequence sizes between this range, we will cycle back to the beginning
602596
# in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user
603597
enforce_sizes = [valid_prompt_shape[1]]
604-
if (
605-
args.enforce_homogeneous_prompt_programs
606-
or args.prefill_chunk_size > 0
607-
):
608-
# if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
609-
tkv_cutoff = (
610-
1 << (valid_prompt_shape[1].bit_length() - 1)
611-
if args.enforce_homogeneous_prompt_programs
612-
else pad_multiple
613-
)
614-
598+
if args.enforce_homogeneous_prompt_programs:
599+
# this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length
600+
tkv_cutoff = 1 << (valid_prompt_shape[1].bit_length() - 1)
615601
possible_seq_lengths = [
616-
_
617-
for _ in range(
618-
tkv_cutoff, valid_prompt_shape[1], pad_multiple
619-
)
602+
_ for _ in range(tkv_cutoff, valid_prompt_shape[1], 64)
620603
]
621604
# favor sequences that are close to the valid prompt length
622605
possible_seq_lengths.reverse()
623-
# add the valid prompt size to the end since it will already exist in the above enforce_sizes
624-
possible_seq_lengths = possible_seq_lengths + [
625-
valid_prompt_shape[1]
626-
]
627606
enforce_sizes = enforce_sizes + list(
628607
itertools.islice(
629608
itertools.cycle(possible_seq_lengths),
@@ -636,7 +615,6 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
636615
valid_prompt_shape[1],
637616
tokenizer,
638617
enforce_sizes=enforce_sizes,
639-
pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size)
640618
)
641619
valid_prompts.append(
642620
(

aiu_fms_testing_utils/utils/paged.py

Lines changed: 97 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union
66
import torch
77
import fms.utils.spyre.paged # noqa
8+
from aiu_fms_testing_utils.utils import get_pad_size
89

910

1011
def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs):
@@ -226,6 +227,12 @@ def generate(
226227
# left_padded_prompt_mask - empty_slots + context_lengths
227228
current_tkv_mask = torch.fill(context_lengths, input_ids.shape[1])
228229

230+
# if using chunked prefill, reserve a pad block
231+
# reserving a pad block is required as writes to pad are done in parallel and could corrupt the real blocks
232+
if prefill_chunk_size > 0:
233+
pad_block_id = block_numbers.pop(0)
234+
pad_slots = [(pad_block_id * BLOCK_SIZE) + pos_i for pos_i in range(BLOCK_SIZE)]
235+
229236
slot_mapping = []
230237
block_table = []
231238
# each sequence has the possibility of a different tkv, so loop over that
@@ -244,6 +251,7 @@ def generate(
244251
slot_mapping_i.append(slot)
245252
slot_mapping.append(slot_mapping_i)
246253
block_table.append(block_table_i)
254+
247255
kwargs["current_tkv_mask"] = None
248256
kwargs["left_padded_prompt_mask"] = None
249257
kwargs["use_cache"] = use_cache
@@ -300,64 +308,110 @@ def generate(
300308
last_n_tokens = kwargs.get("last_n_tokens", 0)
301309

302310
if prefill_chunk_size > 0:
303-
left_padded_prompt_mask_seq_chunk = None
311+
required_extra_pads = (
312+
get_pad_size(current_tkv.item(), prefill_chunk_size)
313+
- current_tkv.item()
314+
)
315+
left_padded_prompt_mask_seq_chunk = (
316+
(kwargs["position_ids"][seq_i][-current_tkv.item() :] == 0).sum(
317+
dim=0
318+
)
319+
- 1
320+
+ required_extra_pads
321+
)
322+
left_padded_prompt_mask_seq_chunk = (
323+
left_padded_prompt_mask_seq_chunk.unsqueeze(0)
324+
)
325+
block_seq_left_padding = required_extra_pads // BLOCK_SIZE
326+
304327
# Chunked prefill
305328
for chunk_j in range(math.ceil(current_tkv / prefill_chunk_size)):
306-
chunk_start = -current_tkv + chunk_j * prefill_chunk_size
307-
chunk_end = -current_tkv + min(
308-
(chunk_j + 1) * prefill_chunk_size, current_tkv
309-
)
329+
# chunk_start and chunk_end are the index mappings from the original sequence
330+
if chunk_j == 0:
331+
chunk_start = 0
332+
chunk_end = prefill_chunk_size - required_extra_pads
333+
else:
334+
required_extra_pads = 0
335+
chunk_start = chunk_end
336+
chunk_end += prefill_chunk_size
337+
338+
input_ids_seq_chunk = input_ids[seq_i][-current_tkv:][
339+
chunk_start:chunk_end
340+
]
341+
slot_mapping_seq_chunk = slot_mapping[seq_i][-current_tkv:][
342+
chunk_start:chunk_end
343+
]
344+
position_ids_seq_chunk = kwargs["position_ids"][seq_i][
345+
-current_tkv:
346+
][chunk_start:chunk_end]
347+
348+
# add the extra required padding to chunk
349+
if required_extra_pads > 0:
350+
input_ids_seq_chunk = torch.cat(
351+
(
352+
torch.zeros(
353+
required_extra_pads,
354+
dtype=torch.int64,
355+
device=input_ids_seq_chunk.device,
356+
),
357+
input_ids_seq_chunk,
358+
)
359+
)
360+
slot_mapping_seq_chunk = (
361+
pad_slots * (required_extra_pads // BLOCK_SIZE)
362+
+ slot_mapping_seq_chunk
363+
)
364+
position_ids_seq_chunk = torch.cat(
365+
(
366+
torch.zeros(
367+
required_extra_pads,
368+
dtype=torch.int64,
369+
device=position_ids_seq_chunk.device,
370+
),
371+
position_ids_seq_chunk,
372+
)
373+
)
374+
375+
input_ids_seq_chunk = input_ids_seq_chunk.unsqueeze(0).clone()
310376

311-
ids_length = input_ids[seq_i].shape[0]
312-
input_ids_seq_chunk = (
313-
input_ids[seq_i][
314-
chunk_start + ids_length : chunk_end + ids_length
315-
]
316-
.unsqueeze(0)
317-
.clone()
318-
)
319-
assert input_ids_seq_chunk.size(1) == prefill_chunk_size, (
320-
f"prefill chunk size was not equal to the chunk size. Found {input_ids_seq_chunk.size(0)}"
321-
)
322-
slots_length = len(slot_mapping[seq_i])
323377
slot_mapping_seq_chunk = (
324378
torch.tensor(
325-
slot_mapping[seq_i][
326-
chunk_start + slots_length : chunk_end
327-
+ slots_length
328-
],
379+
slot_mapping_seq_chunk,
329380
dtype=torch.int64,
330381
)
331382
.unsqueeze(0)
332383
.clone()
333384
)
334-
pids_length = kwargs["position_ids"][seq_i].shape[0]
335-
position_ids_seq_chunk = (
336-
kwargs["position_ids"][seq_i][
337-
chunk_start + pids_length : chunk_end + pids_length
338-
]
339-
.unsqueeze(0)
340-
.clone()
385+
386+
position_ids_seq_chunk = position_ids_seq_chunk.unsqueeze(
387+
0
388+
).clone()
389+
390+
assert input_ids_seq_chunk.size(1) == prefill_chunk_size, (
391+
f"prefill chunk size was not equal to the chunk size for input_ids. Found {input_ids_seq_chunk.size(0)}"
341392
)
342393

343-
# This view will result in a discontiguous tensor (creates a new graph during compile)
344-
# For this reason, we must explicitly make contiguous
345-
if left_padded_prompt_mask_seq_chunk is None:
346-
left_padded_prompt_mask_seq_chunk = (
347-
position_ids_seq_chunk == 0
348-
).sum(dim=1) - 1
349-
current_tkv_mask_seq_chunk = torch.min(
350-
torch.tensor(
351-
(chunk_j + 1) * prefill_chunk_size, dtype=torch.int64
352-
),
353-
current_tkv,
394+
assert slot_mapping_seq_chunk.size(1) == prefill_chunk_size, (
395+
f"prefill chunk size was not equal to the chunk size for slot_mapping. Found {slot_mapping_seq_chunk.size(0)}"
396+
)
397+
398+
assert position_ids_seq_chunk.size(1) == prefill_chunk_size, (
399+
f"prefill chunk size was not equal to the chunk size for position_ids. Found {position_ids_seq_chunk.size(0)}"
400+
)
401+
402+
current_tkv_mask_seq_chunk = torch.tensor(
403+
(chunk_j + 1) * prefill_chunk_size, dtype=torch.int64
354404
).unsqueeze(0)
355405

356-
table_length = len(block_table[seq_i])
357-
block_start = -current_tkv // BLOCK_SIZE + table_length
358-
block_end = chunk_end // BLOCK_SIZE + table_length
406+
block_end = chunk_end // BLOCK_SIZE
407+
# length of padding or index until padding has occured in block table
408+
block_pad_len = (input_ids.shape[1] - current_tkv) // BLOCK_SIZE
359409
block_table_seq_chunk = torch.tensor(
360-
block_table[seq_i][block_start:block_end], dtype=torch.int64
410+
[pad_block_id] * (block_seq_left_padding)
411+
+ block_table[seq_i][
412+
block_pad_len : block_pad_len + block_end
413+
],
414+
dtype=torch.int64,
361415
).unsqueeze(0)
362416

363417
chunked_kwargs = {

tests/models/test_scripts.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,25 @@ def execute_dpp(
175175
test_type,
176176
skip_validation,
177177
enforce_homogeneous_prompt_programs,
178+
prefill_chunk_size,
178179
shared_tmp_path,
179180
isolated_env,
180181
):
181182
isolated_env["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = "1024"
182183
isolated_env["VLLM_DT_MAX_CONTEXT_LEN"] = "512"
183184
isolated_env["VLLM_DT_MAX_BATCH_SIZE"] = "2"
185+
if prefill_chunk_size > 0:
186+
isolated_env["VLLM_DT_CHUNK_LEN"] = f"{prefill_chunk_size}"
184187
Path(os.path.join(shared_tmp_path, "sendnn_cache")).mkdir(exist_ok=True)
185-
os.environ.setdefault(
186-
"TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache")
187-
)
188-
isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "1"
188+
189+
# only enable for non-chunk
190+
if prefill_chunk_size == 0:
191+
os.environ.setdefault(
192+
"TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache")
193+
)
194+
isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "1"
195+
else:
196+
isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "0"
189197

190198
command_list = [
191199
"python3",
@@ -239,6 +247,9 @@ def execute_dpp(
239247
if enforce_homogeneous_prompt_programs:
240248
command_list += ["--enforce_homogeneous_prompt_programs"]
241249

250+
if prefill_chunk_size > 0:
251+
command_list += [f"--prefill_chunk_size={prefill_chunk_size}"]
252+
242253
# add program criteria path
243254
command_list += [
244255
f"--program_criteria_json_path={os.environ['DT_PROG_CRITERIA_FILEPATH']}"
@@ -249,21 +260,24 @@ def execute_dpp(
249260

250261
dpp_possibilities = []
251262
dpp_possibilities.append(
252-
("paged", None, 8, "sharegpt", "metrics", False, False)
263+
("paged", None, 8, "sharegpt", "metrics", False, False, 0)
253264
) # metrics and run all programs
254265
dpp_possibilities.append(
255-
("paged", "*:0,==256", 65, "sharegpt", "tokens", False, False)
266+
("paged", "*:0,==256", 65, "sharegpt", "tokens", False, False, 0)
256267
) # tokens and run all programs that satisfy 256 sequence length
257268
dpp_possibilities.append(
258-
("paged", "*:>=2,0", 65, "sharegpt", None, True, True)
269+
("paged", "*:>=2,0", 65, "sharegpt", None, True, True, 0)
259270
) # metrics and run all programs that have >=2 batch size
260271
dpp_possibilities.append(
261-
("paged", None, 8, "custom", "tokens", False, False)
272+
("paged", None, 8, "custom", "tokens", False, False, 0)
262273
) # tokens running with specific custom dataset
274+
dpp_possibilities.append(
275+
("paged", None, 8, "sharegpt", "tokens", False, False, 128)
276+
) # metrics and run all programs with chunked prefill
263277

264278

265279
@pytest.mark.parametrize(
266-
"attn_type,programs,max_new_tokens,dataset_type,test_type,skip_validation,enforce_homogeneous_prompt_programs",
280+
"attn_type,programs,max_new_tokens,dataset_type,test_type,skip_validation,enforce_homogeneous_prompt_programs,prefill_chunk_size",
267281
dpp_possibilities,
268282
)
269283
def test_dpp_script(
@@ -274,6 +288,7 @@ def test_dpp_script(
274288
test_type,
275289
skip_validation,
276290
enforce_homogeneous_prompt_programs,
291+
prefill_chunk_size,
277292
shared_tmp_path,
278293
isolated_env,
279294
):
@@ -290,6 +305,7 @@ def test_dpp_script(
290305
test_type,
291306
skip_validation,
292307
enforce_homogeneous_prompt_programs,
308+
prefill_chunk_size,
293309
shared_tmp_path,
294310
isolated_env,
295311
)

0 commit comments

Comments
 (0)