Skip to content

Commit f0a04cd

Browse files
Merge pull request #2891 from AI-Hypercomputer:chengnuojin-estimator
PiperOrigin-RevId: 933448303
2 parents 83412ca + f33cb23 commit f0a04cd

7 files changed

Lines changed: 846 additions & 186 deletions

File tree

docs/guides/optimization/benchmark_and_performance.md

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,128 @@ To use a custom policy, set `remat_policy` to `custom` and specify the layers in
6161
- `device`: The activation remains on the TPU device.
6262
- `Remat`: Rematerialization is performed during the backward pass.
6363

64+
**Automatic remat policy search with the Estimator**
65+
66+
Finding the optimal remat policy and batch size manually can be time-consuming. MaxText provides an **Estimator** tool (`estimator.py`) that automates this search using [Ahead-of-Time (AOT) compilation](../monitoring_and_debugging/features_and_diagnostics.md#ahead-of-time-compilation-aot). It leverages `train_compile` to test whether a given configuration causes an Out-Of-Memory (OOM) error *without* requiring the target hardware.
67+
68+
The estimator supports two modes:
69+
70+
1. **Search both batch size and remat policy** (when `per_device_batch_size` is *not* provided): It finds the Pareto frontier of batch size vs. remat policy by iterating through policies from full remat to full device, using binary search for the largest non-OOM batch size at each step.
71+
2. **Search remat policy only** (when `per_device_batch_size` *is* provided): It finds the least aggressive (fastest) remat policy that fits in memory for the given fixed batch size.
72+
73+
*Mode 1 example: Search both batch size and remat policy (Llama 3.1 405B on tpu7x-1024)*
74+
75+
```bash
76+
python -m maxtext.utils.estimator \
77+
maxtext/configs/base.yml \
78+
compile_topology=tpu7x-1024 \
79+
compile_topology_num_slices=1 \
80+
model_name=llama3.1-405b \
81+
max_target_length=32768 \
82+
ici_context_parallelism=8 \
83+
ici_fsdp_parallelism=-1 \
84+
log_config=False \
85+
write_estimator_result=False
86+
```
87+
88+
*Mode 2 example: Search best remat policy for a fixed batch size (DeepSeek3 671B on v5p-1024)*
89+
90+
```bash
91+
python3 -m maxtext.utils.estimator maxtext/configs/base.yml \
92+
model_name=deepseek3-671b \
93+
compile_topology=v5p-1024 \
94+
compile_topology_num_slices=1 \
95+
ici_fsdp_parallelism=512 \
96+
per_device_batch_size=2.0 \
97+
dtype=bfloat16 \
98+
weight_dtype=float32 \
99+
max_target_length=8192 \
100+
log_config=False \
101+
write_estimator_result=False \
102+
decoder_layer_input=offload
103+
```
104+
105+
Key options:
106+
107+
- `write_estimator_result=True`: Writes runnable training commands to `remat_commands_from_estimator.txt`.
108+
- `write_estimator_result=False` (default): Prints results to stdout only.
109+
- You can pin specific tensor remat actions (e.g., `context=offload`) to constrain the search space.
110+
111+
*Advanced example: Search remat policy with XLA tuning flags (DeepSeek3 671B on tpu7x-512)*
112+
113+
For production workloads you often want to combine the estimator with XLA compiler tuning flags for SparseCore offloading, latency-hiding scheduling, and other optimizations. Set these via `LIBTPU_INIT_ARGS` before invoking the estimator:
114+
115+
```bash
116+
export LIBTPU_INIT_ARGS=" \
117+
--xla_tpu_dvfs_p_state=7 \
118+
--xla_tpu_scoped_vmem_limit_kib=65536 \
119+
--xla_tpu_bf16_emission_mode=NATIVE_EMISSION \
120+
--xla_tpu_enable_sparse_core_reduce_scatter_v2=true \
121+
--xla_tpu_enable_sparse_core_collective_offload_all_gather=true \
122+
--xla_tpu_enable_sparse_core_collective_offload_2d_all_gather=true \
123+
--xla_tpu_enable_all_gather_offload_tracing=true \
124+
--xla_tpu_use_tc_device_shape_on_sc=True \
125+
--xla_sc_disable_megacore_partitioning=True \
126+
--xla_tpu_enable_async_collective_fusion_fuse_all_gather=false \
127+
--xla_enable_async_all_gather=true \
128+
--xla_tpu_prefer_async_allgather_to_allreduce=true \
129+
--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true \
130+
--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true \
131+
--xla_tpu_enable_sparse_core_collective_offload_3d_all_gather=true \
132+
--xla_tpu_use_single_sparse_core_for_all_gather_offload=true \
133+
--xla_tpu_enable_concurrent_sparse_core_offloading=true \
134+
--xla_tpu_aggressive_opt_barrier_removal=true \
135+
--xla_tpu_enable_offloading_gather_to_sparsecore=true \
136+
--xla_tpu_sparse_core_all_gather_latency_multiplier=1 \
137+
--xla_tpu_sparse_core_reduce_scatter_latency_multiplier=3 \
138+
--xla_tpu_enable_sparse_core_collective_aggregator=true \
139+
--xla_tpu_enable_latency_hiding_layer_scheduler=true \
140+
--xla_tpu_scheduler_percent_shared_memory_limit=150 \
141+
--xla_tpu_enable_layer_scheduler_for_dependent_collectives=true \
142+
--xla_tpu_enable_sparse_core_collective_offload_nd_reduce_scatter=true \
143+
--xla_tpu_pcie_bandwidth_multiplier=0.03 \
144+
--xla_tpu_enable_sparse_core_offload_queuing_in_lhs=true \
145+
--xla_tpu_enable_multi_compute_overlap_in_layer_scheduler=false \
146+
--xla_tpu_enable_3d_reduce_scatter_decomposer=false "
147+
148+
python3 -m maxtext.utils.estimator maxtext/configs/base.yml \
149+
compile_topology=tpu7x-512 \
150+
compile_topology_num_slices=1 \
151+
run_name=${WORKLOAD_NAME} \
152+
skip_jax_distributed_system=true \
153+
dtype=bfloat16 \
154+
per_device_batch_size=4.0 \
155+
model_name=deepseek3-671b \
156+
remat_policy=custom \
157+
decoder_layer_input=device \
158+
mu_dtype=bfloat16 \
159+
grad_dtype=bfloat16 \
160+
ici_fsdp_parallelism=128 \
161+
ici_expert_parallelism=4 \
162+
dataset_type=synthetic \
163+
dataset_path=gs://max-datasets-rogue \
164+
opt_type=adamw \
165+
steps=20 \
166+
sa_use_fused_bwd_kernel=true \
167+
use_max_logit_estimate=-1 \
168+
cost_estimate_flops_fwd=5000000000000 \
169+
cost_estimate_flops_bwd=5000000000000 \
170+
float32_weight_sum=False \
171+
megablox=true \
172+
sparse_matmul=true \
173+
use_tokamax_gmm=false \
174+
use_tokamax_splash=true \
175+
max_target_length=4096 \
176+
use_random_routing=true \
177+
use_ring_of_experts=true \
178+
use_ragged_sort=true \
179+
tokenizer_path=assets/tokenizer.mistral-v3 \
180+
base_output_directory=${BASE_OUTPUT_DIR} \
181+
merge_gating_gmm=false
182+
```
183+
184+
This example fixes `per_device_batch_size=4.0` so the estimator runs in **Mode 2** (policy-only search), finding the least aggressive remat policy that fits the DeepSeek3 671B model on a tpu7x-512 pod. The XLA flags enable SparseCore collective offloading and latency-hiding scheduling, which affect compilation memory layout and thus the OOM boundary.
185+
64186
### Low precision training
65187

66188
MaxText supports quantization via QWIX. To enable this, set `use_qwix_quantization=true`.

src/maxtext/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,9 @@ compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g.
948948
compile_topology: '' # Target hardware version, e.g. 'v5e-256'
949949
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
950950

951+
# MaxText Estimator configs
952+
write_estimator_result: False
953+
951954
decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of greedy, weighted, nucleus, topk, or composite(top_k -> top_p -> weighted temperature)
952955
decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p
953956
decode_sampling_top_k: 0 # set if you're doing top-k

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,6 +1674,7 @@ class AOT(BaseModel):
16741674
compiled_trainstep_file: PathStr = Field("", description="Name of saved serialized compiled train_step.")
16751675
compile_topology: str = Field("", description="Target hardware version, e.g. 'v5e-256'.")
16761676
compile_topology_num_slices: int = Field(-1, description="Number of target slices.")
1677+
write_estimator_result: bool = Field(False, description="Write estimator.py results in a separate file.")
16771678

16781679

16791680
class DevelopmentAndDebugging(BaseModel):

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 28 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,9 @@
6161
def validate_config(config):
6262
"""Validates the config is is setup correctly to compile, returning a useful error message if not."""
6363
assert config.compile_topology != "", (
64-
"You must pass your desired target hardware in compile_topology, e.g."
65-
" compile_topology=v5e-256"
64+
"You must pass your desired target hardware in compile_topology, e.g." " compile_topology=v5e-256"
6665
)
67-
assert (
68-
config.compile_topology_num_slices > 0
69-
), "You must set compile_topology_num_slices to a positive integer"
66+
assert config.compile_topology_num_slices > 0, "You must set compile_topology_num_slices to a positive integer"
7067

7168

7269
def get_topology_mesh(config):
@@ -78,18 +75,12 @@ def get_topology_mesh(config):
7875
num_slices=config.compile_topology_num_slices,
7976
).devices
8077
else:
81-
target_hardware = accelerator_to_spec_map.get_system_characteristics(
82-
config.compile_topology
83-
)
78+
target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology)
8479
if target_hardware.platform == "gpu":
8580
# Disable sharded autotuning. This is an optimization to distribute
8681
# autotuning across the fleet, but can cause hangs with AoT compilation.
87-
os.environ["XLA_FLAGS"] = (
88-
os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false"
89-
)
90-
jax.config.update(
91-
"mock_num_gpu_processes", config.compile_topology_num_slices
92-
)
82+
os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false"
83+
jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices)
9384
topology_devices = jax.devices()
9485
else:
9586
topology_devices = get_topology_desc(
@@ -104,14 +95,8 @@ def get_topology_mesh(config):
10495
"jax_remove_size_one_mesh_axis_from_type",
10596
config.remove_size_one_mesh_axis_from_type,
10697
)
107-
topology_device_mesh = maxtext_utils.create_device_mesh(
108-
config, topology_devices
109-
)
110-
mesh_axis_type = (
111-
AxisType.Explicit
112-
if config.shard_mode == ShardMode.EXPLICIT
113-
else AxisType.Auto
114-
)
98+
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
99+
mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto
115100
topology_mesh = Mesh(
116101
topology_device_mesh,
117102
config.mesh_axes,
@@ -129,9 +114,7 @@ def _collect_nnx_activation_shardings(create_model_fn, config, mesh):
129114
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
130115
abstract_input = jax.ShapeDtypeStruct(input_shape, jnp.int32)
131116

132-
def _nnx_forward(
133-
decoder_input_tokens, decoder_positions, decoder_segment_ids
134-
):
117+
def _nnx_forward(decoder_input_tokens, decoder_positions, decoder_segment_ids):
135118
model_instance = create_model_fn()
136119
return model_instance(
137120
decoder_input_tokens=decoder_input_tokens,
@@ -140,9 +123,7 @@ def _nnx_forward(
140123
enable_dropout=False,
141124
)
142125

143-
with jax.set_mesh(mesh), nn_partitioning.axis_rules(
144-
config.logical_axis_rules
145-
):
126+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
146127
jax.eval_shape(_nnx_forward, abstract_input, abstract_input, abstract_input)
147128

148129

@@ -151,13 +132,9 @@ def get_shaped_inputs(topology_mesh, config):
151132
# Construct the model and optimizer to get shaped versions of the state
152133
quant = quantizations.configure_quantization(config)
153134
if config.pure_nnx:
154-
_create_model_partial, model = (
155-
model_creation_utils.create_nnx_abstract_model(config, topology_mesh)
156-
)
135+
_create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, topology_mesh)
157136
else:
158-
model = Transformer(
159-
config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN
160-
)
137+
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
161138
# The learning_rate_schedule is baked into the compiled object.
162139
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
163140
# pass in model for muon
@@ -176,20 +153,14 @@ def create_train_state_fn():
176153

177154
init_state_fn = create_train_state_fn
178155
else:
179-
init_state_fn = functools.partial(
180-
maxtext_utils.init_initial_state, model, tx, config, True, example_rng
181-
)
156+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng)
182157

