Skip to content

Commit 96e1937

Browse files
committed
Add hbm utilization and some vllm args
1 parent 278e455 commit 96e1937

4 files changed

Lines changed: 51 additions & 18 deletions

File tree

src/maxtext/eval/runner/common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def build_server_manager(cfg: dict, token: str | None) -> "VllmServerManager":
6565
max_num_seqs = int(max_num_seqs)
6666

6767
expert_parallel_size = int(cfg.get("expert_parallel_size") or 1)
68+
hbm_memory_utilization = float(cfg.get("hbm_memory_utilization") or 0.3)
6869

6970
server_env = {"HF_TOKEN": token} if token else None
7071

@@ -79,6 +80,7 @@ def build_server_manager(cfg: dict, token: str | None) -> "VllmServerManager":
7980
max_model_len=max_model_len,
8081
max_num_batched_tokens=max_num_batched_tokens,
8182
max_num_seqs=max_num_seqs,
83+
hbm_memory_utilization=hbm_memory_utilization,
8284
env=server_env,
8385
)
8486

@@ -120,6 +122,14 @@ def add_server_args(parser: argparse.ArgumentParser) -> None:
120122
"Chips allocated to the expert mesh axis (EP). "
121123
),
122124
)
125+
parser.add_argument(
126+
"--hbm_memory_utilization",
127+
type=float,
128+
default=0.3,
129+
help=(
130+
"Fraction of HBM reserved for KV cache."
131+
),
132+
)
123133
parser.add_argument("--hf_token", help="HuggingFace token for gated models.")
124134
parser.add_argument(
125135
"--gcs_results_path", help="Optional secondary GCS path to upload the results JSON."

src/maxtext/eval/runner/harness_runner.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,37 @@ def _map_results(raw_results: dict, tasks: list[str], task_map: dict[str, str])
7676
lm_task = task_map.get(task, task)
7777
task_r = results_section.get(lm_task, {})
7878

79-
acc = task_r.get("acc,none")
80-
if acc is None:
81-
acc = task_r.get("exact_match,none")
82-
if acc is None:
83-
acc = task_r.get("acc")
84-
if acc is None:
85-
acc = task_r.get("score")
86-
87-
acc_norm = task_r.get("acc_norm,none")
88-
if acc_norm is None:
89-
acc_norm = task_r.get("acc_norm")
79+
acc = None
80+
for key in (
81+
"acc,none",
82+
"exact_match,strict-match",
83+
"exact_match,flexible-extract",
84+
"exact_match,none",
85+
"acc",
86+
"score",
87+
):
88+
if task_r.get(key) is not None:
89+
acc = task_r[key]
90+
break
91+
92+
acc_norm = None
93+
for key in ("acc_norm,none", "acc_norm"):
94+
if task_r.get(key) is not None:
95+
acc_norm = task_r[key]
96+
break
9097

9198
if acc is not None:
9299
scores[f"{task}_accuracy"] = round(float(acc) * 100, 2)
93100
if acc_norm is not None:
94101
scores[f"{task}_accuracy_norm"] = round(float(acc_norm) * 100, 2)
102+
103+
if acc is None and task_r:
104+
logger.warning(
105+
"No known accuracy keys found for task '%s'. Available: %s",
106+
task,
107+
list(task_r.keys()),
108+
)
109+
95110
return scores
96111

97112

@@ -206,7 +221,10 @@ def run_harness(cfg: dict, hf_token: str | None = None) -> dict:
206221
benchmark="+".join(tasks),
207222
model_name=model_name,
208223
scores=scores,
209-
generation_stats={f"{backend}_config": raw_results.get("config", {})},
224+
generation_stats={
225+
f"{backend}_config": raw_results.get("config", {}),
226+
f"{backend}_results": raw_results.get("results", {}),
227+
},
210228
config=cfg,
211229
results_path=results_path,
212230
)

