Skip to content

Commit 788da99

Browse files
authored
Merge branch 'main' into dandragona/mhc_k4_shortcut
2 parents eeb0154 + 9071a8b commit 788da99

81 files changed

Lines changed: 3450 additions & 247 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/AddLabel.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ jobs:
112112
// Ignore the current running workflow
113113
if (checkRun.name.endsWith(context.job)) continue
114114
115-
if (checkRun.status !== 'completed' || checkRun.conclusion !== 'success') {
115+
if (checkRun.status !== 'completed' || !['success', 'skipped'].includes(checkRun.conclusion)) {
116116
core.info(`Waiting for check: ${checkRun.name} (Status: ${checkRun.status}, Conclusion: ${checkRun.conclusion})`);
117117
return; // Exit without failing
118118
}

.github/workflows/gemini-investigate.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
settings: |-
8686
{
8787
"model": {
88-
"maxSessionTurns": 15
88+
"maxSessionTurns": 50
8989
},
9090
"mcpServers": {
9191
"github": {

.github/workflows/run_tests_coordinator.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
strategy:
6767
fail-fast: false
6868
matrix:
69-
worker_group: ${{ fromJSON(contains(inputs.flavor, 'cpu-unit') && '[1, 2]' || '[1]') }}
69+
worker_group: ${{ fromJSON(contains(inputs.flavor, 'cpu-unit') && '[1, 2, 3, 4]' || '[1]') }}
7070

7171
uses: ./.github/workflows/run_tests_against_package.yml
7272
with:
@@ -158,6 +158,6 @@ jobs:
158158
is_scheduled_run: ${{ inputs.is_scheduled_run }}
159159
maxtext_installed: ${{ inputs.maxtext_installed }}
160160
worker_group: ${{ matrix.worker_group }}
161-
total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 2 || 1 }}
161+
total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 4 || 1 }}
162162
maxtext_sha: ${{ inputs.maxtext_sha }}
163163
is_update_hlo: ${{ inputs.is_update_hlo }}

src/maxtext/checkpoint_conversion/standalone_scripts/llama4_ckpt_unscanned.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,9 +600,8 @@ def _convert_pytorch_to_jax_weights(base_model_path: str, model_size: str, model
600600
for i, ckpt_path in enumerate(ckpt_paths):
601601
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
602602
# NOTE: starting in PT2.6, `weights_only` was switched from the default of `False` to `True`
603-
# thus we need to specify this or else loading will fail
604603
chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = torch.load(
605-
ckpt_path, map_location="cpu", weights_only=False
604+
ckpt_path, map_location="cpu", weights_only=True
606605
)
607606
chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))]
608607
# map weight names if they use HuggingFace instead of PyTorch convention

src/maxtext/checkpoint_conversion/standalone_scripts/llama_ckpt_conversion_inference_only.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def convert(base_model_path, maxtext_model_path, model_size):
157157
for i, ckpt_path in enumerate(ckpt_paths):
158158
print(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
159159

160-
checkpoint = torch.load(ckpt_path, map_location="cpu")
160+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
161161
pytorch_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint
162162
print("memory usage in GB: ", psutil.Process().memory_info().rss / (1024 * 1024))
163163

src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def convert_lora_weights_to_jax_weights(lora_config: dict, model_size: str):
428428

429429
max_logging.log(f"Loading the lora model from {lora_config['lora_model_path']}")
430430
# Load LoRA model weights
431-
lora_chkpt_vars = torch.load(lora_config["lora_model_path"])
431+
lora_chkpt_vars = torch.load(lora_config["lora_model_path"], weights_only=True)
432432
lora_chkpt_vars = _NamespaceMapper(lora_chkpt_vars)
433433

434434
jax_weights_lora = {
@@ -1112,9 +1112,8 @@ def _convert_pytorch_to_jax_weights(base_model_path: str, model_size: str, model
11121112
for i, ckpt_path in enumerate(ckpt_paths):
11131113
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
11141114
# NOTE: starting in PT2.6, `weights_only` was switched from the default of `False` to `True`
1115-
# thus we need to specify this or else loading will fail
11161115
chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = torch.load(
1117-
ckpt_path, map_location="cpu", weights_only=False
1116+
ckpt_path, map_location="cpu", weights_only=True
11181117
)
11191118
chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))]
11201119
# map weight names if they use HuggingFace instead of PyTorch convention

