Skip to content

Commit 5cbf844

Browse files
susanbaosusanbao
andauthored
change the flash_block_sizes (#290)
* change the flash_block_sizes * fix unit test --------- Co-authored-by: susanbao <sanbao_google_com@t1v-n-216c02cd-w-0.europe-west4-b.c.cloud-tpu-multipod-dev.internal>
1 parent 6c8fd5a commit 5cbf844

2 files changed

Lines changed: 8 additions & 8 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ jobs:
5858
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
5959
- name: PyTest
6060
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
61+
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
6162
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
6263
# add_pull_ready:
6364
# if: github.ref != 'refs/heads/main'

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,13 @@ flash_min_seq_length: 4096
6464
dropout: 0.1
6565

6666
flash_block_sizes: {
67-
"block_q" : 1024,
68-
"block_kv_compute" : 256,
69-
"block_kv" : 1024,
70-
"block_q_dkv" : 1024,
71-
"block_kv_dkv" : 1024,
72-
"block_kv_dkv_compute" : 256,
73-
"block_q_dq" : 1024,
74-
"block_kv_dq" : 1024
67+
"block_q" : 2048,
68+
"block_kv_compute" : 512,
69+
"block_kv" : 2048,
70+
"block_q_dkv" : 2048,
71+
"block_kv_dkv" : 2048,
72+
"block_kv_dkv_compute" : 512,
73+
"use_fused_bwd_kernel" : True
7574
}
7675
# Use on v6e
7776
# flash_block_sizes: {

0 commit comments

Comments
 (0)