Skip to content

Commit b0cf126

Browse files
authored
Add Model Config to allow default env variables per model architecture (#178)
* configure model defined parameters Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * model config class Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * set class config - no env variables Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * print statement Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * refactor utilities Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix conditions Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * ruff formats fix Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * formatting dpp config Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix formatting Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * review updates Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * linting fixes Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * leave num_blocks default as None Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix default value to Num blocks Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix dist world size param Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * guard against distributed flags Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * pass default value to test Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix linting Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> --------- Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
1 parent aa4b3cf commit b0cf126

5 files changed

Lines changed: 155 additions & 29 deletions

File tree

aiu_fms_testing_utils/scripts/drive_paged_programs.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@
3939
from aiu_fms_testing_utils.utils.paged import (
4040
ProgramCriteria,
4141
get_programs_prompts,
42-
KVCACHE_NUM_BLOCKS_HINT,
4342
)
43+
from aiu_fms_testing_utils.utils.dpp_config import DPPRunnerConfig
44+
from aiu_fms_testing_utils.utils.env_utils import scoped_environ
4445
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string
4546

4647
parser = argparse.ArgumentParser(
@@ -378,7 +379,15 @@ def __load_validation_info(
378379

379380
model.eval()
380381
fx_config.backed_size_oblivious = True
381-
model.compile(backend="sendnn", options={"sendnn.dynamic": True})
382+
383+
model_config = DPPRunnerConfig()
384+
world_size = dist.get_world_size() if USE_DISTRIBUTED and dist.is_initialized() else 1
385+
model_config.setup_config(
386+
model_variant, USE_DISTRIBUTED, world_size, args.prefill_chunk_size
387+
)
388+
with scoped_environ(model_config.env_updates()):
389+
# Temporarily set environment variables needed for compile
390+
model.compile(backend="sendnn", options={"sendnn.dynamic": True})
382391

383392
__maybe_prepare_fp8_weights(model, is_fp8)
384393

@@ -402,15 +411,10 @@ def __load_validation_info(
402411
if is_fp8:
403412
prompt_list = prompt_list * 2
404413
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
405-
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
406414

415+
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
407416
extra_kwargs["attn_name"] = ATTN_NAME
408-
if (
409-
"granite-3.3-8b-instruct" in model_variant
410-
and USE_DISTRIBUTED
411-
and dist.get_world_size() == 4
412-
):
413-
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
417+
extra_kwargs["_kvcache_num_blocks_hint"] = model_config.num_blocks
414418
warmup_model(
415419
model,
416420
input_ids,
@@ -513,6 +517,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
513517
max_batch_size=max_batch_size,
514518
max_tkv=max_tkv,
515519
program_cycles=max_new_tokens,
520+
tkv_limit=model_config.tkv_limit,
516521
prioritize_large_batch_sizes=args.prioritize_large_batch_sizes,
517522
)
518523
for v in program_map.values():
@@ -649,12 +654,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
649654
sample_key,
650655
) in get_program_prompt_list():
651656
extra_kwargs["attn_name"] = ATTN_NAME
652-
if (
653-
"granite-3.3-8b-instruct" in model_variant
654-
and USE_DISTRIBUTED
655-
and dist.get_world_size() == 4
656-
):
657-
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
657+
extra_kwargs["_kvcache_num_blocks_hint"] = model_config.num_blocks
658658

659659
if local_rank == 0:
660660
dprint(f"*** testing program {program_id} ***")
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import os
2+
from dataclasses import dataclass
3+
4+
from aiu_fms_testing_utils.utils.aiu_setup import dprint
5+
6+
7+
@dataclass
8+
class DPPRunnerConfig:
9+
"""Class to configure parameters that may vary with model architecture"""
10+
11+
# populated during setup
12+
num_blocks: int | None = None
13+
tkv_limit: int | None = None
14+
15+
def _get_int_env(self, key: str, default: int, context: str) -> int:
16+
"""
17+
Read an integer environment variable or use a default.
18+
Always emits a debug message explaining the choice.
19+
"""
20+
value = os.environ.get(key)
21+
if value is None:
22+
dprint(f"{context}. Using default {key}={default}")
23+
return default
24+
25+
try:
26+
parsed = int(value)
27+
except ValueError as e:
28+
raise ValueError(
29+
f"{context}. Invalid value for environment variable {key}: "
30+
f"expected an integer, got '{value}'"
31+
) from e
32+
33+
dprint(f"{context}. Using {key} from environment: {parsed}")
34+
return parsed
35+
36+
def _configure_granite_3_8b(self, use_distributed, world_size, prefill_chunk_size):
37+
"""Configure environment for granite 3 8b architecture \
38+
We are setting defaults for env variables not provided. \
39+
Config class is set in wrapper setup_config function."""
40+
41+
if use_distributed and world_size == 4:
42+
##Only set defaults for TP=4
43+
context = (
44+
"Model granite-3.3-8b (or compatible) "
45+
"with tensor parallel size 4 detected"
46+
)
47+
self.tkv_limit = self._get_int_env(
48+
key="VLLM_DT_MAX_BATCH_TKV_LIMIT",
49+
default=524288,
50+
context=context,
51+
)
52+
53+
# these values are to be consistent with vllm for granite 3.3 8b instruct
54+
blocks_override = 8192 if prefill_chunk_size > 0 else 2080
55+
56+
self.num_blocks = self._get_int_env(
57+
key="AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT",
58+
default=blocks_override,
59+
context=context,
60+
)
61+
62+
def setup_config(
63+
self, model_variant, use_distributed, world_size, prefill_chunk_size
64+
):
65+
"""Set up environment variables and default values if not specified"""
66+
67+
## configure per model architecture
68+
if (
69+
"granite-3.3-8b-instruct" in model_variant
70+
or "granite-4.0-8b" in model_variant
71+
):
72+
self._configure_granite_3_8b(
73+
use_distributed, world_size, prefill_chunk_size
74+
)
75+
76+
## global defaults (fallback)
77+
## TODO: IN future we may remove defaults for unknown configurations \
78+
## and require users to set the environment variables
79+
## num_blocks is set in generate if not set here
80+
if self.tkv_limit is None:
81+
self.tkv_limit = self._get_int_env(
82+
key="VLLM_DT_MAX_BATCH_TKV_LIMIT",
83+
default=524288,
84+
context="Unknown model configuration",
85+
)
86+
87+
def env_updates(self) -> dict[str, str]:
88+
"""Returns a key/value of environment variables needed for model compile"""
89+
if self.tkv_limit is None:
90+
raise RuntimeError(
91+
"ModelConfig.env_updates() called before setup_config(). "
92+
"Call setup_config(...) first."
93+
)
94+
95+
return {"VLLM_DT_MAX_BATCH_TKV_LIMIT": str(self.tkv_limit)}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
from contextlib import contextmanager
3+
from typing import Optional
4+
5+
6+
@contextmanager
7+
def scoped_environ(updates: dict[str, Optional[str]]):
8+
"""
9+
Temporarily set environment variables.
10+
Restores original values on exit.
11+
12+
updates:
13+
key -> value
14+
value=None means unset the variable
15+
"""
16+
old_env = {}
17+
18+
try:
19+
# Save old values and apply updates
20+
for key, value in updates.items():
21+
old_env[key] = os.environ.get(key)
22+
if value is None:
23+
os.environ.pop(key, None)
24+
else:
25+
os.environ[key] = str(value)
26+
yield
27+
finally:
28+
# Restore original environment
29+
for key, old_value in old_env.items():
30+
if old_value is None:
31+
os.environ.pop(key, None)
32+
else:
33+
os.environ[key] = old_value

aiu_fms_testing_utils/utils/paged.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ def generate(
130130
_MAX_BATCH = int(os.environ["VLLM_DT_MAX_BATCH_SIZE"])
131131
_MAX_CONTEXT_LENGTH = int(os.environ["VLLM_DT_MAX_CONTEXT_LEN"])
132132
# if the user provides a hint to the number of blocks to use, use it directly
133-
NUM_BLOCKS = kwargs.get(
134-
"_kvcache_num_blocks_hint", (_MAX_BATCH * _MAX_CONTEXT_LENGTH) // BLOCK_SIZE
135-
)
133+
NUM_BLOCKS = kwargs.get("_kvcache_num_blocks_hint")
134+
if NUM_BLOCKS is None:
135+
NUM_BLOCKS = (_MAX_BATCH * _MAX_CONTEXT_LENGTH) // BLOCK_SIZE
136136

137137
if hasattr(model, "head"):
138138
model_dtype = model.head.weight.dtype
@@ -634,14 +634,6 @@ def generate(
634634
return result
635635

636636

637-
# this value is default to 8192 to be consistent with vllm for granite 3.3 8b instruct w/ chunked prefill
638-
KVCACHE_NUM_BLOCKS_HINT = int(
639-
os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 8192)
640-
)
641-
642-
VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 524288))
643-
644-
645637
class ProgramCriteria:
646638
def __init__(
647639
self, program_id, max_batch, max_tkv, batch_granularity, tkv_granularity
@@ -652,9 +644,9 @@ def __init__(
652644
self.batch_granularity = batch_granularity
653645
self.tkv_granularity = tkv_granularity
654646

655-
def is_possible(self, batch_size, tkv):
647+
def is_possible(self, batch_size, tkv, tkv_limit):
656648
return (
657-
(batch_size * tkv <= VLLM_DT_MAX_BATCH_TKV_LIMIT)
649+
(batch_size * tkv <= tkv_limit)
658650
and (batch_size <= self.max_batch)
659651
and (tkv <= self.max_tkv)
660652
)
@@ -690,6 +682,7 @@ def get_programs_prompts(
690682
max_batch_size,
691683
max_tkv,
692684
program_cycles,
685+
tkv_limit,
693686
prioritize_large_batch_sizes=True,
694687
):
695688
program_map = {}
@@ -702,7 +695,9 @@ def get_programs_prompts(
702695
for program_index in range(possible_program_switches):
703696
context_length = prompt_len + (multiple * program_index) + 1
704697

705-
if program_criteria.is_possible(batch_size, context_length):
698+
if program_criteria.is_possible(
699+
batch_size, context_length, tkv_limit
700+
):
706701
padding = program_criteria.calculate_padding(
707702
batch_size, context_length
708703
)

tests/models/test_scripts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,12 +328,15 @@ def test_dpp_script(
328328
program_assertions = [i for i in range(len(program_criteria_list))]
329329
shape_assertions = [">=0", ">=0"]
330330
else:
331+
# sets default of tkv_limit
332+
tkv_limit = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 524288))
331333
program_map = get_programs_prompts(
332334
program_criteria_list,
333335
multiple=64,
334336
max_batch_size=2,
335337
max_tkv=512,
336338
program_cycles=max_new_tokens,
339+
tkv_limit=tkv_limit,
337340
)
338341
programs_split = programs.split(":")
339342
program_ids_str = programs_split[0]

0 commit comments

Comments
 (0)