183158
# Shaped state
184-
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(
185-
config, topology_mesh, init_state_fn, True
186-
)
159+
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True)
187160

188161
if config.pure_nnx:
189162
# NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings.
190-
logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(
191-
state_mesh_shardings
192-
)
163+
logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings)
193164
# For NNX, get_functional_train_with_signature expects the graphdef (static structure),
194165
# not the raw model — mirroring how the training loop does nnx.split(train_state).
195166
with nn_partitioning.axis_rules(config.logical_axis_rules):
@@ -198,9 +169,7 @@ def create_train_state_fn():
198169
model = graphdef
199170
else:
200171
# unsharded logical annotations
201-
logical_annotations = maxtext_utils.get_logical_annotations(
202-
config, topology_mesh, init_state_fn
203-
)
172+
logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn)
204173

205174
# Shaped batch
206175
shaped_batch = maxtext_utils.get_shaped_batch(config)
@@ -217,9 +186,7 @@ def create_train_state_fn():
217186
# Collect NNX activation shardings via an abstract forward pass (must run
218187
# after get_abstract_state, which only traces __init__).
219188
if config.debug_sharding and config.pure_nnx:
220-
_collect_nnx_activation_shardings(
221-
_create_model_partial, config, topology_mesh
222-
)
189+
_collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh)
223190

