Skip to content

Commit c69c95e

Browse files
authored
Bug fixes and performance optimizations for FLUX training (#187)
* fix multiple issues on flux GPU
1 parent 4a6f807 commit c69c95e

5 files changed

Lines changed: 31 additions & 16 deletions

File tree

src/maxdiffusion/checkpointing/flux_checkpointer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def load_diffusers_checkpoint(self):
194194
clip_tokenizer = CLIPTokenizer.from_pretrained(self.config.clip_model_name_or_path, max_length=77, use_fast=True)
195195
t5_encoder = FlaxT5EncoderModel.from_pretrained(self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype)
196196
t5_tokenizer = AutoTokenizer.from_pretrained(
197-
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
197+
self.config.t5xxl_model_name_or_path, model_max_length=self.config.max_sequence_length, use_fast=True
198198
)
199199

200200
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
@@ -263,7 +263,7 @@ def load_checkpoint(self, step=None, scheduler_class=None):
263263
self.config.t5xxl_model_name_or_path, dtype=self.config.weights_dtype
264264
)
265265
t5_tokenizer = AutoTokenizer.from_pretrained(
266-
self.config.t5xxl_model_name_or_path, max_length=self.config.max_sequence_length, use_fast=True
266+
self.config.t5xxl_model_name_or_path, model_max_length=self.config.max_sequence_length, use_fast=True
267267
)
268268

269269
vae = FlaxAutoencoderKL.from_config(

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tensorflow as tf
1919
import tensorflow.experimental.numpy as tnp
2020
from datasets import load_dataset, load_from_disk
21-
21+
import jax
2222
from maxdiffusion import multihost_dataloading
2323

2424
AUTOTUNE = tf.data.AUTOTUNE
@@ -65,8 +65,13 @@ def make_tf_iterator(
6565
)
6666
if config.cache_latents_text_encoder_outputs:
6767
train_ds.save_to_disk(config.dataset_save_location)
68-
train_ds.cleanup_cache_files()
69-
68+
# Only process 0 should attempt to clean up cache files
69+
if jax.process_index() == 0:
70+
try:
71+
train_ds.cleanup_cache_files()
72+
except FileNotFoundError:
73+
# Ignore FileNotFoundError as files may have been cleaned up by another process
74+
pass
7075
train_ds = load_as_tf_dataset(train_ds, global_batch_size, True, dataloading_host_count)
7176
train_ds = train_ds.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
7277

src/maxdiffusion/max_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import os
2727
from pathlib import Path
2828
import subprocess
29-
3029
import numpy as np
3130

3231
import flax

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def setup(self):
8686
self.linear1 = nn.Dense(
8787
self.dim * 3 + self.mlp_hidden_dim,
8888
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
89-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)),
89+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
9090
dtype=self.dtype,
9191
param_dtype=self.weights_dtype,
9292
precision=self.precision,
@@ -96,7 +96,7 @@ def setup(self):
9696
self.linear2 = nn.Dense(
9797
self.dim,
9898
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
99-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)),
99+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
100100
dtype=self.dtype,
101101
param_dtype=self.weights_dtype,
102102
precision=self.precision,
@@ -209,7 +209,7 @@ def setup(self):
209209
int(self.dim * self.mlp_ratio),
210210
use_bias=True,
211211
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
212-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)),
212+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
213213
dtype=self.dtype,
214214
param_dtype=self.weights_dtype,
215215
precision=self.precision,
@@ -218,8 +218,8 @@ def setup(self):
218218
nn.Dense(
219219
self.dim,
220220
use_bias=True,
221-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
222-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)),
221+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
222+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
223223
dtype=self.dtype,
224224
param_dtype=self.weights_dtype,
225225
precision=self.precision,
@@ -240,7 +240,7 @@ def setup(self):
240240
int(self.dim * self.mlp_ratio),
241241
use_bias=True,
242242
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
243-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)),
243+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
244244
dtype=self.dtype,
245245
param_dtype=self.weights_dtype,
246246
precision=self.precision,
@@ -249,8 +249,8 @@ def setup(self):
249249
nn.Dense(
250250
self.dim,
251251
use_bias=True,
252-
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")),
253-
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("mlp",)),
252+
kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")),
253+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)),
254254
dtype=self.dtype,
255255
param_dtype=self.weights_dtype,
256256
precision=self.precision,
@@ -483,6 +483,9 @@ def __call__(
483483
):
484484
hidden_states = self.img_in(hidden_states)
485485
timestep = self.timestep_embedding(timestep, 256)
486+
487+
timestep = nn.with_logical_constraint(timestep, ("activation_batch", None))
488+
486489
if self.guidance_embeds:
487490
guidance = self.timestep_embedding(guidance, 256)
488491
else:
@@ -492,6 +495,9 @@ def __call__(
492495
if guidance is None
493496
else self.time_text_embed(timestep, guidance, pooled_projections)
494497
)
498+
499+
temb = nn.with_logical_constraint(temb, ("activation_batch", None))
500+
495501
encoder_hidden_states = self.txt_in(encoder_hidden_states)
496502
if txt_ids.ndim == 3:
497503
txt_ids = txt_ids[0]
@@ -501,7 +507,7 @@ def __call__(
501507
ids = jnp.concatenate((txt_ids, img_ids), axis=0)
502508
ids = nn.with_logical_constraint(ids, ("activation_batch", None))
503509
image_rotary_emb = self.pe_embedder(ids)
504-
image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed"))
510+
image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, (None, None))
505511

506512
for double_block in self.double_blocks:
507513
hidden_states, encoder_hidden_states = double_block(

src/maxdiffusion/trainers/flux_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def load_dataset(self, pipeline, params, train_states):
252252
t5_tokenizer=pipeline.t5_tokenizer,
253253
clip_text_encoder=pipeline.clip_encoder,
254254
t5_text_encoder=pipeline.t5_encoder,
255+
max_sequence_length=config.max_sequence_length,
255256
encode_in_batches=True,
256257
encode_batch_size=16,
257258
)
@@ -348,9 +349,13 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
348349
example_batch = load_next_batch(data_iterator, example_batch, self.config)
349350
example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()}
350351

351-
with jax.profiler.StepTraceAnnotation("train", step_num=step):
352+
if self.config.profiler == 'nsys':
352353
with self.mesh:
353354
flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs)
355+
else:
356+
with jax.profiler.StepTraceAnnotation("train", step_num=step):
357+
with self.mesh:
358+
flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs)
354359

355360
samples_count = self.total_train_batch_size * (step + 1)
356361
new_time = datetime.datetime.now()

0 commit comments

Comments
 (0)