src/maxtext/configs/pyconfig.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,11 @@ def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfi
324324

325325
# 2. Get overrides from CLI and kwargs
326326
cli_cfg = omegaconf.OmegaConf.from_cli(cli_args)
327+
if "hf_access_token" in cli_cfg:
328+
logger.warning(
329+
"WARNING: Passing 'hf_access_token' via command-line arguments is deprecated and insecure because it makes "
330+
"your token visible in 'ps' and shell history. Please set the 'HF_TOKEN' environment variable instead."
331+
)
327332
kwargs_cfg = omegaconf.OmegaConf.create(kwargs)
328333
overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg)
329334

src/maxtext/eval/README.md

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# MaxText vLLM Eval Framework
2+
3+
A vLLM-native evaluation framework for MaxText models supporting harness-based eval (lm-eval, evalchemy) and custom datasets.
4+
5+
## Quick Start
6+
7+
All runners share a single entry point:
8+
9+
```bash
10+
python -m maxtext.eval.runner.run --runner <eval|lm_eval|evalchemy> [flags]
11+
```
12+
13+
### Custom dataset (MLPerf OpenOrca, ROUGE scoring, Other)
14+
15+
```bash
16+
python -m maxtext.eval.runner.run \
17+
--runner eval \
18+
--config src/maxtext/eval/configs/mlperf.yml \
19+
--checkpoint_path gs://<bucket>/checkpoints/0/items \
20+
--model_name llama3.1-8b \
21+
--hf_path meta-llama/Llama-3.1-8B-Instruct \
22+
--base_output_directory gs://<bucket>/ \
23+
--run_name eval_run \
24+
--max_model_len 8192 \
25+
--hf_token $HF_TOKEN
26+
```
27+
28+
HF safetensors mode (no MaxText checkpoint):
29+
30+
```bash
31+
python -m maxtext.eval.runner.run \
32+
--runner eval \
33+
--config src/maxtext/eval/configs/mlperf.yml \
34+
--hf_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
35+
--model_name tinyllama \
36+
--base_output_directory gs://<bucket>/ \
37+
--run_name eval_test \
38+
--hf_mode \
39+
--num_samples 20 \
40+
--max_model_len 2048 \
41+
--tensor_parallel_size 1
42+
```
43+
44+
### LM Eval
45+
46+
Requires: `pip install "lm_eval[api]"`
47+
48+
```bash
49+
python -m maxtext.eval.runner.run \
50+
--runner lm_eval \
51+
--checkpoint_path gs://<bucket>/checkpoints/0/items \
52+
--model_name qwen3-30b-a3b \
53+
--hf_path Qwen/Qwen3-30B-A3B \
54+
--tasks gsm8k \
55+
--base_output_directory gs://<bucket>/ \
56+
--run_name my_run \
57+
--max_model_len 8192 \
58+
--tensor_parallel_size 8 \
59+
--expert_parallel_size 8 \
60+
--hf_token $HF_TOKEN
61+
```
62+
63+
### Evalchemy
64+
65+
Requires: `pip install git+https://github.com/mlfoundations/evalchemy.git`
66+
67+
```bash
68+
python -m maxtext.eval.runner.run \
69+
--runner evalchemy \
70+
--checkpoint_path gs://<bucket>/checkpoints/0/items \
71+
--model_name llama3.1-8b \
72+
--hf_path meta-llama/Llama-3.1-8B-Instruct \
73+
--tasks ifeval math500 gpqa_diamond \
74+
--base_output_directory gs://<bucket>/ \
75+
--run_name eval_run \
76+
--max_model_len 8192 \
77+
--tensor_parallel_size 4 \
78+
--hf_token $HF_TOKEN
79+
```
80+
81+
## Common Flags
82+
83+
| Flag | Description |
84+
|---|---|
85+
| `--checkpoint_path` | MaxText Orbax checkpoint path. Enables `MaxTextForCausalLM` mode. |
86+
| `--model_name` | MaxText model name (e.g. `llama3.1-8b`) |
87+
| `--hf_path` | HF model ID or local path |
88+
| `--max_model_len` | vLLM max context length. |
89+
| `--tensor_parallel_size` | Chips per model replica |
90+
| `--expert_parallel_size` | Chips for the expert mesh axis |
91+
| `--data_parallel_size` | Number of model replicas |
92+
| `--hbm_memory_utilization` | Fraction of HBM reserved for KV cache |
93+
| `--hf_token` | HF token (or set `HF_TOKEN` env var) |
94+
| `--hf_mode` | HF safetensors mode, no MaxText checkpoint loading |
95+
| `--server_host` / `--server_port` | vLLM server address (default: localhost:8000) |
96+
| `--max_num_batched_tokens` | vLLM tokens per scheduler step |
97+
| `--max_num_seqs` | vLLM max concurrent sequences |
98+
| `--gcs_results_path` | GCS path to upload results JSON |
99+
| `--log_level` | Logging verbosity (default: INFO) |
100+
101+
Custom `eval` specific:
102+
103+
| Flag | Description |
104+
|---|---|
105+
| `--config` | Benchmark YAML config (required) |
106+
| `--num_samples` | Limit eval samples |
107+
| `--max_tokens` | Max tokens per generation |
108+
| `--temperature` | Sampling temperature (default: 0.0) |
109+
| `--concurrency` | HTTP request concurrency (default: 64) |
110+
111+
Harness `lm_eval` / `evalchemy` specific:
112+
113+
| Flag | Description |
114+
|---|---|
115+
| `--tasks` | Space-separated task names |
116+
| `--num_fewshot` | Few-shot examples per task (default: 0) |
117+
| `--num_samples` | Limit samples per task (default: full dataset) |
118+
119+
## Eval on RL Checkpoints
120+
121+
122+
123+
Example (Qwen3-30B-A3B, v6e-8):
124+
125+
```bash
126+
STEP=244
127+
MODEL=qwen3-30b-a3b
128+
HF_PATH=Qwen/Qwen3-30B-A3B
129+
CHECKPOINT=gs://<bucket>/run/checkpoints/actor/${STEP}/model_params
130+
OUTPUT=gs://<bucket>/eval/
131+
132+
python -m maxtext.eval.runner.run \
133+
--runner lm_eval \
134+
--checkpoint_path ${CHECKPOINT} \
135+
--model_name ${MODEL} \
136+
--hf_path ${HF_PATH} \
137+
--tasks gsm8k \
138+
--base_output_directory ${OUTPUT} \
139+
--run_name rl_${MODEL}_step${STEP} \
140+
--max_model_len 4096 \
141+
--tensor_parallel_size 8 \
142+
--expert_parallel_size 8 \
143+
--num_samples 20 \
144+
--hf_token $HF_TOKEN
145+
```
146+
147+
148+
## Adding a Custom Benchmark
149+
150+
1. Implement `BenchmarkDataset` in `src/maxtext/eval/datasets/`:
151+
152+
```python
153+
from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest
154+
155+
class MyDataset(BenchmarkDataset):
156+
name = "my_benchmark"
157+
158+
def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]:
159+
# load dataset, build prompts, return SampleRequest list
160+
```
161+
162+
2. Register in `src/maxtext/eval/datasets/registry.py`:
163+
164+
```python
165+
from maxtext.eval.datasets.my_dataset import MyDataset
166+
DATASET_REGISTRY["my_benchmark"] = MyDataset
167+
```
168+
169+
3. Add a scorer in `src/maxtext/eval/scoring/` and register it in `src/maxtext/eval/scoring/registry.py`.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Base evaluation configuration.
2+
3+
temperature: 0.0
4+
concurrency: 64
5+
server_host: "localhost"
6+
server_port: 8000
7+
tensor_parallel_size: 4
8+
num_samples: null
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# MLPerf OpenOrca evaluation config.
2+
3+
benchmark: "mlperf_openorca"
4+
max_tokens: 1024
5+
num_samples: 5000

0 commit comments

Comments
 (0)