|
53 | 53 | >>> import torch |
54 | 54 | >>> from diffusers import ChromaPipeline |
55 | 55 |
|
56 | | - >>> model_id = "lodestones/Chroma" |
57 | | - >>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors" |
| 56 | + >>> model_id = "lodestones/Chroma1-HD" |
| 57 | + >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" |
58 | 58 | >>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) |
59 | 59 | >>> pipe = ChromaPipeline.from_pretrained( |
60 | 60 | ... model_id, |
@@ -158,7 +158,7 @@ class ChromaPipeline( |
158 | 158 | r""" |
159 | 159 | The Chroma pipeline for text-to-image generation. |
160 | 160 |
|
161 | | - Reference: https://huggingface.co/lodestones/Chroma/ |
| 161 | + Reference: https://huggingface.co/lodestones/Chroma1-HD/ |
162 | 162 |
|
163 | 163 | Args: |
164 | 164 | transformer ([`ChromaTransformer2DModel`]): |
@@ -233,20 +233,23 @@ def _get_t5_prompt_embeds( |
233 | 233 | return_tensors="pt", |
234 | 234 | ) |
235 | 235 | text_input_ids = text_inputs.input_ids |
236 | | - attention_mask = text_inputs.attention_mask.clone() |
| 236 | + tokenizer_mask = text_inputs.attention_mask |
237 | 237 |
|
238 | | - # Chroma requires the attention mask to include one padding token |
239 | | - seq_lengths = attention_mask.sum(dim=1) |
240 | | - mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1) |
241 | | - attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool() |
| 238 | + tokenizer_mask_device = tokenizer_mask.to(device) |
242 | 239 |
|
| 240 | + # unlike FLUX, Chroma uses the attention mask when generating the T5 embedding |
243 | 241 | prompt_embeds = self.text_encoder( |
244 | | - text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device) |
| 242 | + text_input_ids.to(device), |
| 243 | + output_hidden_states=False, |
| 244 | + attention_mask=tokenizer_mask_device, |
245 | 245 | )[0] |
246 | 246 |
|
247 | | - dtype = self.text_encoder.dtype |
248 | 247 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
249 | | - attention_mask = attention_mask.to(device=device) |
| 248 | + |
| 249 | + # for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer |
| 250 | + seq_lengths = tokenizer_mask_device.sum(dim=1) |
| 251 | + mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1) |
| 252 | + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) |
250 | 253 |
|
251 | 254 | _, seq_len, _ = prompt_embeds.shape |
252 | 255 |
|
|
0 commit comments