Skip to content
This repository was archived by the owner on May 20, 2026. It is now read-only.

Commit cabae30

Browse files
authored
Add FLUX Model (#99)
* initial version with model fwd done Signed-off-by: Ao Tang <aot@nvidia.com> * flux training Signed-off-by: Ao Tang <aot@nvidia.com> * add real dataset pipeline Signed-off-by: Ao Tang <aot@nvidia.com> * Remove unused utils.py file and clean up code in flux_pipeline.py and convert_checkpoints.py. Add functionality to save individual sample files in prepare_energon_dataset_flux.py. Signed-off-by: Ao Tang <aot@nvidia.com> * Fix dist ckpt issue Signed-off-by: Ao Tang <aot@nvidia.com> * simplify conversion script Signed-off-by: Ao Tang <aot@nvidia.com> * ruff Signed-off-by: Ao Tang <aot@nvidia.com> * lint Signed-off-by: Ao Tang <aot@nvidia.com> * imporve dataset preprocessing script for center cropping Signed-off-by: Ao Tang <aot@nvidia.com> * unified flow matching pipeline Signed-off-by: Ao Tang <aot@nvidia.com> * lint Signed-off-by: Ao Tang <aot@nvidia.com> * repilicate id for tp sharding dist ckpt Signed-off-by: Ao Tang <aot@nvidia.com> * Add unit test Signed-off-by: Ao Tang <aot@nvidia.com> * functional test Signed-off-by: Ao Tang <aot@nvidia.com> * improve inf Signed-off-by: Ao Tang <aot@nvidia.com> * fix tests Signed-off-by: Ao Tang <aot@nvidia.com> * add readme and rename inference pipeline Signed-off-by: Ao Tang <aot@nvidia.com> * comment improve Signed-off-by: Ao Tang <aot@nvidia.com> * restructure Signed-off-by: Ao Tang <aot@nvidia.com> * select to run with either flux_step or flux_step_with_automodel Signed-off-by: Ao Tang <aot@nvidia.com> * trailing whitespace removed Signed-off-by: Ao Tang <aot@nvidia.com> * use timesteps from diffusers Signed-off-by: Ao Tang <aot@nvidia.com> * Add finetune script Signed-off-by: Ao Tang <aot@nvidia.com> * doc fix Signed-off-by: Ao Tang <aot@nvidia.com> * fix unit test Signed-off-by: Ao Tang <aot@nvidia.com> --------- Signed-off-by: Ao Tang <aot@nvidia.com>
1 parent 5b81246 commit cabae30

42 files changed

Lines changed: 10380 additions & 16 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint: disable=C0115,C0116,C0301
16+
17+
from dataclasses import dataclass
18+
19+
from megatron.bridge.data.utils import DatasetBuildContext
20+
from torch import int_repr
21+
22+
from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModule, DiffusionDataModuleConfig
23+
from dfm.src.megatron.data.flux.flux_taskencoder import FluxTaskEncoder
24+
25+
26+
@dataclass(kw_only=True)
27+
class FluxDataModuleConfig(DiffusionDataModuleConfig):
28+
path: str
29+
seq_length: int
30+
packing_buffer_size: int
31+
micro_batch_size: int
32+
global_batch_size: int
33+
num_workers: int_repr
34+
dataloader_type: str = "external"
35+
vae_scale_factor: int = 8
36+
latent_channels: int = 16
37+
38+
def __post_init__(self):
39+
self.dataset = DiffusionDataModule(
40+
path=self.path,
41+
seq_length=self.seq_length,
42+
packing_buffer_size=self.packing_buffer_size,
43+
task_encoder=FluxTaskEncoder(
44+
seq_length=self.seq_length,
45+
packing_buffer_size=self.packing_buffer_size,
46+
vae_scale_factor=self.vae_scale_factor,
47+
latent_channels=self.latent_channels,
48+
),
49+
micro_batch_size=self.micro_batch_size,
50+
global_batch_size=self.global_batch_size,
51+
num_workers=self.num_workers,
52+
use_train_split_for_val=True,
53+
)
54+
self.sequence_length = self.dataset.seq_length
55+
56+
def build_datasets(self, context: DatasetBuildContext):
57+
return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader()
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Mock data module for FLUX model training."""
16+
17+
from dataclasses import dataclass
18+
19+
import torch
20+
from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider
21+
from torch.utils.data import DataLoader, Dataset
22+
23+
24+
class _MockT2IDataset(Dataset):
25+
"""
26+
A mock dataset class for text-to-image tasks, simulating data samples for training and testing.
27+
28+
This dataset generates synthetic data for both image and text inputs, with options to use
29+
pre-cached latent representations or raw data. The class is designed for use in testing and
30+
prototyping machine learning models.
31+
32+
Attributes:
33+
image_H (int): Height of the generated images.
34+
image_W (int): Width of the generated images.
35+
length (int): Total number of samples in the dataset.
36+
image_precached (bool): Whether to use pre-cached latent representations for images.
37+
text_precached (bool): Whether to use pre-cached embeddings for text.
38+
prompt_seq_len (int): Sequence length for text prompts.
39+
pooled_prompt_dim (int): Dimensionality of pooled text embeddings.
40+
context_dim (int): Dimensionality of the text embedding context.
41+
vae_scale_factor (int): Scaling factor for the VAE latent representation.
42+
vae_channels (int): Number of channels in the VAE latent representation.
43+
"""
44+
45+
def __init__(
46+
self,
47+
image_H: int = 1024,
48+
image_W: int = 1024,
49+
length: int = 100000,
50+
image_precached: bool = True,
51+
text_precached: bool = True,
52+
prompt_seq_len: int = 512,
53+
pooled_prompt_dim: int = 768,
54+
context_dim: int = 4096,
55+
vae_scale_factor: int = 8,
56+
vae_channels: int = 16,
57+
):
58+
super().__init__()
59+
self.length = length
60+
self.H = image_H
61+
self.W = image_W
62+
self.image_precached = image_precached
63+
self.text_precached = text_precached
64+
self.vae_channels = vae_channels
65+
self.vae_scale_factor = vae_scale_factor
66+
self.prompt_seq_len = prompt_seq_len
67+
self.pooled_prompt_dim = pooled_prompt_dim
68+
self.context_dim = context_dim
69+
70+
if self.image_precached:
71+
self.latent_shape = (
72+
vae_channels,
73+
int(image_H // vae_scale_factor),
74+
int(image_W // vae_scale_factor),
75+
)
76+
if self.text_precached:
77+
self.prompt_embeds_shape = (prompt_seq_len, context_dim)
78+
self.pooled_prompt_embeds_shape = (pooled_prompt_dim,)
79+
self.text_ids_shape = (prompt_seq_len, 3)
80+
81+
def __getitem__(self, index):
82+
"""
83+
Retrieves a single sample from the dataset.
84+
85+
The sample includes pre-cached latent representations for images and text.
86+
87+
Args:
88+
index (int): Index of the sample to retrieve.
89+
90+
Returns:
91+
dict: A dictionary containing the generated data sample with keys:
92+
- 'latents': Pre-cached latent representation of the image [C, H, W].
93+
- 'prompt_embeds': Pre-cached text prompt embeddings [seq_len, context_dim].
94+
- 'pooled_prompt_embeds': Pooled text prompt embeddings [pooled_dim].
95+
- 'text_ids': Text position IDs [seq_len, 3].
96+
"""
97+
item = {}
98+
99+
if self.image_precached:
100+
# Latents in [C, H, W] format - will be batched to [B, C, H, W]
101+
item["latents"] = torch.randn(self.latent_shape, dtype=torch.bfloat16)
102+
else:
103+
# Raw images [3, H, W]
104+
item["images"] = torch.randn(3, self.H, self.W, dtype=torch.bfloat16)
105+
106+
if self.text_precached:
107+
# T5 embeddings [seq_len, context_dim]
108+
item["prompt_embeds"] = torch.randn(self.prompt_embeds_shape, dtype=torch.bfloat16)
109+
# CLIP pooled embeddings [pooled_dim]
110+
item["pooled_prompt_embeds"] = torch.randn(self.pooled_prompt_embeds_shape, dtype=torch.bfloat16)
111+
# Text position IDs [seq_len, 3]
112+
item["text_ids"] = torch.zeros(self.text_ids_shape, dtype=torch.bfloat16)
113+
else:
114+
item["txt"] = "This is a sample caption input"
115+
116+
return item
117+
118+
def __len__(self):
119+
"""Returns the total number of samples in the dataset."""
120+
return self.length
121+
122+
123+
def _collate_fn(samples):
124+
"""
125+
Collate function to batch samples from _MockT2IDataset.
126+
127+
Args:
128+
samples: List of sample dictionaries from the dataset.
129+
130+
Returns:
131+
dict: Batched dictionary with stacked tensors.
132+
"""
133+
batch = {}
134+
135+
# Stack latents: [B, C, H, W]
136+
if "latents" in samples[0]:
137+
batch["latents"] = torch.stack([s["latents"] for s in samples], dim=0)
138+
elif "images" in samples[0]:
139+
batch["images"] = torch.stack([s["images"] for s in samples], dim=0)
140+
141+
# Stack text embeddings
142+
if "prompt_embeds" in samples[0]:
143+
# [B, seq_len, context_dim]
144+
batch["prompt_embeds"] = torch.stack([s["prompt_embeds"] for s in samples], dim=0)
145+
# [B, pooled_dim]
146+
batch["pooled_prompt_embeds"] = torch.stack([s["pooled_prompt_embeds"] for s in samples], dim=0)
147+
# [B, seq_len, 3]
148+
batch["text_ids"] = torch.stack([s["text_ids"] for s in samples], dim=0)
149+
elif "txt" in samples[0]:
150+
batch["txt"] = [s["txt"] for s in samples]
151+
152+
# Add loss mask (all ones)
153+
if "latents" in batch:
154+
batch_size = batch["latents"].shape[0]
155+
latent_h = batch["latents"].shape[2]
156+
latent_w = batch["latents"].shape[3]
157+
# Loss mask covers all latent positions
158+
batch["loss_mask"] = torch.ones(batch_size, latent_h * latent_w, dtype=torch.bfloat16)
159+
160+
return batch
161+
162+
163+
@dataclass(kw_only=True)
164+
class FluxMockDataModuleConfig(DatasetProvider):
165+
"""
166+
Configuration for FLUX mock data module.
167+
168+
This data module generates synthetic data for FLUX model training,
169+
matching the expected input format of FluxForwardStep.
170+
171+
Attributes:
172+
path: Unused, kept for interface compatibility.
173+
seq_length: Sequence length (unused for FLUX, kept for interface compatibility).
174+
packing_buffer_size: Packing buffer size (unused for FLUX).
175+
micro_batch_size: Micro batch size for training.
176+
global_batch_size: Global batch size for training.
177+
num_workers: Number of data loading workers.
178+
dataloader_type: Type of dataloader ("external" for mock data).
179+
image_H: Height of input images.
180+
image_W: Width of input images.
181+
vae_channels: Number of VAE latent channels.
182+
vae_scale_factor: VAE spatial downsampling factor.
183+
prompt_seq_len: Sequence length for T5 text embeddings.
184+
context_dim: Dimensionality of T5 text embeddings.
185+
pooled_prompt_dim: Dimensionality of CLIP pooled embeddings.
186+
image_precached: Whether images are pre-encoded as VAE latents.
187+
text_precached: Whether text is pre-encoded as embeddings.
188+
num_train_samples: Number of training samples.
189+
"""
190+
191+
path: str = ""
192+
seq_length: int = 1024
193+
packing_buffer_size: int = None
194+
micro_batch_size: int = 1
195+
global_batch_size: int = 4
196+
num_workers: int = 8
197+
dataloader_type: str = "external"
198+
199+
# Image dimensions
200+
image_H: int = 1024
201+
image_W: int = 1024
202+
203+
# VAE settings
204+
vae_channels: int = 16
205+
vae_scale_factor: int = 8
206+
207+
# Text embedding settings
208+
prompt_seq_len: int = 512
209+
context_dim: int = 4096
210+
pooled_prompt_dim: int = 768
211+
212+
# Precaching settings (FLUX typically uses precached data)
213+
image_precached: bool = True
214+
text_precached: bool = True
215+
216+
# Dataset size
217+
num_train_samples: int = 10000
218+
219+
def __post_init__(self):
220+
"""Initialize the mock dataset and dataloader."""
221+
mock_ds = _MockT2IDataset(
222+
image_H=self.image_H,
223+
image_W=self.image_W,
224+
length=self.num_train_samples,
225+
image_precached=self.image_precached,
226+
text_precached=self.text_precached,
227+
prompt_seq_len=self.prompt_seq_len,
228+
pooled_prompt_dim=self.pooled_prompt_dim,
229+
context_dim=self.context_dim,
230+
vae_scale_factor=self.vae_scale_factor,
231+
vae_channels=self.vae_channels,
232+
)
233+
234+
kwargs = {}
235+
if self.num_workers > 0:
236+
kwargs["prefetch_factor"] = 8
237+
kwargs["persistent_workers"] = True
238+
239+
self._train_dl = DataLoader(
240+
mock_ds,
241+
batch_size=self.micro_batch_size,
242+
num_workers=self.num_workers,
243+
collate_fn=_collate_fn,
244+
shuffle=True,
245+
drop_last=True,
246+
pin_memory=True,
247+
**kwargs,
248+
)
249+
self._train_dl_iter = iter(self._train_dl)
250+
self.sequence_length = self.seq_length
251+
252+
def build_datasets(self, _context: DatasetBuildContext):
253+
"""Build and return train/val/test dataloaders."""
254+
# Return iterator for external dataloader type
255+
return self._train_dl_iter, self._train_dl_iter, self._train_dl_iter

0 commit comments

Comments
 (0)