We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 9c9e095 + 0ad8d3f commit b2153a3Copy full SHA for b2153a3
1 file changed
src/maxtext/configs/tpu/v4/22b.sh
@@ -56,6 +56,6 @@ fi
56
# Train
57
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE"
58
python3 -m maxtext.trainers.pre_train.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml\
59
- ici_fsdp_parallelism=64 steps=10 per_device_batch_size=13 profiler=xplane remat_policy=full\
+ ici_fsdp_parallelism=64 steps=10 per_device_batch_size=13 profiler=xplane remat_policy=full attention=flash num_vocab_tiling=8\
60
base_emb_dim=6144 base_num_kv_heads=24 base_num_query_heads=24 base_mlp_dim=24576 base_num_decoder_layers=48\
61
base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH
0 commit comments