src/maxtext/eval/runner/server_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _build_app(llm: Any) -> Any:
3333
"""Return a FastAPI app that wraps an in-process vLLM LLM instance."""
3434
import fastapi # pylint: disable=import-outside-toplevel
3535
from vllm.sampling_params import SamplingParams # pylint: disable=import-outside-toplevel
36+
globals()["fastapi"] = fastapi
3637

3738
app = fastapi.FastAPI()
3839

@@ -195,6 +196,7 @@ class VllmServerManager:
195196
max_num_batched_tokens: Tokens per scheduler step (None = vLLM default).
196197
max_num_seqs: Max concurrent sequences (None = vLLM default).
197198
startup_timeout: Seconds to wait for /health to return healthy.
199+
hbm_memory_utilization: Fraction of HBM reserved for KV cache.
198200
env: Optional environment-variable overrides.
199201
additional_vllm_kwargs: Extra kwargs merged into the vLLM LLM() constructor.
200202
"""
@@ -213,6 +215,7 @@ def __init__(
213215
max_num_batched_tokens: int | None = None,
214216
max_num_seqs: int | None = None,
215217
startup_timeout: int = 600,
218+
hbm_memory_utilization: float = 0.3,
216219
env: dict[str, str] | None = None,
217220
additional_vllm_kwargs: dict | None = None,
218221
):
@@ -235,6 +238,7 @@ def __init__(
235238
self.max_num_batched_tokens = max_num_batched_tokens
236239
self.max_num_seqs = max_num_seqs
237240
self.startup_timeout = startup_timeout
241+
self.hbm_memory_utilization = hbm_memory_utilization
238242
self.env = env
239243
self.additional_vllm_kwargs = additional_vllm_kwargs or {}
240244

@@ -255,6 +259,8 @@ def start(self) -> None:
255259
# V1 engine architecture is otherwise preserved (tpu-inference plugin works),
256260
# and JAX/TPU is initialised exactly once inside LLM() in this process.
257261
os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
262+
os.environ.setdefault("NEW_MODEL_DESIGN", "1")
263+
os.environ.setdefault("SKIP_JAX_PRECOMPILE", "1")
258264

259265
if self.env:
260266
os.environ.update(self.env)
@@ -268,6 +274,7 @@ def start(self) -> None:
268274
"tensor_parallel_size": ici_tp,
269275
"max_model_len": self.max_model_len,
270276
"dtype": self.dtype,
277+
"gpu_memory_utilization": self.hbm_memory_utilization,
271278
}
272279
if self.max_num_batched_tokens is not None:
273280
vllm_kwargs["max_num_batched_tokens"] = self.max_num_batched_tokens

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
from flax import nnx
2121
import flax.linen as nn
2222
from jax import numpy as jnp
23-
from jax.sharding import AxisType, Mesh
23+
from jax.sharding import Mesh
2424
from maxtext.configs import pyconfig
2525
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
2626
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE
2727
from maxtext.utils import max_logging
28-
from maxtext.utils import maxtext_utils
2928
from maxtext.utils import model_creation_utils
3029

3130

@@ -99,9 +98,8 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh):
9998
self.cfg = vllm_config.model_config
10099
self.maxtext_config = generate_maxtext_config(vllm_config)
101100

102-
devices_array = maxtext_utils.create_device_mesh(self.maxtext_config)
103-
axis_types = tuple([AxisType.Auto] * len(self.maxtext_config.mesh_axes))
104-
self.mesh = Mesh(devices_array, self.maxtext_config.mesh_axes, axis_types=axis_types)
101+
# Model configuration
102+
self.mesh = mesh
105103
self.model_mode = MODEL_MODE_AUTOREGRESSIVE
106104
self.is_text_generation_model = True
107105

@@ -238,4 +236,4 @@ def load_weights(self, rng_key: jax.Array) -> None:
238236
model, _ = model_creation_utils.create_nnx_model(
239237
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
240238
)
241-
self.model = nnx.data(model)
239+
self.model = nnx.data(model)

0 commit comments

Comments
 (0)