224191
return (
225192
shaped_train_args,
@@ -256,9 +223,7 @@ def jit_and_compile(
256223
maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args)
257224
lowered = jitted.lower(*func_input_args, **func_input_kwargs)
258225
# Import libtpu flags as compiler options. Defaults to empty dict if string is empty.
259-
compiler_options = max_utils.parse_libtpu_flags_to_dict(
260-
config.compile_xla_flags
261-
)
226+
compiler_options = max_utils.parse_libtpu_flags_to_dict(config.compile_xla_flags)
262227
compiled = lowered.compile(compiler_options=compiler_options)
263228
return compiled
264229

@@ -293,18 +258,12 @@ def is_oom(argv: Sequence[str]) -> bool:
293258
) = get_shaped_inputs(topology_mesh, config)
294259

295260
# Update params_shardings when shard_optimizer_over_data is enabled (Zero-1)
296-
params_shardings, state_mesh_shardings = (
297-
sharding.maybe_update_params_sharding_with_opt(
298-
config, state_mesh_shardings
299-
)
300-
)
261+
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
301262

302263
# When ZeRO-1 is enabled, we need to use the original params_shardings for input shardings
303264
# but keep the updated state_mesh_shardings for the optimizer state
304265
if config.shard_optimizer_over_data:
305-
input_state_mesh_shardings = state_mesh_shardings.replace(
306-
params=params_shardings
307-
)
266+
input_state_mesh_shardings = state_mesh_shardings.replace(params=params_shardings)
308267
else:
309268
input_state_mesh_shardings = state_mesh_shardings
310269

