Skip to content

Commit fc2dab7

Browse files
authored
Merge branch 'main' into vladk/sft-completion-fix2
2 parents d99f227 + b5f41ec commit fc2dab7

21 files changed

Lines changed: 634 additions & 454 deletions

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies
4141

4242
## 🔥 Latest news 🔥
4343

44+
* \[February 27, 2026\] New MaxText structure! MaxText has been restructured according to [RESTRUCTURE.md](https://github.com/AI-Hypercomputer/maxtext/blob/1b9e38aa0a19b6018feb3aed757406126b6953a1/RESTRUCTURE.md). Please feel free to share your thoughts and feedback.
4445
* \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported.
4546
* \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model.
4647
* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) are available.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
<!--
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
# Batch Size
18+
19+
This document explains the different concepts of "batch size" within MaxText and how to configure them to tune performance and manage memory.
20+
21+
## Per-Device Batch Size
22+
23+
`per_device_batch_size` is the number of training examples processed by a single device in one forward and backward pass. This value impacts the memory usage on each device and is a configuration parameter in `configs/base.yml`
24+
25+
## Global Batch Size
26+
27+
`global_batch_to_train` is the total number of training examples processed before the optimizer performs a single weight update. It is the effective batch size for training, calculated as:
28+
29+
`global_batch_to_train = per_device_batch_size x number_of_devices x gradient_accumulation_steps`
30+
31+
You can set `per_device_batch_size` and `gradient_accumulation_steps` in `configs/base.yml`.
32+
33+
`global_batch_to_load` is the total number of examples the data input pipeline loads from storage at once. It can be larger than the training batch size to optimize I/O performance, and is calculated as:
34+
35+
`global_batch_to_load` = `global_batch_size_to_train_on x expansion_factor_real_data`
36+
37+
When `expansion_factor_real_data > 1`, only a subset of hosts read data from the source (e.g., a GCS bucket). These "loading hosts" read more data than they need for their own devices and distribute the surplus to other "non-loading" hosts. This reduces the number of concurrent connections to the data source, which can significantly improve I/O throughput. When set to between 0 and 1, it's for grain pipeline to use a smaller chip count to read checkpoint from a larger chip count job. Details in https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/data_input_grain.md#using-grain.
38+
39+
## Gradient Accumulation Steps
40+
41+
`gradient_accumulation_steps` defines how many forward/backward passes are performed before the optimizer updates the model weights. The gradients from each pass are accumulated (summed). It is discussed in more detail [here](https://maxtext.readthedocs.io/en/latest/reference/core_concepts/tiling.html#gradient-accumulation).
42+
43+
For example, if `gradient_accumulation_steps` is set to `4`, the model will execute four forward and backward passes, sum the gradients, and then apply a single optimizer step. This achieves the same effective global batch size as quadrupling the `per_device_batch_size` with significantly less memory, but can potentially lead to lower MFU.
44+
45+
## Pipeline Microbatches
46+
47+
When pipeline parallelism is enabled, the global batch is split into smaller chunks called **microbatches**. These are fed into the pipeline sequentially, allowing different stages of the model to work on different microbatches simultaneously.
48+
49+
The `num_pipeline_microbatches` parameter in `configs/base.yml` configures how many of these smaller chunks the global batch is divided into. It must be a multiple of the total number of pipeline stages (`ici_pipeline_parallelism` x `dcn_pipeline_parallelism`).
50+
51+
The choice of `num_pipeline_microbatches` is a trade-off between reducing pipeline idle time and the computational efficiency within each stage. More microbatches reduces the "Pipeline Bubble" but leads to smaller matrix multiplications within each stage. Very small operations may not fully saturate the compute units of the hardware, potentially lowering arithmetic intensity.
52+
53+
## Batch Size Ramp-up
54+
55+
MaxText supports gradually increasing the batch size during the initial phase of training to improve stability, a technique also used in frameworks like [NVIDIA's NeMo Megatron](https://docs.nvidia.com/nemo-framework/user-guide/24.09/nemotoolkit/nlp/nemo_megatron/rampup_batch_size.html). This can be configured in `configs/base.yml`:
56+
57+
- Setting `enable_rampup_batch_size=True` activates the ramp-up process.
58+
- `per_device_batch_size_start`: The minimum batch size to start training on.
59+
- `per_device_batch_size`: The target batch size to stabilize on at the end of the ramp-up process.
60+
- `per_device_batch_size_increment`: How much batch size increases for each ramp-up stage.
61+
- `global_rampup_samples`: The total number of samples to process across all ramp-up stages.
62+
63+
The ramp-up is based on the number of samples processed, not the number of training steps. Each stage processes an equal number of samples before batch size is increased.
64+
65+
The number of stages is determined by:
66+
67+
`num_increments = (per_device_batch_size - per_device_batch_size_start) / per_device_batch_size_increment`
68+
69+
The total number of ramp-up samples (`global_rampup_samples`) is then distributed equally across these stages. The number of samples processed in each stage is determined by:
70+
71+
`samples_per_increment = global_rampup_samples / num_increments`
72+
73+
During training, the model processes `samples_per_increment` samples at the current batch size. Once this threshold is reached, the batch size is increased by `per_device_batch_size_increment` until the target `per_device_batch_size` is reached. This entire process is managed by the `RampupBatchManager` class.
74+
75+
## Reinforcement Learning (RL) Batch Size
76+
77+
The batch size parameters for RL training are defined in `configs/post_train/rl.yml`:
78+
79+
- `batch_size` refers to the number of unique prompts loaded from the dataset in a single batch. For instance, `batch_size=1` means one prompt is processed at a time by the data loader.
80+
81+
- `num_generations` is the number of times the policy generates multiple responses for a given prompt within a single training step.
82+
83+
- The effective training batch is the total number of prompt-response pairs used in a training step, calculated as `batch_size x num_generations`. It is determined by the number of responses generated for each prompt, which is configured by `num_generations`.
84+
85+
- `micro_batch_size` is used to split the batch of prompt-response pairs into smaller chunks for memory management. This enables overlapping the rollout phase (generating responses) of one micro-batch with the training phase (updating model weights) of the previous micro-batch, which can improve hardware utilization. A value of `-1` means no micro-batching is enabled.

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradie
779779
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
780780
adam_eps_root: 0. # A small constant applied to denominator inside the square root.
781781
adam_weight_decay: 0.1 # AdamW Weight decay
782+
adamw_mask: [] # List of parameter names/patterns to exclude from weight decay in AdamW, like ['bias', '.*norm', '.*ln.*'].
782783
mu_dtype: "" # data type to store "mu" of AdamW tracking the first moment. Inherits from weight_dtype if unset.
783784
# Setting nu_dtype is not yet supported by optax, instead nu_dtype is always inherited from weights.
784785
# See b/399961932 for more.

src/maxtext/configs/pyconfig.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
135135
new_value = [new_value]
136136

137137
# An empty value provided in the configuration is treated as None
138-
if key in ("hf_train_files", "hf_eval_files") and new_value == "":
138+
if (
139+
key in ("hf_train_files", "hf_eval_files", "hf_access_token", "hf_name", "hf_data_dir", "hf_eval_split")
140+
and new_value == ""
141+
):
139142
new_value = None
140143

141144
if key == "run_name" and new_value is None:

src/maxtext/configs/types.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -994,11 +994,11 @@ class HfDataset(BaseModel):
994994
"""Configuration specific to HuggingFace datasets."""
995995

996996
hf_path: str = Field("", description="Path of the Hugging Face dataset.")
997-
hf_name: str = Field("", description="Name of the Hugging Face dataset.")
998-
hf_data_dir: PathStr = Field("", description="Data directory for the HF dataset.")
999-
hf_train_files: Optional[str] = Field(None, description="Files for the HF training split.")
1000-
hf_eval_split: str = Field("", description="Name of the HF evaluation split.")
1001-
hf_eval_files: Optional[str] = Field(None, description="Files for the HF evaluation split.")
997+
hf_name: None | str = Field(None, description="Name of the Hugging Face dataset.")
998+
hf_data_dir: None | PathStr = Field(None, description="Data directory for the HF dataset.")
999+
hf_train_files: None | str = Field(None, description="Files for the HF training split.")
1000+
hf_eval_split: None | str = Field(None, description="Name of the HF evaluation split.")
1001+
hf_eval_files: None | str = Field(None, description="Files for the HF evaluation split.")
10021002
hf_access_token: None | str = Field(None, description="Hugging Face API access token.")
10031003

10041004

@@ -1175,6 +1175,12 @@ class AdamW(BaseModel):
11751175
description="A small constant for numerical stability (epsilon), applied inside of the square root.",
11761176
)
11771177
adam_weight_decay: float = Field(0.1, description="Weight decay regularization.")
1178+
adamw_mask: list[str] = Field(
1179+
default_factory=list,
1180+
description=(
1181+
"List of parameter names/patterns to exclude from weight decay in AdamW," " like ['bias', '.*norm', '.*ln.*']"
1182+
),
1183+
)
11781184
mu_dtype: str = Field(
11791185
"",
11801186
description="Data type for 'mu' (first moment) in AdamW. Inherits from weight_dtype if empty.",

src/maxtext/input_pipeline/input_pipeline_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
187187
A string representing the completion formatted by the chat template.
188188
"""
189189
prompt_completion_tokens = tokenizer_model.apply_chat_template(round_msgs, add_generation_prompt=False, tokenize=True)
190-
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=False, tokenize=True)
190+
# include generation_prompt as part of the prompt tokens
191+
prompt_tokens = tokenizer_model.apply_chat_template(round_msgs[:-1], add_generation_prompt=True, tokenize=True)
191192

192193
# attention masks in BatchEncoding are effectively ignored
193194
if hasattr(prompt_completion_tokens, INPUT_TOKENS_KEY):
@@ -209,7 +210,7 @@ def _get_completion_in_chat_template(tokenizer_model, round_msgs):
209210

210211
def apply_chat_template(example, tokenizer_model, data_column_name):
211212
"""Formats conversational data by applying the tokenizer's chat template
212-
and identifying prompt/completion segments.
213+
and identifying prompt/completion segments for SFT masking.
213214
214215
Args:
215216
example: A dictionary containing conversational data. It is expected to have a key
@@ -223,9 +224,10 @@ def apply_chat_template(example, tokenizer_model, data_column_name):
223224
The modified `example` dictionary.
224225
- The `data_column_name` column will be updated to a list of
225226
messages, each formatted according to the tokenizer's chat template.
226-
- A new column named "is_prompt" will be added, where `True`
227-
indicates a system message or a user message (prompt) and `False` indicates an assistant
228-
message (completion).
227+
- A new column "is_prompt" is added, where `True` indicates the
228+
tokens contain the system message, user message, and generation
229+
prompt (if applicable). `False` indicates the expected LLM
230+
completion, excluding the assistant's start tokens.
229231
"""
230232
messages = []
231233
is_prompt = []
@@ -239,7 +241,7 @@ def apply_chat_template(example, tokenizer_model, data_column_name):
239241
elif message["role"] == "user":
240242
round_msgs.append(message)
241243
prompt_in_chat_template = tokenizer_model.apply_chat_template(
242-
round_msgs, add_generation_prompt=False, tokenize=False
244+
round_msgs, add_generation_prompt=True, tokenize=False
243245
)
244246
messages.append(prompt_in_chat_template)
245247
is_prompt.append(True)

0 commit comments

Comments
 (0)