Skip to content

Commit ced76d0

Browse files
authored
Add WAN 1.3B Config and parameterize num_layers (#324)
1 parent f23746b commit ced76d0

2 files changed

Lines changed: 343 additions & 5 deletions

File tree

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This sentinel is a reminder to choose a real run name.
16+
run_name: ''
17+
18+
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
19+
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
20+
write_metrics: True
21+
22+
timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written.
23+
write_timing_metrics: True
24+
25+
gcs_metrics: False
26+
# If true save config to GCS in {base_output_directory}/{run_name}/
27+
save_config_to_gcs: False
28+
log_period: 100
29+
30+
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-1.3B-Diffusers'
31+
model_name: wan2.1
32+
model_type: 'T2V'
33+
34+
# Overrides the transformer from pretrained_model_name_or_path
35+
wan_transformer_pretrained_model_name_or_path: ''
36+
37+
unet_checkpoint: ''
38+
revision: ''
39+
# This will convert the weights to this dtype.
40+
# When running inference on TPUv5e, use weights_dtype: 'bfloat16'
41+
weights_dtype: 'bfloat16'
42+
# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype)
43+
activations_dtype: 'bfloat16'
44+
45+
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
46+
replicate_vae: False
47+
48+
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
49+
# Options are "DEFAULT", "HIGH", "HIGHEST"
50+
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
51+
# at the cost of time.
52+
precision: "DEFAULT"
53+
# Use jax.lax.scan for transformer layers
54+
scan_layers: True
55+
56+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
57+
# It must be True for multi-host.
58+
jit_initializers: True
59+
60+
# Set true to load weights from pytorch
61+
from_pt: True
62+
split_head_dim: True
63+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
64+
flash_min_seq_length: 0
65+
66+
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
67+
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
68+
# However, when padding tokens are significant, this will lead to worse quality and should be set to True.
69+
mask_padding_tokens: True
70+
# Maxdiffusion has 2 types of attention sharding strategies:
71+
# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention)
72+
# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded
73+
# in cross attention q.
74+
attention_sharding_uniform: True
75+
dropout: 0.1
76+
77+
flash_block_sizes: {
78+
"block_q" : 512,
79+
"block_kv_compute" : 512,
80+
"block_kv" : 512,
81+
"block_q_dkv" : 512,
82+
"block_kv_dkv" : 512,
83+
"block_kv_dkv_compute" : 512,
84+
"block_q_dq" : 512,
85+
"block_kv_dq" : 512,
86+
"use_fused_bwd_kernel": False,
87+
}
88+
# GroupNorm groups
89+
norm_num_groups: 32
90+
91+
# train text_encoder - Currently not supported for SDXL
92+
train_text_encoder: False
93+
text_encoder_learning_rate: 4.25e-6
94+
95+
# https://arxiv.org/pdf/2305.08891.pdf
96+
snr_gamma: -1.0
97+
98+
timestep_bias: {
99+
# a value of later will increase the frequence of the model's final training steps.
100+
# none, earlier, later, range
101+
strategy: "none",
102+
# multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it.
103+
multiplier: 1.0,
104+
# when using strategy=range, the beginning (inclusive) timestep to bias.
105+
begin: 0,
106+
# when using strategy=range, the final step (inclusive) to bias.
107+
end: 1000,
108+
# portion of timesteps to bias.
109+
# 0.5 will bias one half of the timesteps. Value of strategy determines
110+
# whether the biased portions are in the earlier or later timesteps.
111+
portion: 0.25
112+
}
113+
114+
# Override parameters from checkpoints's scheduler.
115+
diffusion_scheduler_config: {
116+
_class_name: 'FlaxEulerDiscreteScheduler',
117+
prediction_type: 'epsilon',
118+
rescale_zero_terminal_snr: False,
119+
timestep_spacing: 'trailing'
120+
}
121+
122+
# Output directory
123+
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
124+
base_output_directory: ""
125+
126+
# Hardware
127+
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
128+
skip_jax_distributed_system: False
129+
130+
# Parallelism
131+
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
132+
133+
# batch : batch dimension of data and activations
134+
# hidden :
135+
# embed : attention qkv dense layer hidden dim named as embed
136+
# heads : attention head dim = num_heads * head_dim
137+
# length : attention sequence length
138+
# temb_in : dense.shape[0] of resnet dense before conv
139+
# out_c : dense.shape[1] of resnet dense before conv
140+
# out_channels : conv.shape[-1] activation
141+
# keep_1 : conv.shape[0] weight
142+
# keep_2 : conv.shape[1] weight
143+
# conv_in : conv.shape[2] weight
144+
# conv_out : conv.shape[-1] weight
145+
logical_axis_rules: [
146+
['batch', ['data', 'fsdp']],
147+
['activation_batch', ['data', 'fsdp']],
148+
['activation_self_attn_heads', ['context', 'tensor']],
149+
['activation_cross_attn_q_length', ['context', 'tensor']],
150+
['activation_length', 'context'],
151+
['activation_heads', 'tensor'],
152+
['mlp','tensor'],
153+
['embed', ['context', 'fsdp']],
154+
['heads', 'tensor'],
155+
['norm', 'tensor'],
156+
['conv_batch', ['data', 'context', 'fsdp']],
157+
['out_channels', 'tensor'],
158+
['conv_out', 'context'],
159+
]
160+
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
161+
162+
# One axis for each parallelism type may hold a placeholder (-1)
163+
# value to auto-shard based on available slices and devices.
164+
# By default, product of the DCN axes should equal number of slices
165+
# and product of the ICI axes should equal number of devices per slice.
166+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
167+
dcn_fsdp_parallelism: 1
168+
dcn_context_parallelism: -1
169+
dcn_tensor_parallelism: 1
170+
ici_data_parallelism: 1
171+
ici_fsdp_parallelism: 1
172+
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
173+
ici_tensor_parallelism: 1
174+
175+
allow_split_physical_axes: False
176+
177+
# Dataset
178+
# Replace with dataset path or train_data_dir. One has to be set.
179+
dataset_name: 'diffusers/pokemon-gpt4-captions'
180+
train_split: 'train'
181+
dataset_type: 'tfrecord'
182+
cache_latents_text_encoder_outputs: True
183+
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
184+
# only apply to small dataset that fits in memory
185+
# prepare image latents and text encoder outputs
186+
# Reduce memory consumption and reduce step time during training
187+
# transformed dataset is saved at dataset_save_location
188+
dataset_save_location: ''
189+
load_tfrecord_cached: True
190+
train_data_dir: ''
191+
dataset_config_name: ''
192+
jax_cache_dir: ''
193+
hf_data_dir: ''
194+
hf_train_files: ''
195+
hf_access_token: ''
196+
image_column: 'image'
197+
caption_column: 'text'
198+
resolution: 1024
199+
center_crop: False
200+
random_flip: False
201+
# If cache_latents_text_encoder_outputs is True
202+
# the num_proc is set to 1
203+
tokenize_captions_num_proc: 4
204+
transform_images_num_proc: 4
205+
reuse_example_batch: False
206+
enable_data_shuffling: True
207+
208+
# Defines the type of gradient checkpoint to enable.
209+
# NONE - means no gradient checkpoint
210+
# FULL - means full gradient checkpoint, whenever possible (minimum memory usage)
211+
# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
212+
# except for ones that involve batch dimension - that means that all attention and projection
213+
# layers will have gradient checkpoint, but not the backward with respect to the parameters.
214+
# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing.
215+
# CUSTOM - set names to offload and save.
216+
remat_policy: "NONE"
217+
# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj
218+
# xq_out, xk_out, ffn_activation
219+
names_which_can_be_saved: []
220+
names_which_can_be_offloaded: []
221+
222+
# checkpoint every number of samples, -1 means don't checkpoint.
223+
checkpoint_every: -1
224+
checkpoint_dir: ""
225+
# enables one replica to read the ckpt then broadcast to the rest
226+
enable_single_replica_ckpt_restoring: False
227+
228+
# Training loop
229+
learning_rate: 1.e-5
230+
scale_lr: False
231+
max_train_samples: -1
232+
# max_train_steps takes priority over num_train_epochs.
233+
max_train_steps: 1500
234+
num_train_epochs: 1
235+
seed: 0
236+
output_dir: 'sdxl-model-finetuned'
237+
per_device_batch_size: 1.0
238+
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
239+
global_batch_size: 0
240+
241+
# For creating tfrecords from dataset
242+
tfrecords_dir: ''
243+
no_records_per_shard: 0
244+
enable_eval_timesteps: False
245+
timesteps_list: [125, 250, 375, 500, 625, 750, 875]
246+
num_eval_samples: 420
247+
248+
warmup_steps_fraction: 0.1
249+
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
250+
save_optimizer: False
251+
252+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
253+
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
254+
255+
# AdamW optimizer parameters
256+
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
257+
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
258+
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
259+
adam_weight_decay: 0 # AdamW Weight decay
260+
max_grad_norm: 1.0
261+
262+
enable_profiler: False
263+
# Skip first n steps for profiling, to omit things like compilation and to give
264+
# the iteration time a chance to stabilize.
265+
skip_first_n_steps_for_profiler: 5
266+
profiler_steps: 10
267+
268+
# Enable JAX named scopes for detailed profiling and debugging
269+
# When enabled, adds named scopes around key operations in transformer and attention layers
270+
enable_jax_named_scopes: False
271+
272+
# Generation parameters
273+
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
274+
prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
275+
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
276+
do_classifier_free_guidance: True
277+
height: 480
278+
width: 832
279+
num_frames: 81
280+
guidance_scale: 5.0
281+
flow_shift: 3.0
282+
283+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
284+
guidance_rescale: 0.0
285+
num_inference_steps: 30
286+
fps: 16
287+
save_final_checkpoint: False
288+
289+
# SDXL Lightning parameters
290+
lightning_from_pt: True
291+
# Empty or "ByteDance/SDXL-Lightning" to enable lightning.
292+
lightning_repo: ""
293+
# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning.
294+
lightning_ckpt: ""
295+
296+
# LoRA parameters
297+
enable_lora: False
298+
# Values are lists to support multiple LoRA loading during inference in the future.
299+
lora_config: {
300+
rank: [64],
301+
lora_model_name_or_path: [""],
302+
weight_name: [""],
303+
adapter_name: [""],
304+
scale: [1.0],
305+
from_pt: []
306+
}
307+
308+
enable_mllog: False
309+
310+
#controlnet
311+
controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
312+
controlnet_from_pt: True
313+
controlnet_conditioning_scale: 0.5
314+
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
315+
quantization: ''
316+
# Shard the range finding operation for quantization. By default this is set to number of slices.
317+
quantization_local_shard_count: -1
318+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
319+
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
320+
# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
321+
weight_quantization_calibration_method: "absmax"
322+
act_quantization_calibration_method: "absmax"
323+
bwd_quantization_calibration_method: "absmax"
324+
qwix_module_path: ".*"
325+
326+
# Eval model on per eval_every steps. -1 means don't eval.
327+
eval_every: -1
328+
eval_data_dir: ""
329+
enable_generate_video_for_eval: False # This will increase the used TPU memory.
330+
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
331+
332+
enable_ssim: False

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def rename_for_custom_trasformer(key):
7272
return renamed_pt_key
7373

7474

75-
def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers):
75+
def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers=40):
7676
if scan_layers:
7777
if "blocks" in pt_tuple_key:
7878
new_key = ("blocks",) + pt_tuple_key[2:]
@@ -89,7 +89,7 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d
8989
if flax_key in flax_state_dict:
9090
new_tensor = flax_state_dict[flax_key]
9191
else:
92-
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
92+
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
9393
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
9494
return flax_key, flax_tensor
9595

@@ -127,7 +127,9 @@ def load_fusionx_transformer(
127127

128128
pt_tuple_key = tuple(renamed_pt_key.split("."))
129129

130-
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
130+
flax_key, flax_tensor = get_key_and_value(
131+
pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
132+
)
131133
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
132134

133135
validate_flax_state_dict(eval_shapes, flax_state_dict)
@@ -167,7 +169,9 @@ def load_causvid_transformer(
167169
renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key)
168170

169171
pt_tuple_key = tuple(renamed_pt_key.split("."))
170-
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
172+
flax_key, flax_tensor = get_key_and_value(
173+
pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
174+
)
171175
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
172176

173177
validate_flax_state_dict(eval_shapes, flax_state_dict)
@@ -284,7 +288,9 @@ def load_base_wan_transformer(
284288
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
285289
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
286290
pt_tuple_key = tuple(renamed_pt_key.split("."))
287-
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
291+
flax_key, flax_tensor = get_key_and_value(
292+
pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
293+
)
288294
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
289295

290296
validate_flax_state_dict(eval_shapes, flax_state_dict)

0 commit comments

Comments
 (0)