Skip to content

Commit f344ab0

Browse files
authored
Optimize batch loading and metrics writing, replace PositionalSharding with NamedSharding (#186)
* fix profiling * Use torch cpu, async write to tensorboard, script to convert latents to tfrecord, batch iterator for tfrecord cached, namedsharding instead of positional sharding Signed-off-by: Kunjan <kunjanp@google.com> * Replace positional sharding with named sharding Signed-off-by: Kunjan <kunjanp@google.com> * Formatting Signed-off-by: Kunjan <kunjanp@google.com> * Formatting Signed-off-by: Kunjan <kunjanp@google.com> * Fallback to regular tfrecord iterator for datasets without all the processed features Signed-off-by: Kunjan <kunjanp@google.com> * README update --------- Signed-off-by: Kunjan <kunjanp@google.com>
1 parent b161ce3 commit f344ab0

11 files changed

Lines changed: 248 additions & 59 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
ruff check .
5151
- name: PyTest
5252
run: |
53-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest
53+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x
5454
# add_pull_ready:
5555
# if: github.ref != 'refs/heads/main'
5656
# permissions:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ MaxDiffusion supports
3535
* Stable Diffusion 2 base (training and inference)
3636
* Stable Diffusion 2.1 (training and inference)
3737
* Stable Diffusion XL (training and inference).
38+
* Flux Dev and Schnell (Training and inference).
3839
* Stable Diffusion Lightning (inference).
3940
* Hyper-SD XL LoRA loading (inference).
4041
* Load Multiple LoRA (SDXL inference).
4142
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
4243
* Dreambooth training support for Stable Diffusion 1.x,2.x.
4344

44-
**WARNING: The training code is purely experimental and is under development.**
4545

4646
# Table of Contents
4747

requirements.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
jax>=0.4.30
1+
--extra-index-url https://download.pytorch.org/whl/cpu
2+
jax==0.5.3
23
jaxlib>=0.4.30
34
grain-nightly==0.0.10
45
google-cloud-storage==2.17.0
56
absl-py
67
datasets
78
flax>=0.10.2
89
optax>=0.2.3
9-
torch==2.5.1
10-
torchvision==0.20.1
10+
torch==2.6.0
11+
torchvision>=0.20.1
1112
ftfy
1213
tensorboard>=2.17.0
1314
tensorboardx==2.6.2.2

src/maxdiffusion/generate_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424
from PIL import Image
2525
import jax
26-
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
26+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2727
import jax.numpy as jnp
2828
import flax.linen as nn
2929
from chex import Array
@@ -343,7 +343,7 @@ def run(config):
343343
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
344344
)
345345

346-
encoders_sharding = PositionalSharding(devices_array).replicate()
346+
encoders_sharding = NamedSharding(devices_array, P())
347347
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
348348
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
349349
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)

src/maxdiffusion/generate_flux_multi_res.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424
from PIL import Image
2525
import jax
26-
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
26+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2727
import jax.numpy as jnp
2828
import flax.linen as nn
2929
from chex import Array
@@ -381,7 +381,7 @@ def run(config):
381381
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
382382
)
383383

384-
encoders_sharding = PositionalSharding(devices_array).replicate()
384+
encoders_sharding = NamedSharding(devices_array, P())
385385
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
386386
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
387387
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from maxdiffusion import multihost_dataloading
2323

24-
AUTOTUNE = tf.data.experimental.AUTOTUNE
24+
AUTOTUNE = tf.data.AUTOTUNE
2525

2626

2727
def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
@@ -31,7 +31,7 @@ def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_cou
3131
if shuffle:
3232
tf_dataset = tf_dataset.shuffle(len(tf_dataset))
3333
tf_dataset = tf_dataset.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
34-
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)
34+
tf_dataset = tf_dataset.prefetch(AUTOTUNE)
3535
tf_dataset = tf_dataset.repeat(-1)
3636

3737
return tf_dataset
@@ -73,6 +73,53 @@ def make_tf_iterator(
7373
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7474
return train_iter
7575

76+
def make_cached_tfrecord_iterator(
77+
config,
78+
dataloading_host_index,
79+
dataloading_host_count,
80+
mesh,
81+
global_batch_size,
82+
):
83+
"""
84+
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
85+
latents, input_ids, prompt_embeds, and text_embeds.
86+
"""
87+
feature_description = {
88+
"pixel_values": tf.io.FixedLenFeature([], tf.string),
89+
"input_ids": tf.io.FixedLenFeature([], tf.string),
90+
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
91+
"text_embeds": tf.io.FixedLenFeature([], tf.string),
92+
}
93+
94+
def _parse_tfrecord_fn(example):
95+
return tf.io.parse_single_example(example, feature_description)
96+
97+
def prepare_sample(features):
98+
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
99+
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
100+
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
101+
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)
102+
103+
return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}
104+
105+
# This pipeline reads the sharded files and applies the parsing and preparation.
106+
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
107+
108+
train_ds = (
109+
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
110+
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
111+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
112+
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
113+
.shuffle(global_batch_size * 10)
114+
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
115+
.repeat(-1)
116+
.prefetch(AUTOTUNE)
117+
)
118+
119+
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
120+
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
121+
return train_iter
122+
76123