@@ -344,6 +303,7 @@ def is_oom(argv: Sequence[str]) -> bool:
344303
except Exception as e:
345304
# return true if OOM error happens
346305
# OOM error looks like
306+
# Check failed: entries[i] <= std::numeric_limits<uint32_t>::max()
347307
# jax.errors.JaxRuntimeError: RESOURCE_EXHAUSTED: Allocation ...
348308
# jax.errors.JaxRuntimeError: INTERNAL: RET_CHECK failure ...
349309
message = str(e).lower()
@@ -355,8 +315,7 @@ def is_oom(argv: Sequence[str]) -> bool:
355315
def main(argv: Sequence[str]) -> None:
356316
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
357317
os.environ["LIBTPU_INIT_ARGS"] = (
358-
os.environ.get("LIBTPU_INIT_ARGS", "")
359-
+ " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
318+
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
360319
)
361320
print("Starting train_compile.py...", flush=True)
362321

@@ -381,18 +340,12 @@ def main(argv: Sequence[str]) -> None:
381340
) = get_shaped_inputs(topology_mesh, config)
382341

383342
# Update params_shardings when shard_optimizer_over_data is enabled (Zero-1)
384-
params_shardings, state_mesh_shardings = (
385-
sharding.maybe_update_params_sharding_with_opt(
386-
config, state_mesh_shardings
387-
)
388-
)
343+
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
389344

