Skip to content

Commit 4a6f807

Browse files
Add Context parallelism to Wan 2.1 (#200)
Added context parallelism for seq_len sharding. --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com>
1 parent 462f463 commit 4a6f807

24 files changed

Lines changed: 568 additions & 458 deletions

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ pytest==8.2.2
2323
tensorflow>=2.17.0
2424
tensorflow-datasets>=4.9.6
2525
ruff>=0.1.5,<=0.2
26-
git+https://github.com/mlperf/logging.git
2726
opencv-python-headless==4.10.0.84
2827
orbax-checkpoint==0.10.3
2928
tokenizers==0.21.0

src/maxdiffusion/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
BATCH = "activation_batch"
3838
LENGTH = "activation_length"
39+
KV_LENGTH = "activation_kv_length"
3940
EMBED = "activation_embed"
4041
HEAD = "activation_heads"
4142
D_KV = "activation_kv"

src/maxdiffusion/configs/base14.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
135135
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
136136
ici_tensor_parallelism: 1
137137

138+
allow_split_physical_axes: False
139+
138140
# Dataset
139141
# Replace with dataset path or train_data_dir. One has to be set.
140142
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base21.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
136136
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
137137
ici_tensor_parallelism: 1
138138

139+
allow_split_physical_axes: False
140+
139141
# Dataset
140142
# Replace with dataset path or train_data_dir. One has to be set.
141143
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
149149
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
150150
ici_tensor_parallelism: 1
151151

152+
allow_split_physical_axes: False
153+
152154
# Dataset
153155
# Replace with dataset path or train_data_dir. One has to be set.
154156
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ ici_data_parallelism: -1
162162
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
163163
ici_tensor_parallelism: 1
164164

165+
allow_split_physical_axes: False
166+
165167
# Dataset
166168
# Replace with dataset path or train_data_dir. One has to be set.
167169
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_flux_dev_multi_res.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ ici_data_parallelism: -1
162162
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
163163
ici_tensor_parallelism: 1
164164

165+
allow_split_physical_axes: False
166+
165167
# Dataset
166168
# Replace with dataset path or train_data_dir. One has to be set.
167169
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ ici_data_parallelism: -1
170170
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
171171
ici_tensor_parallelism: 1
172172

173+
allow_split_physical_axes: False
174+
173175
# Dataset
174176
# Replace with dataset path or train_data_dir. One has to be set.
175177
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
5757

5858
flash_block_sizes: {}
59+
# Use on v6e
60+
# flash_block_sizes: {
61+
# "block_q" : 3024,
62+
# "block_kv_compute" : 1024,
63+
# "block_kv" : 2048,
64+
# "block_q_dkv" : 3024,
65+
# "block_kv_dkv" : 2048,
66+
# "block_kv_dkv_compute" : 2048,
67+
# "block_q_dq" : 3024,
68+
# "block_kv_dq" : 2048
69+
# }
5970
# GroupNorm groups
6071
norm_num_groups: 32
6172

@@ -115,17 +126,15 @@ mesh_axes: ['data', 'fsdp', 'tensor']
115126
# conv_out : conv.shape[-1] weight
116127
logical_axis_rules: [
117128
['batch', 'data'],
118-
['activation_heads', 'fsdp'],
119-
['activation_batch', ['data','fsdp']],
120-
['activation_kv', 'tensor'],
129+
['activation_length', 'fsdp'],
130+
['activation_heads', 'tensor'],
131+
['activation_batch', 'data'],
121132
['mlp','tensor'],
122133
['embed','fsdp'],
123-
['heads', 'tensor'],
124-
['norm', 'fsdp'],
134+
['norm', 'tensor'],
125135
['conv_batch', ['data','fsdp']],
126136
['out_channels', 'tensor'],
127-
['conv_out', 'fsdp'],
128-
['conv_in', 'fsdp']
137+
['conv_in', 'fsdp'],
129138
]
130139
data_sharding: [['data', 'fsdp', 'tensor']]
131140

@@ -140,6 +149,8 @@ ici_data_parallelism: 1
140149
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
141150
ici_tensor_parallelism: 1
142151

152+
allow_split_physical_axes: False
153+
143154
# Dataset
144155
# Replace with dataset path or train_data_dir. One has to be set.
145156
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/configs/base_xl.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ ici_data_parallelism: -1
135135
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
136136
ici_tensor_parallelism: 1
137137

138+
allow_split_physical_axes: False
139+
138140
# Dataset
139141
# Replace with dataset path or train_data_dir. One has to be set.
140142
dataset_name: 'diffusers/pokemon-gpt4-captions'

0 commit comments

Comments
 (0)