77124
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
78125
def make_tfrecord_iterator(
@@ -86,6 +133,16 @@ def make_tfrecord_iterator(
86133
check out preparation script
87134
maxdiffusion/pedagogical_examples/to_tfrecords.py
88135
"""
136+
137+
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
138+
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
139+
# Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
140+
if (config.cache_latents_text_encoder_outputs
141+
and os.path.isdir(config.dataset_save_location)
142+
and 'load_tfrecord_cached'in config.get_keys()
143+
and config.load_tfrecord_cached):
144+
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
145+
89146
feature_description = {
90147
"moments": tf.io.FixedLenFeature([], tf.string),
91148
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import os
2+
import argparse
3+
import tensorflow as tf
4+
from datasets import load_from_disk
5+
import numpy as np
6+
7+
8+
def _bytes_feature(value):
9+
"""Returns a bytes_list from a serialized tensor."""
10+
if not isinstance(value, tf.Tensor):
11+
value = tf.constant(value)
12+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
13+
14+
15+
def create_4_feature_example(record):
16+
"""Creates a tf.train.Example proto with all 4 pre-computed features."""
17+
pixel_values = tf.io.serialize_tensor(record["pixel_values"])
18+
input_ids = tf.io.serialize_tensor(record["input_ids"])
19+
prompt_embeds = tf.io.serialize_tensor(record["prompt_embeds"])
20+
text_embeds = tf.io.serialize_tensor(record["text_embeds"])
21+
22+
feature = {
23+
"pixel_values": _bytes_feature(pixel_values),
24+
"input_ids": _bytes_feature(input_ids),
25+
"prompt_embeds": _bytes_feature(prompt_embeds),
26+
"text_embeds": _bytes_feature(text_embeds),
27+
}
28+
return tf.train.Example(features=tf.train.Features(feature=feature))
29+
30+
31+
def run(args):
32+
"""Main processing function."""
33+
# Load the cached dataset from the location specified in the arguments
34+
print(f"Loading processed dataset from disk: {args.dataset_save_location}")
35+
processed_ds = load_from_disk(args.dataset_save_location)
36+
print("Dataset loaded successfully.")
37+
38+
# Get sharding and output directory from the arguments
39+
tfrecords_dir = args.tfrecords_dir
40+
num_shards = args.data_num_shards
41+
os.makedirs(tfrecords_dir, exist_ok=True)
42+
43+
writers = [
44+
tf.io.TFRecordWriter(os.path.join(tfrecords_dir, f"shard-{i:05d}-of-{num_shards:05d}.tfrecord"))
45+
for i in range(num_shards)
46+
]
47+
48+
print(f"Writing {len(processed_ds)} records into {num_shards} TFRecord shards...")
49+
50+
for i, record in enumerate(processed_ds):
51+
# Create a new record with explicit casting for float types
52+
casted_record = {
53+
"pixel_values": np.float32(record["pixel_values"]),
54+
"input_ids": record["input_ids"], # This is already integer type
55+
"prompt_embeds": np.float32(record["prompt_embeds"]),
56+
"text_embeds": np.float32(record["text_embeds"]),
57+
}
58+
59+
writer_index = i % num_shards
60+
tf_example = create_4_feature_example(casted_record)
61+
writers[writer_index].write(tf_example.SerializeToString())
62+
63+
for writer in writers:
64+
writer.close()
65+
66+
print("TFRecord conversion complete.")
67+
68+
69+
def main():
70+
"""Parses command-line arguments and runs the conversion."""
71+
parser = argparse.ArgumentParser(description="Convert a cached Hugging Face dataset to sharded TFRecords.")
72+
parser.add_argument(
73+
"--dataset_save_location",
74+
type=str,
75+
required=False,
76+
default="/tmp/pokemon-gpt4-captions_xl",
77+
help="Path to the cached dataset created by the training pipeline.",
78+
)
79+
parser.add_argument(
80+
"--tfrecords_dir",
81+
type=str,
82+
required=False,
83+
default="/tmp/cached_pokemon_tfrecords_sharded",
84+
help="Output directory to save the sharded TFRecord files.",
85+
)
86+
parser.add_argument(
87+
"--data_num_shards", type=int, default=128, help="Number of shards to split the TFRecord dataset into."
88+
)
89+
90+
args = parser.parse_args()
91+
run(args)
92+
93+
94+
if __name__ == "__main__":
95+
main()

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818
import jax
1919
import jax.numpy as jnp
20-
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
20+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2121
import flax
2222
import flax.linen as nn
2323
from flax import nnx
@@ -195,7 +195,7 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
195195
# This replaces random params with the model.
196196
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
197197
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
198-
params = jax.device_put(params, PositionalSharding(devices_array).replicate())
198+
params = jax.device_put(params, NamedSharding(devices_array, P()))
199199
wan_vae = nnx.merge(graphdef, params)
200200
p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
201201
# Shard
@@ -395,7 +395,7 @@ def __call__(
395395
num_channels_latents=num_channel_latents,
396396
)
397397

398-
data_sharding = PositionalSharding(self.devices_array).replicate()
398+
data_sharding = NamedSharding(self.devices_array, P())
399399
if len(prompt) % jax.device_count() == 0:
400400
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
401401

src/maxdiffusion/train_utils.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import jax
1919
import jax.numpy as jnp
20+
import queue
2021

2122
from maxdiffusion import max_utils, max_logging
2223

@@ -68,10 +69,31 @@ def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
6869
metrics["scalar"].update({"learning/current_learning_rate": lr})
6970

7071

72+
_metrics_queue = queue.Queue()
7173
_buffered_step = None
7274
_buffered_metrics = None
7375

7476

77+
def _tensorboard_writer_worker(writer, config):
78+
"""
79+
A worker function that runs in a separate thread.
80+
It waits for metrics to appear in the queue and writes them to TensorBoard.
81+
"""
82+
while True:
83+
data = _metrics_queue.get()
84+
if data is None:
85+
break
86+
metrics, step = data
87+
if jax.process_index() == 0:
88+
for metric_name in metrics.get("scalar", []):
89+
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
90+
for metric_name in metrics.get("scalars", []):
91+
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
92+
93+
if step % config.log_period == 0:
94+
writer.flush()
95+
96+
7597
def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
7698
"""Entry point for all metrics writing in Train's Main.
7799
TODO: would be better as a Class in the future (that initialized all state!)
@@ -81,16 +103,18 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
81103
The logic is that this ensures that Jax is able to queues train_steps and we
82104
don't block when turning "lazy" Jax arrays into real Python numbers.
83105
"""
84-
global _buffered_step, _buffered_metrics
106+
global _buffered_step, _buffered_metrics, _metrics_queue
85107

108+
if metrics:
109+
_metrics_queue.put((metrics, step))
86110
if _buffered_metrics is not None:
111+
if config.metrics_file:
112+
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)
113+
87114
if _buffered_step is None:
88115
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
89116
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)
90117

91-
if config.metrics_file:
92-
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)
93-
94118
if config.gcs_metrics and jax.process_index() == 0:
95119
running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)
96120

@@ -100,13 +124,6 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
100124

101125
def write_metrics_to_tensorboard(writer, metrics, step, config):
102126
"""Writes metrics to tensorboard"""
103-
if jax.process_index() == 0:
104-
for metric_name in metrics.get("scalar", []):
105-
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
106-
for metric_name in metrics.get("scalars", []):
107-
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
108-
109-
full_log = step % config.log_period == 0
110127
if jax.process_index() == 0:
111128
max_logging.log(
112129
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
@@ -116,6 +133,13 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):
116133
float(metrics["scalar"]["learning/loss"]),
117134
)
118135
)
136+
if jax.process_index() == 0:
137+
for metric_name in metrics.get("scalar", []):
138+
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
139+
for metric_name in metrics.get("scalars", []):
140+
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
141+
142+
full_log = step % config.log_period == 0
119143

120144
if full_log and jax.process_index() == 0:
121145
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")

src/maxdiffusion/trainers/flux_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import jax
2323
import jax.numpy as jnp
24-
from jax.sharding import PositionalSharding, PartitionSpec as P
24+
from jax.sharding import NamedSharding, PartitionSpec as P
2525
from flax.linen import partitioning as nn_partitioning
2626
from maxdiffusion.checkpointing.flux_checkpointer import (
2727
FluxCheckpointer,
@@ -87,7 +87,7 @@ def start_training(self):
8787
state_shardings = {}
8888

8989
# move params to accelerator
90-
encoders_sharding = PositionalSharding(self.devices_array).replicate()
90+
encoders_sharding = NamedSharding(self.mesh, P(None))
9191
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
9292
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
9393
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)

0 commit comments

Comments
 (0)