Skip to content

Commit fc46fcc

Browse files
authored
Video training part 2 (#218)
1 parent 8274aca commit fc46fcc

11 files changed

Lines changed: 243 additions & 79 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""
1616

1717
from abc import ABC
18-
from flax import nnx
1918
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
2019
from ..pipelines.wan.wan_pipeline import WanPipeline
2120
from .. import max_logging, max_utils
@@ -42,7 +41,7 @@ def _create_optimizer(self, model, config, learning_rate):
4241
learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps
4342
)
4443
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
45-
return nnx.Optimizer(model, tx), learning_rate_scheduler
44+
return tx, learning_rate_scheduler
4645

4746
def load_wan_configs_from_orbax(self, step):
4847
max_logging.log("Restoring stable diffusion configs")

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ jit_initializers: True
5454
from_pt: True
5555
split_head_dim: True
5656
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
57+
flash_min_seq_length: 4096
5758

5859
flash_block_sizes: {}
5960
# Use on v6e
@@ -126,15 +127,17 @@ mesh_axes: ['data', 'fsdp', 'tensor']
126127
# conv_out : conv.shape[-1] weight
127128
logical_axis_rules: [
128129
['batch', 'data'],
130+
['activation_batch', 'data'],
129131
['activation_length', 'fsdp'],
132+
130133
['activation_heads', 'tensor'],
131-
['activation_batch', 'data'],
132134
['mlp','tensor'],
133135
['embed','fsdp'],
136+
['heads', 'tensor'],
134137
['norm', 'tensor'],
135138
['conv_batch', ['data','fsdp']],
136139
['out_channels', 'tensor'],
137-
['conv_in', 'fsdp'],
140+
['conv_out', 'fsdp'],
138141
]
139142
data_sharding: [['data', 'fsdp', 'tensor']]
140143

@@ -182,6 +185,14 @@ transform_images_num_proc: 4
182185
reuse_example_batch: False
183186
enable_data_shuffling: True
184187

188+
# Defines the type of gradient checkpoint to enable.
189+
# NONE - means no gradient checkpoint
190+
# FULL - means full gradient checkpoint, whenever possible (minimum memory usage)
191+
# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
192+
# except for ones that involve batch dimension - that means that all attention and projection
193+
# layers will have gradient checkpoint, but not the backward with respect to the parameters
194+
remat_policy: "NONE"
195+
185196
# checkpoint every number of samples, -1 means don't checkpoint.
186197
checkpoint_every: -1
187198
# enables one replica to read the ckpt then broadcast to the rest
@@ -196,7 +207,7 @@ max_train_steps: 1500
196207
num_train_epochs: 1
197208
seed: 0
198209
output_dir: 'sdxl-model-finetuned'
199-
per_device_batch_size: 1
210+
per_device_batch_size: 1.0
200211
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
201212
global_batch_size: 0
202213

src/maxdiffusion/generate_wan.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,9 @@ def run(config, pipeline=None, filename_prefix=""):
2929
pipeline = WanPipeline.from_pretrained(config)
3030
s0 = time.perf_counter()
3131

32-
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
33-
global_batch_size = config.global_batch_size
34-
if global_batch_size != 0:
35-
batch_multiplier = global_batch_size
36-
else:
37-
batch_multiplier = jax.device_count() * config.per_device_batch_size
38-
39-
prompt = [config.prompt] * batch_multiplier
40-
negative_prompt = [config.negative_prompt] * batch_multiplier
32+
# Using global_batch_size_to_train_on so not to create more config variables
33+
prompt = [config.prompt] * config.global_batch_size_to_train_on
34+
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on
4135

