Skip to content

Commit 8bcb4c1

Browse files
committed
Merge branch 'main' of github.com:AI-Hypercomputer/maxtext into shuningjin-fix
2 parents 75e87c8 + c9df820 commit 8bcb4c1

9 files changed

Lines changed: 106 additions & 66 deletions

File tree

docs/tutorials/post_training_index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,5 @@ posttraining/rl_on_multi_host.md
7070
posttraining/knowledge_distillation.md
7171
posttraining/multimodal.md
7272
posttraining/full_finetuning.md
73+
posttraining/gepa_optimization.md
7374
```
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# GEPA Prompt Optimization for MaxText
2+
3+
## Overview
4+
5+
This document explains how to use **GEPA** (Generic Evaluation and Prompt Adaptation) to optimize system prompts for MaxText models. GEPA is an evolutionary framework ([GitHub Repository](https://github.com/gepa-ai/gepa), [Paper](https://arxiv.org/abs/2507.19457)) that iteratively refines prompts based on evaluation feedback, helping models perform better on specific tasks. A complete, runnable example notebook is provided in the repository at [maxtext_with_gepa.ipynb](../../../src/maxtext/examples/maxtext_with_gepa.ipynb).
6+
7+
## How GEPA Optimization Works
8+
9+
The optimization process relies on a collaborative loop between two Language Models (LMs):
10+
11+
1. **Target Model**: This is the model being optimized. It attempts to solve the evaluation problems (e.g., AIME questions) using the current candidate system prompt. For example, this can be a `Qwen3-4B` model hosted on a local vLLM server.
12+
2. **Reflection LM**: This model reviews the reasoning traces and failures of the Target Model. It identifies recurring errors (e.g., mathematical errors or formatting issues) and proposes targeted updates to the system prompt. For example, a model like `Gemini 3 Flash Preview` can be used as the reflection model.
13+
14+
### The Evolutionary Loop
15+
16+
1. **Propose**: The Reflection LM proposes a new system prompt based on errors seen in previous runs.
17+
2. **Evaluate (Subsample)**: The Target Model solves a small random subset of problems using the new prompt. This serves as a quick screening step.
18+
3. **Full Evaluation**: If the subsample score improves, the prompt is evaluated on the full validation set.
19+
4. **Selection**: Successful prompts are added to the candidate pool, driving the evolution of domain-specific heuristics (such as circle packing formulas or prime factorization strategies) that eventually form the final optimized prompt.
20+
21+
### Synergy via Prompt Merging
22+
23+
A key feature used during the AIME experimentation was **Prompt Merging** (`use_merge=True`).
24+
25+
As the evolutionary process runs, different branches might discover distinct, valid heuristics (e.g., one branch learns a rule for Geometry, while another learns a rule for Combinatorics).
26+
27+
- **How It Works**: Instead of forcing a choice between these two distinct winning paths, GEPA attempts to merge them. The Reflection LM is instructed to synthesize the instructions from both candidates, deduplicating content and integrating the new knowledge into a single, unified system prompt.
28+
- **Why It Is Important**: Merging allows the optimization to achieve synergetic gains. By combining orthogonal prompt improvements, the final system prompt acts as a comprehensive "cheat sheet" covering multiple mathematical domains simultaneously, which is critical for the broad range of problems found in datasets like AIME.
29+
30+
## Robust Evaluation with MathAdapter
31+
32+
A critical component of the optimization setup is the custom `MathAdapter`.
33+
34+
### Why the Custom Logic?
35+
36+
Standard evaluation pipelines often use simple regular expressions to extract the answer from a model's response (e.g., capturing everything inside `\boxed{}`). However, competition math problems like AIME frequently require answers formatted in complex LaTeX (such as fractions `\boxed{\frac{a}{b}}` or nested expressions). A naive regex will break on the first closing brace `}`, failing to capture the full answer.
37+
38+
The `MathAdapter` implements a robust **brace-counting parser** that correctly tracks nested LaTeX structures, ensuring the complete mathematical expression is extracted.
39+
40+
### Why It Is Crucial for GEPA
41+
42+
Prompt optimization frameworks like GEPA are highly sensitive to the reward signal (the evaluation score). If a model generates a correct answer but the evaluation logic fails to parse it correctly (a False Negative), the optimization loop receives faulty feedback. This noisy signal can cause GEPA to discard beneficial prompt mutations, ultimately leading to performance degradation instead of improvement.
43+
44+
## Tutorial Notebook
45+
46+
A complete, runnable tutorial is available in the repository as a Jupyter Notebook:
47+
[maxtext_with_gepa.ipynb](../../../src/maxtext/examples/maxtext_with_gepa.ipynb) (provided as an example)
48+
49+
This notebook walks through:
50+
51+
- Streaming the dataset.
52+
- Setting up a custom `MathAdapter` for float extraction.
53+
- Running the GEPA evolutionary loop.
54+
- Comparing accuracy before and after optimization.
55+
56+
> [!NOTE]
57+
> In this tutorial, we utilize an out-of-tree version of vLLM tailored for MaxText models via the `maxtext_vllm_adapter`. For more information on serving MaxText models with vLLM, refer to the [Inference Guide](../inference.md).
58+
59+
## Pointing GEPA to the Local vLLM Server
60+
61+
By default, optimization frameworks might expect to communicate with remote model APIs. In our setup, we route the evaluation traffic to the locally running MaxText model on the vLLM server by overriding the API base URL.
62+
63+
This is achieved by setting the following environment variables in the script:
64+
65+
```python
66+
os.environ["OPENAI_API_BASE"] = "http://localhost:8000/v1"
67+
os.environ["OPENAI_API_KEY"] = "fake-key"
68+
```
69+
70+
When the `MathAdapter` initializes the model (e.g., specifying `openai/Qwen/Qwen3-4B-Instruct-2507`), `litellm` (used by GEPA under the hood) intercepts the request and directs it to the local server running on the TPU host instead of attempting to connect to a remote OpenAI endpoint.
71+
72+
## Case Study: AIME Prompt Optimization
73+
74+
In our experiments with the **AIME (American Invitational Mathematics Examination)** dataset, we utilized **Qwen3-4B** as the Target Model (hosted locally via vLLM) and **Gemini 3 Flash Preview** as the Reflection LM.
75+
76+
With this setup, GEPA successfully improved the model's accuracy from **49.0% to 54.0%** (a 5% absolute improvement).
77+
78+
The optimization process discovered that injecting specific domain knowledge and heuristics (like circle packing formulas and square-free parts for number theory) significantly helped the model solve complex competition-level problems.

src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ frozenlist>=1.8.0
8181
fsspec>=2026.1.0
8282
gast>=0.6.0
8383
gcsfs>=2026.1.0
84+
gepa>=0.1.1
8485
gguf>=0.17.1
8586
google-api-core>=2.28.1
8687
google-api-python-client>=2.187.0

src/maxtext/configs/post_train/rl.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ rl:
6868
degenerate_group_masking: True
6969
# Upper-bound clipping epsilon for GRPO loss; defaults to grpo_epsilon when null.
7070
epsilon_high: null
71+
# Number of model keys to chunk for resharding tensors between trainer and rollout devices.
72+
# If null, the entire model is resharded at once.
73+
reshard_chunk_size: null
7174

7275

7376
# ====== Models ======

src/maxtext/configs/types.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -436,26 +436,16 @@ class Quantization(BaseModel):
436436
)
437437
weight_sparsity_n: int | None = Field(
438438
None,
439-
description=(
440-
"The 'N' in N:M sparsity, representing the maximum number of non-zero"
441-
" values in each block."
442-
),
439+
description=("The 'N' in N:M sparsity, representing the maximum number of non-zero" " values in each block."),
443440
)
444441
weight_sparsity_m: int | None = Field(
445442
None,
446-
description=(
447-
"The 'M' in N:M sparsity, representing the number of values in each"
448-
" block."
449-
),
450-
)
451-
weight_sparsity_update_step: int = Field(
452-
10, description="The step size for updating weight sparsity masks."
443+
description=("The 'M' in N:M sparsity, representing the number of values in each" " block."),
453444
)
445+
weight_sparsity_update_step: int = Field(10, description="The step size for updating weight sparsity masks.")
454446
weight_sparsity_start_step: int = Field(
455447
50,
456-
description=(
457-
"The first number of steps before updating the sparsity masks."
458-
),
448+
description=("The first number of steps before updating the sparsity masks."),
459449
)
460450

461451

@@ -1822,6 +1812,13 @@ class RL(BaseModel):
18221812
epsilon_high: Optional[float] = Field(
18231813
None, description="Upper-bound clipping epsilon for GRPO loss. Defaults to epsilon when None (agentic only)."
18241814
)
1815+
reshard_chunk_size: Optional[int] = Field(
1816+
None,
1817+
description=(
1818+
"Number of model keys to chunk for resharding tensors between trainer and rollout devices."
1819+
"If None, no chunking is applied, which may lead to OOM errors if tensors are too large."
1820+
),
1821+
)
18251822

18261823

18271824
class RLDataset(BaseModel):

src/maxtext/models/gemma4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def __call__(
370370

371371
next_layer_addition = mlp_lnx + residual
372372
layer_output = next_layer_addition
373-
layer_output = layer_output * self.layer_scalar.value
373+
layer_output = layer_output * jnp.asarray(self.layer_scalar.value, cfg.dtype)
374374

375375
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)
376376

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def create_rl_components(
405405
rollout_vllm_max_num_seqs=trainer_config.max_num_seqs,
406406
rollout_vllm_async_scheduling=trainer_config.async_scheduling,
407407
rollout_vllm_server_mode=trainer_config.rl.use_agentic_rollout,
408+
rollout_vllm_reshard_chunk_size=trainer_config.rl.reshard_chunk_size,
408409
rollout_vllm_kwargs={
409410
"hf_overrides": trainer_config.vllm_hf_overrides,
410411
"enable_expert_parallel": sampler_config.enable_expert_parallel,

src/maxtext/trainers/pre_train/train.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -387,31 +387,6 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
387387
else:
388388
grads = raw_grads
389389

390-
# fp8 fix: sanitize NaN OWG (overwrite-with-gradient) stats before apply_gradients.
391-
# Under FSDP, the fp8 output gradient amax can be NaN at step 0, which propagates into
392-
# amax_history and corrupts future steps. Replace NaN OWG entries with the current state
393-
# values (skip the amax update for that step) instead of letting NaN flow through.
394-
# Also restore OWG values after apply_gradients to bypass optimizer corruption
395-
# (Adam should not update fp8 scale/amax_history).
396-
fp8_stats = dict(grads).get(maxtext_utils.OVERWRITE_WITH_GRADIENT, None)
397-
if fp8_stats is not None:
398-
if maxtext_utils.OVERWRITE_WITH_GRADIENT in state.params:
399-
current_fp8 = state.params[maxtext_utils.OVERWRITE_WITH_GRADIENT]
400-
fp8_stats = jax.tree_util.tree_map(
401-
lambda new, cur: jnp.where(jnp.isnan(new), cur, new),
402-
fp8_stats,
403-
current_fp8,
404-
)
405-
else:
406-
fp8_stats = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, nan=0.0), fp8_stats)
407-
grads = dict(grads)
408-
grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
409-
# Zero out any remaining NaN in float gradients to prevent param corruption
410-
grads = jax.tree_util.tree_map(
411-
lambda x: jnp.nan_to_num(x, nan=0.0) if jnp.issubdtype(x.dtype, jnp.floating) else x,
412-
grads,
413-
)
414-
415390
if config.optimizer_memory_host_offload:
416391
state = state.replace(
417392
opt_state=jax.device_put(
@@ -462,25 +437,7 @@ def move(path, value):
462437
)
463438
else:
464439
new_state = state.apply_gradients(grads=full_grads)
465-
# fp8 fix: restore sanitized OWG values, bypassing any optimizer update to fp8 stats.
466-
if fp8_stats is not None:
467-
new_params = dict(new_state.params)
468-
new_params[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
469-
new_state = new_state.replace(params=new_params)
470-
has_batch_stats = (
471-
config.weight_sparsity_n
472-
and config.weight_sparsity_m
473-
and bool(aux.get("batch_stats"))
474-
and isinstance(state.params, dict)
475-
and "batch_stats" in state.params
476-
)
477440

478-
if has_batch_stats:
479-
new_params = dict(new_state.params)
480-
new_params["batch_stats"] = max_utils.unbox_logicallypartioned(
481-
aux["batch_stats"]
482-
)
483-
new_state = new_state.replace(params=new_params)
484441
# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
485442
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
486443
target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")

tests/sparsity_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import tempfile
1919
from absl.testing import absltest
2020
from absl.testing import parameterized
21+
import pytest
2122
from maxtext.trainers.pre_train import train
2223
from tests.utils.test_helpers import get_test_config_path
2324

@@ -45,9 +46,8 @@ class Train(parameterized.TestCase):
4546
"use_sparsity": True,
4647
},
4748
)
48-
def test_different_quant_sparsity_configs(
49-
self, quantization: str, use_sparsity: bool
50-
):
49+
@pytest.mark.tpu_only
50+
def test_different_quant_sparsity_configs(self, quantization: str, use_sparsity: bool):
5151
test_tmpdir = os.environ.get("TEST_TMPDIR", gettempdir())
5252
outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", test_tmpdir)
5353
args = [
@@ -81,11 +81,13 @@ def test_different_quant_sparsity_configs(
8181
f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}",
8282
]
8383
if use_sparsity:
84-
args.extend([
85-
"weight_sparsity_n=2",
86-
"weight_sparsity_m=4",
87-
"weight_sparsity_update_step=1",
88-
])
84+
args.extend(
85+
[
86+
"weight_sparsity_n=2",
87+
"weight_sparsity_m=4",
88+
"weight_sparsity_update_step=1",
89+
]
90+
)
8991
train_main(args)
9092

9193

0 commit comments

Comments
 (0)