You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/guides/optimization/benchmark_and_performance.md
+122Lines changed: 122 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -61,6 +61,128 @@ To use a custom policy, set `remat_policy` to `custom` and specify the layers in
61
61
-`device`: The activation remains on the TPU device.
62
62
-`Remat`: Rematerialization is performed during the backward pass.
63
63
64
+
**Automatic remat policy search with the Estimator**
65
+
66
+
Finding the optimal remat policy and batch size manually can be time-consuming. MaxText provides an **Estimator** tool (`estimator.py`) that automates this search using [Ahead-of-Time (AOT) compilation](../monitoring_and_debugging/features_and_diagnostics.md#ahead-of-time-compilation-aot). It leverages `train_compile` to test whether a given configuration causes an Out-Of-Memory (OOM) error *without* requiring the target hardware.
67
+
68
+
The estimator supports two modes:
69
+
70
+
1.**Search both batch size and remat policy** (when `per_device_batch_size` is *not* provided): It finds the Pareto frontier of batch size vs. remat policy by iterating through policies from full remat to full device, using binary search for the largest non-OOM batch size at each step.
71
+
2.**Search remat policy only** (when `per_device_batch_size`*is* provided): It finds the least aggressive (fastest) remat policy that fits in memory for the given fixed batch size.
72
+
73
+
*Mode 1 example: Search both batch size and remat policy (Llama 3.1 405B on tpu7x-1024)*
74
+
75
+
```bash
76
+
python -m maxtext.utils.estimator \
77
+
maxtext/configs/base.yml \
78
+
compile_topology=tpu7x-1024 \
79
+
compile_topology_num_slices=1 \
80
+
model_name=llama3.1-405b \
81
+
max_target_length=32768 \
82
+
ici_context_parallelism=8 \
83
+
ici_fsdp_parallelism=-1 \
84
+
log_config=False \
85
+
write_estimator_result=False
86
+
```
87
+
88
+
*Mode 2 example: Search best remat policy for a fixed batch size (DeepSeek3 671B on v5p-1024)*
-`write_estimator_result=True`: Writes runnable training commands to `remat_commands_from_estimator.txt`.
108
+
-`write_estimator_result=False` (default): Prints results to stdout only.
109
+
- You can pin specific tensor remat actions (e.g., `context=offload`) to constrain the search space.
110
+
111
+
*Advanced example: Search remat policy with XLA tuning flags (DeepSeek3 671B on tpu7x-512)*
112
+
113
+
For production workloads you often want to combine the estimator with XLA compiler tuning flags for SparseCore offloading, latency-hiding scheduling, and other optimizations. Set these via `LIBTPU_INIT_ARGS` before invoking the estimator:
This example fixes `per_device_batch_size=4.0` so the estimator runs in **Mode 2** (policy-only search), finding the least aggressive remat policy that fits the DeepSeek3 671B model on a tpu7x-512 pod. The XLA flags enable SparseCore collective offloading and latency-hiding scheduling, which affect compilation memory layout and thus the OOM boundary.
185
+
64
186
### Low precision training
65
187
66
188
MaxText supports quantization via QWIX. To enable this, set `use_qwix_quantization=true`.
Copy file name to clipboardExpand all lines: src/maxtext/configs/base.yml
+3Lines changed: 3 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -948,6 +948,9 @@ compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g.
948
948
compile_topology: ''# Target hardware version, e.g. 'v5e-256'
949
949
compile_topology_num_slices: -1# Number of target slices, set to a positive integer.
950
950
951
+
# MaxText Estimator configs
952
+
write_estimator_result: False
953
+
951
954
decode_sampling_strategy: "greedy"# decode_sampling_strategy should be one of greedy, weighted, nucleus, topk, or composite(top_k -> top_p -> weighted temperature)
952
955
decode_sampling_nucleus_p: -1# set if you're doing nucleus / top-p
953
956
decode_sampling_top_k: 0# set if you're doing top-k
0 commit comments