|
| 1 | +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Mock data module for FLUX model training.""" |
| 16 | + |
| 17 | +from dataclasses import dataclass |
| 18 | + |
| 19 | +import torch |
| 20 | +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider |
| 21 | +from torch.utils.data import DataLoader, Dataset |
| 22 | + |
| 23 | + |
| 24 | +class _MockT2IDataset(Dataset): |
| 25 | + """ |
| 26 | + A mock dataset class for text-to-image tasks, simulating data samples for training and testing. |
| 27 | +
|
| 28 | + This dataset generates synthetic data for both image and text inputs, with options to use |
| 29 | + pre-cached latent representations or raw data. The class is designed for use in testing and |
| 30 | + prototyping machine learning models. |
| 31 | +
|
| 32 | + Attributes: |
| 33 | + image_H (int): Height of the generated images. |
| 34 | + image_W (int): Width of the generated images. |
| 35 | + length (int): Total number of samples in the dataset. |
| 36 | + image_precached (bool): Whether to use pre-cached latent representations for images. |
| 37 | + text_precached (bool): Whether to use pre-cached embeddings for text. |
| 38 | + prompt_seq_len (int): Sequence length for text prompts. |
| 39 | + pooled_prompt_dim (int): Dimensionality of pooled text embeddings. |
| 40 | + context_dim (int): Dimensionality of the text embedding context. |
| 41 | + vae_scale_factor (int): Scaling factor for the VAE latent representation. |
| 42 | + vae_channels (int): Number of channels in the VAE latent representation. |
| 43 | + """ |
| 44 | + |
| 45 | + def __init__( |
| 46 | + self, |
| 47 | + image_H: int = 1024, |
| 48 | + image_W: int = 1024, |
| 49 | + length: int = 100000, |
| 50 | + image_precached: bool = True, |
| 51 | + text_precached: bool = True, |
| 52 | + prompt_seq_len: int = 512, |
| 53 | + pooled_prompt_dim: int = 768, |
| 54 | + context_dim: int = 4096, |
| 55 | + vae_scale_factor: int = 8, |
| 56 | + vae_channels: int = 16, |
| 57 | + ): |
| 58 | + super().__init__() |
| 59 | + self.length = length |
| 60 | + self.H = image_H |
| 61 | + self.W = image_W |
| 62 | + self.image_precached = image_precached |
| 63 | + self.text_precached = text_precached |
| 64 | + self.vae_channels = vae_channels |
| 65 | + self.vae_scale_factor = vae_scale_factor |
| 66 | + self.prompt_seq_len = prompt_seq_len |
| 67 | + self.pooled_prompt_dim = pooled_prompt_dim |
| 68 | + self.context_dim = context_dim |
| 69 | + |
| 70 | + if self.image_precached: |
| 71 | + self.latent_shape = ( |
| 72 | + vae_channels, |
| 73 | + int(image_H // vae_scale_factor), |
| 74 | + int(image_W // vae_scale_factor), |
| 75 | + ) |
| 76 | + if self.text_precached: |
| 77 | + self.prompt_embeds_shape = (prompt_seq_len, context_dim) |
| 78 | + self.pooled_prompt_embeds_shape = (pooled_prompt_dim,) |
| 79 | + self.text_ids_shape = (prompt_seq_len, 3) |
| 80 | + |
| 81 | + def __getitem__(self, index): |
| 82 | + """ |
| 83 | + Retrieves a single sample from the dataset. |
| 84 | +
|
| 85 | + The sample includes pre-cached latent representations for images and text. |
| 86 | +
|
| 87 | + Args: |
| 88 | + index (int): Index of the sample to retrieve. |
| 89 | +
|
| 90 | + Returns: |
| 91 | + dict: A dictionary containing the generated data sample with keys: |
| 92 | + - 'latents': Pre-cached latent representation of the image [C, H, W]. |
| 93 | + - 'prompt_embeds': Pre-cached text prompt embeddings [seq_len, context_dim]. |
| 94 | + - 'pooled_prompt_embeds': Pooled text prompt embeddings [pooled_dim]. |
| 95 | + - 'text_ids': Text position IDs [seq_len, 3]. |
| 96 | + """ |
| 97 | + item = {} |
| 98 | + |
| 99 | + if self.image_precached: |
| 100 | + # Latents in [C, H, W] format - will be batched to [B, C, H, W] |
| 101 | + item["latents"] = torch.randn(self.latent_shape, dtype=torch.bfloat16) |
| 102 | + else: |
| 103 | + # Raw images [3, H, W] |
| 104 | + item["images"] = torch.randn(3, self.H, self.W, dtype=torch.bfloat16) |
| 105 | + |
| 106 | + if self.text_precached: |
| 107 | + # T5 embeddings [seq_len, context_dim] |
| 108 | + item["prompt_embeds"] = torch.randn(self.prompt_embeds_shape, dtype=torch.bfloat16) |
| 109 | + # CLIP pooled embeddings [pooled_dim] |
| 110 | + item["pooled_prompt_embeds"] = torch.randn(self.pooled_prompt_embeds_shape, dtype=torch.bfloat16) |
| 111 | + # Text position IDs [seq_len, 3] |
| 112 | + item["text_ids"] = torch.zeros(self.text_ids_shape, dtype=torch.bfloat16) |
| 113 | + else: |
| 114 | + item["txt"] = "This is a sample caption input" |
| 115 | + |
| 116 | + return item |
| 117 | + |
| 118 | + def __len__(self): |
| 119 | + """Returns the total number of samples in the dataset.""" |
| 120 | + return self.length |
| 121 | + |
| 122 | + |
| 123 | +def _collate_fn(samples): |
| 124 | + """ |
| 125 | + Collate function to batch samples from _MockT2IDataset. |
| 126 | +
|
| 127 | + Args: |
| 128 | + samples: List of sample dictionaries from the dataset. |
| 129 | +
|
| 130 | + Returns: |
| 131 | + dict: Batched dictionary with stacked tensors. |
| 132 | + """ |
| 133 | + batch = {} |
| 134 | + |
| 135 | + # Stack latents: [B, C, H, W] |
| 136 | + if "latents" in samples[0]: |
| 137 | + batch["latents"] = torch.stack([s["latents"] for s in samples], dim=0) |
| 138 | + elif "images" in samples[0]: |
| 139 | + batch["images"] = torch.stack([s["images"] for s in samples], dim=0) |
| 140 | + |
| 141 | + # Stack text embeddings |
| 142 | + if "prompt_embeds" in samples[0]: |
| 143 | + # [B, seq_len, context_dim] |
| 144 | + batch["prompt_embeds"] = torch.stack([s["prompt_embeds"] for s in samples], dim=0) |
| 145 | + # [B, pooled_dim] |
| 146 | + batch["pooled_prompt_embeds"] = torch.stack([s["pooled_prompt_embeds"] for s in samples], dim=0) |
| 147 | + # [B, seq_len, 3] |
| 148 | + batch["text_ids"] = torch.stack([s["text_ids"] for s in samples], dim=0) |
| 149 | + elif "txt" in samples[0]: |
| 150 | + batch["txt"] = [s["txt"] for s in samples] |
| 151 | + |
| 152 | + # Add loss mask (all ones) |
| 153 | + if "latents" in batch: |
| 154 | + batch_size = batch["latents"].shape[0] |
| 155 | + latent_h = batch["latents"].shape[2] |
| 156 | + latent_w = batch["latents"].shape[3] |
| 157 | + # Loss mask covers all latent positions |
| 158 | + batch["loss_mask"] = torch.ones(batch_size, latent_h * latent_w, dtype=torch.bfloat16) |
| 159 | + |
| 160 | + return batch |
| 161 | + |
| 162 | + |
| 163 | +@dataclass(kw_only=True) |
| 164 | +class FluxMockDataModuleConfig(DatasetProvider): |
| 165 | + """ |
| 166 | + Configuration for FLUX mock data module. |
| 167 | +
|
| 168 | + This data module generates synthetic data for FLUX model training, |
| 169 | + matching the expected input format of FluxForwardStep. |
| 170 | +
|
| 171 | + Attributes: |
| 172 | + path: Unused, kept for interface compatibility. |
| 173 | + seq_length: Sequence length (unused for FLUX, kept for interface compatibility). |
| 174 | + packing_buffer_size: Packing buffer size (unused for FLUX). |
| 175 | + micro_batch_size: Micro batch size for training. |
| 176 | + global_batch_size: Global batch size for training. |
| 177 | + num_workers: Number of data loading workers. |
| 178 | + dataloader_type: Type of dataloader ("external" for mock data). |
| 179 | + image_H: Height of input images. |
| 180 | + image_W: Width of input images. |
| 181 | + vae_channels: Number of VAE latent channels. |
| 182 | + vae_scale_factor: VAE spatial downsampling factor. |
| 183 | + prompt_seq_len: Sequence length for T5 text embeddings. |
| 184 | + context_dim: Dimensionality of T5 text embeddings. |
| 185 | + pooled_prompt_dim: Dimensionality of CLIP pooled embeddings. |
| 186 | + image_precached: Whether images are pre-encoded as VAE latents. |
| 187 | + text_precached: Whether text is pre-encoded as embeddings. |
| 188 | + num_train_samples: Number of training samples. |
| 189 | + """ |
| 190 | + |
| 191 | + path: str = "" |
| 192 | + seq_length: int = 1024 |
| 193 | + packing_buffer_size: int = None |
| 194 | + micro_batch_size: int = 1 |
| 195 | + global_batch_size: int = 4 |
| 196 | + num_workers: int = 8 |
| 197 | + dataloader_type: str = "external" |
| 198 | + |
| 199 | + # Image dimensions |
| 200 | + image_H: int = 1024 |
| 201 | + image_W: int = 1024 |
| 202 | + |
| 203 | + # VAE settings |
| 204 | + vae_channels: int = 16 |
| 205 | + vae_scale_factor: int = 8 |
| 206 | + |
| 207 | + # Text embedding settings |
| 208 | + prompt_seq_len: int = 512 |
| 209 | + context_dim: int = 4096 |
| 210 | + pooled_prompt_dim: int = 768 |
| 211 | + |
| 212 | + # Precaching settings (FLUX typically uses precached data) |
| 213 | + image_precached: bool = True |
| 214 | + text_precached: bool = True |
| 215 | + |
| 216 | + # Dataset size |
| 217 | + num_train_samples: int = 10000 |
| 218 | + |
| 219 | + def __post_init__(self): |
| 220 | + """Initialize the mock dataset and dataloader.""" |
| 221 | + mock_ds = _MockT2IDataset( |
| 222 | + image_H=self.image_H, |
| 223 | + image_W=self.image_W, |
| 224 | + length=self.num_train_samples, |
| 225 | + image_precached=self.image_precached, |
| 226 | + text_precached=self.text_precached, |
| 227 | + prompt_seq_len=self.prompt_seq_len, |
| 228 | + pooled_prompt_dim=self.pooled_prompt_dim, |
| 229 | + context_dim=self.context_dim, |
| 230 | + vae_scale_factor=self.vae_scale_factor, |
| 231 | + vae_channels=self.vae_channels, |
| 232 | + ) |
| 233 | + |
| 234 | + kwargs = {} |
| 235 | + if self.num_workers > 0: |
| 236 | + kwargs["prefetch_factor"] = 8 |
| 237 | + kwargs["persistent_workers"] = True |
| 238 | + |
| 239 | + self._train_dl = DataLoader( |
| 240 | + mock_ds, |
| 241 | + batch_size=self.micro_batch_size, |
| 242 | + num_workers=self.num_workers, |
| 243 | + collate_fn=_collate_fn, |
| 244 | + shuffle=True, |
| 245 | + drop_last=True, |
| 246 | + pin_memory=True, |
| 247 | + **kwargs, |
| 248 | + ) |
| 249 | + self._train_dl_iter = iter(self._train_dl) |
| 250 | + self.sequence_length = self.seq_length |
| 251 | + |
| 252 | + def build_datasets(self, _context: DatasetBuildContext): |
| 253 | + """Build and return train/val/test dataloaders.""" |
| 254 | + # Return iterator for external dataloader type |
| 255 | + return self._train_dl_iter, self._train_dl_iter, self._train_dl_iter |
0 commit comments