4236
max_logging.log(
4337
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,16 @@ def make_tfrecord_iterator(
117117
check out preparation script
118118
maxdiffusion/pedagogical_examples/to_tfrecords.py
119119
"""
120-
121120
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
122121
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
123122
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
123+
124+
# checks that the dataset path is valid. In case of gcs, the existance of the dir is not checked.
125+
is_dataset_dir_valid = "gs://" in config.dataset_save_location or os.path.isdir(config.dataset_save_location)
126+
124127
if (
125128
config.cache_latents_text_encoder_outputs
126-
and os.path.isdir(config.dataset_save_location)
129+
and is_dataset_dir_valid
127130
and "load_tfrecord_cached" in config.get_keys()
128131
and config.load_tfrecord_cached
129132
):

src/maxdiffusion/models/attention_flax.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import flax.linen as nn
1919
from flax import nnx
2020
import jax
21+
from jax.ad_checkpoint import checkpoint_name
2122
from jax.sharding import PartitionSpec
2223
import jax.numpy as jnp
2324
from jax.experimental import shard_map
@@ -187,30 +188,6 @@ def _tpu_flash_attention(
187188
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
188189
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
189190
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
190-
flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH)
191-
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
192-
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
193-
194-
shard_head_size = mesh.shape["tensor"]
195-
196-
@functools.partial(
197-
jax.jit,
198-
static_argnames=["multi_head_mask", "shard_head_size"],
199-
)
200-
def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
201-
splash_kernel = splash_attention_kernel.make_splash_mha(
202-
mask=multi_head_mask,
203-
head_shards=shard_head_size, # the sizes of the axis is sharding over heads
204-
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
205-
block_sizes=block_sizes,
206-
)
207-
return splash_kernel
208-
209-
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
210-
211-
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
212-
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
213-
segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding)
214191

215192
@functools.partial(
216193
shard_map.shard_map,
@@ -219,12 +196,21 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
219196
q_axis_names,
220197
kv_axis_names,
221198
kv_axis_names,
222-
segment_axis_names_splash_kernel,
223199
),
224200
out_specs=q_axis_names,
225201
check_rep=False,
226202
)
227-
def wrap_flash_attention(query, key, value, splash_kernel):
203+
def wrap_flash_attention(query, key, value):
204+
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
205+
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
206+
# make_splash_mha is wrapped around shardmap and seq and head is already
207+
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
208+
splash_kernel = splash_attention_kernel.make_splash_mha(
209+
mask=multi_head_mask,
210+
head_shards=1, # the sizes of the axis is sharding over heads
211+
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
212+
block_sizes=block_sizes,
213+
)
228214
attention_output = jax.vmap(splash_kernel)(query, key, value)
229215
return attention_output
230216

@@ -236,7 +222,7 @@ def wrap_flash_attention(query, key, value, splash_kernel):
236222
"Warning, batch dimension should be shardable among the devices in data and fsdp"
237223
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
238224
)
239-
x = wrap_flash_attention(query, key, value, splash_kernel)
225+
x = wrap_flash_attention(query, key, value)
240226
x = x[:, :, :query_seq_len, :kv_size]
241227
x = _reshape_heads_to_head_dim(x)
242228

@@ -632,7 +618,7 @@ def __init__(
632618
use_memory_efficient_attention: bool = False,
633619
split_head_dim: bool = False,
634620
attention_kernel: str = "flash",
635-
flash_min_seq_length: int = 4096,
621+
flash_min_seq_length: int = 0,
636622
flash_block_sizes: BlockSizes = None,
637623
mesh: jax.sharding.Mesh = None,
638624
dtype: jnp.dtype = jnp.float32,
@@ -809,12 +795,16 @@ def __call__(
809795
query_proj = _unflatten_heads(query_proj, self.heads)
810796
key_proj = _unflatten_heads(key_proj, self.heads)
811797
value_proj = _unflatten_heads(value_proj, self.heads)
798+
# output of _unflatten_heads Batch, heads, seq_len, head_dim
812799
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
813800

801+
query_proj = checkpoint_name(query_proj, "query_proj")
802+
key_proj = checkpoint_name(key_proj, "key_proj")
803+
value_proj = checkpoint_name(value_proj, "value_proj")
814804
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
815805

816806
attn_output = attn_output.astype(dtype=dtype)
817-
807+
attn_output = checkpoint_name(attn_output, "attn_output")
818808
hidden_states = self.proj_attn(attn_output)
819809
return hidden_states
820810

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from enum import Enum, auto
18+
from typing import Optional
19+
20+
import jax
21+
from jax import checkpoint_policies as cp
22+
from flax import nnx
23+
24+
SKIP_GRADIENT_CHECKPOINT_KEY = "skip"
25+
26+
27+
# This class only works with NNX modules.
28+
class GradientCheckpointType(Enum):
29+
"""
30+
Defines the type of the gradient checkpoint we will have
31+
32+
NONE - means no gradient checkpoint
33+
FULL - means full gradient checkpoint, wherever possible (minimum memory usage)
34+
MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
35+
except for ones that involve batch dimension - that means that all attention and projection
36+
layers will have gradient checkpoint, but not the backward with respect to the parameters
37+
"""
38+
39+
NONE = auto()
40+
FULL = auto()
41+
MATMUL_WITHOUT_BATCH = auto()
42+
ATTN = auto()
43+
44+
@classmethod
45+
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
46+
"""
47+
Constructs the gradient checkpoint type from a string
48+
49+
Args:
50+
s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None.
51+
52+
Returns:
53+
GradientCheckpointType: The policy that corresponds to the string
54+
"""
55+
if s is None:
56+
s = "none"
57+
return GradientCheckpointType[s.upper()]
58+
59+
def to_jax_policy(self):
60+
"""
61+
Converts the gradient checkpoint type to a jax policy
62+
"""
63+
match self:
64+
case GradientCheckpointType.NONE:
65+
return SKIP_GRADIENT_CHECKPOINT_KEY
66+
case GradientCheckpointType.FULL:
67+
return None
68+
case GradientCheckpointType.ATTN:
69+
return cp.save_and_offload_only_these_names(
70+
names_which_can_be_saved=[], names_which_can_be_offloaded=[], offload_src="device", offload_dst="pinned_host"
71+
)
72+
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
73+
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
74+
75+
def apply(self, module: nnx.Module) -> nnx.Module:
76+
"""
77+
Applies a gradient checkpoint policy to a module
78+
if no policy is needed, it will return the module as is
79+
80+
Args:
81+
module (nn.Module): the module to apply the policy to
82+
83+
Returns:
84+
nn.Module: the module with the policy applied
85+
"""
86+
policy = self.to_jax_policy()
87+
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
88+
return module
89+
return nnx.remat( # pylint: disable=invalid-name
90+
module,
91+
prevent_cse=False,
92+
policy=policy,
93+
)

0 commit comments

Comments
 (0)