390345
# When ZeRO-1 is enabled, we need to use the original params_shardings for input shardings
391346
# but keep the updated state_mesh_shardings for the optimizer state
392347
if config.shard_optimizer_over_data:
393-
input_state_mesh_shardings = state_mesh_shardings.replace(
394-
params=params_shardings
395-
)
348+
input_state_mesh_shardings = state_mesh_shardings.replace(params=params_shardings)
396349
else:
397350
input_state_mesh_shardings = state_mesh_shardings
398351

@@ -401,21 +354,15 @@ def main(argv: Sequence[str]) -> None:
401354
if config.enable_diloco:
402355
# Build abstract DiLoCo state and shardings for AOT compilation
403356
abstract_state = shaped_train_args[0]
404-
diloco_state, state_mesh_shardings, inner_state_shardings = (
405-
diloco.build_abstract_diloco_state(
406-
config, abstract_state, state_mesh_shardings, topology_mesh
407-
)
357+
diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state(
358+
config, abstract_state, state_mesh_shardings, topology_mesh
408359
)
409360
# For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng.
410-
shaped_rng_arg = (
411-
shaped_train_args[2] if len(shaped_train_args) > 2 else None
412-
)
361+
shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None
413362
shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg)
414363

415364
# Wrap train_step with diloco
416-
train_step_partial = functools.partial(
417-
train.train_step, model, config, inner_state_shardings, params_shardings
418-
)
365+
train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, params_shardings)
419366
train_step_fn = diloco.build_diloco_train_step(config, train_step_partial)
420367

421368
# For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng)
@@ -480,10 +427,7 @@ def main(argv: Sequence[str]) -> None:
480427
if config.compiled_trainstep_file != "":
481428
print("Saving compiled object...")
482429
save_compiled(compiled, config.compiled_trainstep_file)
483-
print(
484-
"Successfully saved compiled object as"
485-
f" {config.compiled_trainstep_file}"
486-
)
430+
print("Successfully saved compiled object as" f" {config.compiled_trainstep_file}")
487431
print("Finished train_compile.py successfully!", flush=True)
488432
print(f"Cost analysis: {compiled.cost_analysis()}")
489433
print(f"Memory analysis: {compiled.memory_analysis()}")

0 commit comments

Comments
 (0)